use crate::frame::value::ValueList;
use crate::prepared_statement::PreparedStatement;
use crate::query::Query;
use crate::transport::errors::{DbError, QueryError};
use crate::transport::iterator::RowIterator;
use crate::{QueryResult, Session};
use bytes::Bytes;
use dashmap::DashMap;
use itertools::Either;
pub struct CachingSession {
pub session: Session,
pub max_capacity: usize,
pub cache: DashMap<String, PreparedStatement>,
}
impl CachingSession {
pub fn from(session: Session, cache_size: usize) -> Self {
Self {
session,
max_capacity: cache_size,
cache: Default::default(),
}
}
pub async fn execute(
&self,
query: impl Into<Query>,
values: &impl ValueList,
) -> Result<QueryResult, QueryError> {
let query = query.into();
let prepared = self.add_prepared_statement(&query).await?;
let values = values.serialized()?;
let result = self.session.execute(&prepared, values.clone()).await;
match self.post_execute_prepared_statement(&query, result).await? {
Either::Left(result) => Ok(result),
Either::Right(new_prepared_statement) => {
self.session.execute(&new_prepared_statement, values).await
}
}
}
pub async fn execute_iter(
&self,
query: impl Into<Query>,
values: impl ValueList,
) -> Result<RowIterator, QueryError> {
let query = query.into();
let prepared = self.add_prepared_statement(&query).await?;
let values = values.serialized()?;
let result = self.session.execute_iter(prepared, values.clone()).await;
match self.post_execute_prepared_statement(&query, result).await? {
Either::Left(result) => Ok(result),
Either::Right(new_prepared_statement) => {
self.session
.execute_iter(new_prepared_statement, values)
.await
}
}
}
pub async fn execute_paged(
&self,
query: impl Into<Query>,
values: impl ValueList,
paging_state: Option<Bytes>,
) -> Result<QueryResult, QueryError> {
let query = query.into();
let prepared = self.add_prepared_statement(&query).await?;
let values = values.serialized()?;
let result = self
.session
.execute_paged(&prepared, values.clone(), paging_state.clone())
.await;
match self.post_execute_prepared_statement(&query, result).await? {
Either::Left(result) => Ok(result),
Either::Right(new_prepared_statement) => {
self.session
.execute_paged(&new_prepared_statement, values, paging_state)
.await
}
}
}
pub async fn add_prepared_statement(
&self,
query: impl Into<&Query>,
) -> Result<PreparedStatement, QueryError> {
let query = query.into();
if let Some(prepared) = self.cache.get(&query.contents) {
Ok(prepared.clone())
} else {
let prepared = self.session.prepare(query.clone()).await?;
if self.max_capacity == self.cache.len() {
let query = self.cache.iter().next().map(|c| c.key().to_string());
if let Some(q) = query {
self.cache.remove(&q);
}
}
self.cache.insert(query.contents.clone(), prepared.clone());
Ok(prepared)
}
}
async fn post_execute_prepared_statement<T>(
&self,
query: &Query,
result: Result<T, QueryError>,
) -> Result<Either<T, PreparedStatement>, QueryError> {
match result {
Ok(qr) => Ok(Either::Left(qr)),
Err(err) => {
match err {
QueryError::DbError(db_error, message) => match db_error {
DbError::Unprepared => {
self.cache.remove(&query.contents);
let prepared = self.add_prepared_statement(query).await?;
Ok(Either::Right(prepared))
}
_ => Err(QueryError::DbError(db_error, message)),
},
_ => Err(err),
}
}
}
}
}
#[cfg(test)]
mod tests {
use crate::{CachingSession, SessionBuilder};
use futures::StreamExt;
async fn create_caching_session() -> CachingSession {
let session = CachingSession::from(SessionBuilder::new_for_test().await, 2);
session
.execute("insert into test_table(a, b) values (1, 2)", &[])
.await
.unwrap();
assert_eq!(session.cache.len(), 1);
session.cache.clear();
session
}
#[tokio::test]
async fn test_full() {
let session = create_caching_session().await;
let first_query = "select * from test_table";
let middle_query = "insert into test_table(a, b) values (?, ?)";
let last_query = "update test_table set b = ? where a = 1";
session
.add_prepared_statement(&first_query.into())
.await
.unwrap();
session
.add_prepared_statement(&middle_query.into())
.await
.unwrap();
session
.add_prepared_statement(&last_query.into())
.await
.unwrap();
assert_eq!(2, session.cache.len());
assert!(session.cache.get(last_query).is_some());
let first_query_removed = session.cache.get(first_query).is_none();
let middle_query_removed = session.cache.get(middle_query).is_none();
assert!(first_query_removed || middle_query_removed);
}
#[tokio::test]
async fn test_execute_cached() {
let session = create_caching_session().await;
let result = session
.execute("select * from test_table", &[])
.await
.unwrap();
assert_eq!(1, session.cache.len());
assert_eq!(1, result.rows.unwrap().len());
let result = session
.execute("select * from test_table", &[])
.await
.unwrap();
assert_eq!(1, session.cache.len());
assert_eq!(1, result.rows.unwrap().len());
}
#[tokio::test]
async fn test_execute_iter_cached() {
let session = create_caching_session().await;
assert!(session.cache.is_empty());
let mut iter = session
.execute_iter("select * from test_table", &[])
.await
.unwrap();
let mut rows = 0;
while let Some(_) = iter.next().await {
rows += 1;
}
assert_eq!(1, rows);
assert_eq!(1, session.cache.len());
}
#[tokio::test]
async fn test_execute_paged_cached() {
let session = create_caching_session().await;
assert!(session.cache.is_empty());
let result = session
.execute_paged("select * from test_table", &[], None)
.await
.unwrap();
assert_eq!(1, session.cache.len());
assert_eq!(1, result.rows.unwrap().len());
}
}