use std::any::TypeId;
use std::borrow::Cow;
use std::collections::hash_map::Entry;
use std::hash::Hash;
use std::ops::{Deref, DerefMut};
use strategy::{
LookupStatementResult, StatementCacheStrategy, WithCacheStrategy, WithoutCacheStrategy,
};
use crate::backend::Backend;
use crate::connection::InstrumentationEvent;
use crate::query_builder::*;
use crate::result::QueryResult;
use super::{CacheSize, Instrumentation};
#[allow(unreachable_pub)]
pub mod strategy;
#[allow(missing_debug_implementations, unreachable_pub)]
#[cfg_attr(
diesel_docsrs,
doc(cfg(feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes"))
)]
pub struct StatementCache<DB: Backend, Statement> {
cache: Box<dyn StatementCacheStrategy<DB, Statement>>,
cache_counter: u64,
}
#[derive(Debug, Clone, Copy)]
#[cfg_attr(
diesel_docsrs,
doc(cfg(feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes"))
)]
#[allow(unreachable_pub)]
pub enum PrepareForCache {
Yes {
#[allow(dead_code)]
counter: u64,
},
No,
}
#[allow(clippy::new_without_default, unreachable_pub)]
impl<DB, Statement> StatementCache<DB, Statement>
where
DB: Backend + 'static,
Statement: Send + 'static,
DB::TypeMetadata: Send + Clone,
DB::QueryBuilder: Default,
StatementCacheKey<DB>: Hash + Eq,
{
#[allow(unreachable_pub)]
pub fn new() -> Self {
StatementCache {
cache: Box::new(WithCacheStrategy::default()),
cache_counter: 0,
}
}
pub fn set_cache_size(&mut self, size: CacheSize) {
if self.cache.cache_size() != size {
self.cache = match size {
CacheSize::Unbounded => Box::new(WithCacheStrategy::default()),
CacheSize::Disabled => Box::new(WithoutCacheStrategy::default()),
}
}
}
#[allow(dead_code)]
pub(crate) fn set_strategy<Strategy>(&mut self, s: Strategy)
where
Strategy: StatementCacheStrategy<DB, Statement> + 'static,
{
self.cache = Box::new(s);
}
#[allow(unreachable_pub)]
#[cfg(any(
feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes",
feature = "sqlite",
feature = "mysql"
))]
pub fn cached_statement<'a, T, R, C>(
&'a mut self,
source: &T,
backend: &DB,
bind_types: &[DB::TypeMetadata],
conn: C,
prepare_fn: fn(C, &str, PrepareForCache, &[DB::TypeMetadata]) -> R,
instrumentation: &mut dyn Instrumentation,
) -> R::Return<'a>
where
T: QueryFragment<DB> + QueryId,
R: StatementCallbackReturnType<Statement, C> + 'a,
{
self.cached_statement_non_generic(
T::query_id(),
source,
backend,
bind_types,
conn,
prepare_fn,
instrumentation,
)
}
#[allow(unreachable_pub)]
#[allow(clippy::too_many_arguments)] pub fn cached_statement_non_generic<'a, R, C>(
&'a mut self,
maybe_type_id: Option<TypeId>,
source: &dyn QueryFragmentForCachedStatement<DB>,
backend: &DB,
bind_types: &[DB::TypeMetadata],
conn: C,
prepare_fn: fn(C, &str, PrepareForCache, &[DB::TypeMetadata]) -> R,
instrumentation: &mut dyn Instrumentation,
) -> R::Return<'a>
where
R: StatementCallbackReturnType<Statement, C> + 'a,
{
Self::cached_statement_non_generic_impl(
self.cache.as_mut(),
maybe_type_id,
source,
backend,
bind_types,
conn,
|conn, sql, is_cached| {
if is_cached {
instrumentation.on_connection_event(InstrumentationEvent::CacheQuery { sql });
self.cache_counter += 1;
prepare_fn(
conn,
sql,
PrepareForCache::Yes {
counter: self.cache_counter,
},
bind_types,
)
} else {
prepare_fn(conn, sql, PrepareForCache::No, bind_types)
}
},
)
}
fn cached_statement_non_generic_impl<'a, R, C>(
cache: &'a mut dyn StatementCacheStrategy<DB, Statement>,
maybe_type_id: Option<TypeId>,
source: &dyn QueryFragmentForCachedStatement<DB>,
backend: &DB,
bind_types: &[DB::TypeMetadata],
conn: C,
prepare_fn: impl FnOnce(C, &str, bool) -> R,
) -> R::Return<'a>
where
R: StatementCallbackReturnType<Statement, C> + 'a,
{
let cache_key =
match StatementCacheKey::for_source(maybe_type_id, source, bind_types, backend) {
Ok(o) => o,
Err(e) => return R::from_error(e),
};
let is_safe_to_cache_prepared = match source.is_safe_to_cache_prepared(backend) {
Ok(o) => o,
Err(e) => return R::from_error(e),
};
if !is_safe_to_cache_prepared {
let sql = match cache_key.sql(source, backend) {
Ok(sql) => sql,
Err(e) => return R::from_error(e),
};
return prepare_fn(conn, &sql, false).map_to_no_cache();
}
let entry = cache.lookup_statement(cache_key);
match entry {
LookupStatementResult::CacheEntry(Entry::Occupied(e)) => {
R::map_to_cache(e.into_mut(), conn)
}
LookupStatementResult::CacheEntry(Entry::Vacant(e)) => {
let sql = match e.key().sql(source, backend) {
Ok(sql) => sql,
Err(e) => return R::from_error(e),
};
let st = prepare_fn(conn, &sql, true);
st.register_cache(|stmt| e.insert(stmt))
}
LookupStatementResult::NoCache(cache_key) => {
let sql = match cache_key.sql(source, backend) {
Ok(sql) => sql,
Err(e) => return R::from_error(e),
};
prepare_fn(conn, &sql, false).map_to_no_cache()
}
}
}
}
#[allow(unreachable_pub)]
#[cfg_attr(
diesel_docsrs,
doc(cfg(feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes"))
)]
pub trait QueryFragmentForCachedStatement<DB> {
fn construct_sql(&self, backend: &DB) -> QueryResult<String>;
fn is_safe_to_cache_prepared(&self, backend: &DB) -> QueryResult<bool>;
}
impl<T, DB> QueryFragmentForCachedStatement<DB> for T
where
DB: Backend,
DB::QueryBuilder: Default,
T: QueryFragment<DB>,
{
fn construct_sql(&self, backend: &DB) -> QueryResult<String> {
let mut query_builder = DB::QueryBuilder::default();
self.to_sql(&mut query_builder, backend)?;
Ok(query_builder.finish())
}
fn is_safe_to_cache_prepared(&self, backend: &DB) -> QueryResult<bool> {
<T as QueryFragment<DB>>::is_safe_to_cache_prepared(self, backend)
}
}
#[allow(missing_debug_implementations, unreachable_pub)]
#[cfg_attr(
diesel_docsrs,
doc(cfg(feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes"))
)]
#[non_exhaustive]
pub enum MaybeCached<'a, T: 'a> {
CannotCache(T),
Cached(&'a mut T),
}
#[cfg_attr(
diesel_docsrs,
doc(cfg(feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes"))
)]
#[allow(unreachable_pub)]
pub trait StatementCallbackReturnType<S: 'static, C> {
type Return<'a>;
fn from_error<'a>(e: diesel::result::Error) -> Self::Return<'a>;
fn map_to_no_cache<'a>(self) -> Self::Return<'a>
where
Self: 'a;
fn map_to_cache(stmt: &mut S, conn: C) -> Self::Return<'_>;
fn register_cache<'a>(
self,
callback: impl FnOnce(S) -> &'a mut S + Send + 'a,
) -> Self::Return<'a>
where
Self: 'a;
}
impl<S, C> StatementCallbackReturnType<S, C> for QueryResult<S>
where
S: 'static,
{
type Return<'a> = QueryResult<MaybeCached<'a, S>>;
fn from_error<'a>(e: diesel::result::Error) -> Self::Return<'a> {
Err(e)
}
fn map_to_no_cache<'a>(self) -> Self::Return<'a> {
self.map(MaybeCached::CannotCache)
}
fn map_to_cache(stmt: &mut S, _conn: C) -> Self::Return<'_> {
Ok(MaybeCached::Cached(stmt))
}
fn register_cache<'a>(
self,
callback: impl FnOnce(S) -> &'a mut S + Send + 'a,
) -> Self::Return<'a>
where
Self: 'a,
{
Ok(MaybeCached::Cached(callback(self?)))
}
}
impl<T> Deref for MaybeCached<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
match *self {
MaybeCached::CannotCache(ref x) => x,
MaybeCached::Cached(ref x) => x,
}
}
}
impl<T> DerefMut for MaybeCached<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
match *self {
MaybeCached::CannotCache(ref mut x) => x,
MaybeCached::Cached(ref mut x) => x,
}
}
}
#[allow(missing_debug_implementations, unreachable_pub)]
#[derive(Hash, PartialEq, Eq)]
#[cfg_attr(
diesel_docsrs,
doc(cfg(feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes"))
)]
pub enum StatementCacheKey<DB: Backend> {
Type(TypeId),
Sql {
sql: String,
bind_types: Vec<DB::TypeMetadata>,
},
}
impl<DB> StatementCacheKey<DB>
where
DB: Backend,
DB::QueryBuilder: Default,
DB::TypeMetadata: Clone,
{
#[allow(unreachable_pub)]
pub fn for_source(
maybe_type_id: Option<TypeId>,
source: &dyn QueryFragmentForCachedStatement<DB>,
bind_types: &[DB::TypeMetadata],
backend: &DB,
) -> QueryResult<Self> {
match maybe_type_id {
Some(id) => Ok(StatementCacheKey::Type(id)),
None => {
let sql = source.construct_sql(backend)?;
Ok(StatementCacheKey::Sql {
sql,
bind_types: bind_types.into(),
})
}
}
}
#[allow(unreachable_pub)]
pub fn sql(
&self,
source: &dyn QueryFragmentForCachedStatement<DB>,
backend: &DB,
) -> QueryResult<Cow<'_, str>> {
match *self {
StatementCacheKey::Type(_) => source.construct_sql(backend).map(Cow::Owned),
StatementCacheKey::Sql { ref sql, .. } => Ok(Cow::Borrowed(sql)),
}
}
}