use std::error::Error;
use std::fmt;
use std::future::Future;
use std::marker::PhantomData;
use std::time::Duration;
use hydracache::{CacheKeyBuilder, CacheOptions, HydraCache, PostcardCodec, TagSet};
use hydracache_core::CacheCodec;
use serde::{de::DeserializeOwned, Serialize};
use crate::{CacheEntity, DbCacheError, QueryCachePolicy, Result};
pub struct DbCache<C = PostcardCodec>
where
C: CacheCodec,
{
cache: HydraCache<C>,
namespace: String,
}
impl<C> Clone for DbCache<C>
where
C: CacheCodec,
{
fn clone(&self) -> Self {
Self {
cache: self.cache.clone(),
namespace: self.namespace.clone(),
}
}
}
impl<C> fmt::Debug for DbCache<C>
where
C: CacheCodec,
{
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter
.debug_struct("DbCache")
.field("namespace", &self.namespace)
.finish_non_exhaustive()
}
}
impl<C> DbCache<C>
where
C: CacheCodec,
{
pub fn new(cache: HydraCache<C>, namespace: impl Into<String>) -> Self {
Self {
cache,
namespace: namespace.into(),
}
}
pub fn namespace(&self) -> &str {
&self.namespace
}
pub fn cache(&self) -> &HydraCache<C> {
&self.cache
}
pub fn cached<T>(&self) -> DbQuery<T, C> {
DbQuery {
cache: self.cache.clone(),
namespace: self.namespace.clone(),
policy: QueryCachePolicy::new(),
value: PhantomData,
}
}
pub fn cached_with<T>(&self, policy: QueryCachePolicy) -> DbQuery<T, C> {
self.cached::<T>().with_policy(policy)
}
pub fn entity<T>(&self, kind: impl ToString, id: impl ToString) -> DbQuery<T, C> {
self.cached::<T>().for_entity(kind, id)
}
pub fn for_entity<T>(&self, id: T::Id) -> DbQuery<T, C>
where
T: CacheEntity,
{
self.cached::<T>().for_cache_entity(id)
}
pub fn collection<T>(&self, name: impl ToString) -> DbQuery<T, C> {
self.cached::<T>().collection(name)
}
pub fn named<T>(&self, name: impl Into<String>) -> DbQuery<T, C> {
DbQuery {
cache: self.cache.clone(),
namespace: self.namespace.clone(),
policy: QueryCachePolicy::named(name),
value: PhantomData,
}
}
pub fn query_as<T>(&self, sql: impl Into<String>) -> DbQuery<T, C> {
self.named(sql)
}
}
pub struct DbQuery<T, C = PostcardCodec>
where
C: CacheCodec,
{
cache: HydraCache<C>,
namespace: String,
policy: QueryCachePolicy,
value: PhantomData<fn() -> T>,
}
impl<T, C> Clone for DbQuery<T, C>
where
C: CacheCodec,
{
fn clone(&self) -> Self {
Self {
cache: self.cache.clone(),
namespace: self.namespace.clone(),
policy: self.policy.clone(),
value: PhantomData,
}
}
}
impl<T, C> fmt::Debug for DbQuery<T, C>
where
C: CacheCodec,
{
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter
.debug_struct("DbQuery")
.field("namespace", &self.namespace)
.field("policy", &self.policy)
.finish_non_exhaustive()
}
}
impl<T, C> DbQuery<T, C>
where
C: CacheCodec,
{
pub fn name(&self) -> Option<&str> {
self.policy.name()
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.policy = self.policy.with_name(name);
self
}
pub fn cache_policy(&self) -> &QueryCachePolicy {
&self.policy
}
pub fn with_policy(mut self, policy: QueryCachePolicy) -> Self {
self.policy = policy;
self
}
pub fn namespace(&self) -> &str {
&self.namespace
}
pub fn key_value(&self) -> Option<&str> {
self.policy.key_value()
}
pub fn physical_key(&self) -> Option<String> {
let key = self.key_value()?;
Some(physical_key(&self.namespace, key))
}
pub fn tags_value(&self) -> &[String] {
self.policy.tags_value()
}
pub fn ttl_value(&self) -> Option<Duration> {
self.policy.ttl_value()
}
pub fn key(mut self, key: impl Into<String>) -> Self {
self.policy = self.policy.key(key);
self
}
pub fn key_builder(self, key: CacheKeyBuilder) -> Self {
self.key(key.build_string())
}
pub fn for_entity(mut self, kind: impl ToString, id: impl ToString) -> Self {
self.policy = self.policy.for_entity(kind, id);
self
}
pub fn for_cache_entity(mut self, id: T::Id) -> Self
where
T: CacheEntity,
{
self.policy = self.policy.for_cache_entity::<T>(id);
self
}
pub fn collection(mut self, name: impl ToString) -> Self {
self.policy = self.policy.collection(name);
self
}
pub fn tag(mut self, tag: impl Into<String>) -> Self {
self.policy = self.policy.tag(tag);
self
}
pub fn collection_tag(mut self, name: impl ToString) -> Self {
self.policy = self.policy.collection_tag(name);
self
}
pub fn tags<I, S>(mut self, tags: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.policy = self.policy.tags(tags);
self
}
pub fn tag_set(mut self, tags: TagSet) -> Self {
self.policy = self.policy.tag_set(tags);
self
}
pub fn ttl(mut self, ttl: Duration) -> Self {
self.policy = self.policy.ttl(ttl);
self
}
pub async fn load<E, F, Fut>(self, loader: F) -> Result<T>
where
T: Serialize + DeserializeOwned + Send + 'static,
E: Error + Send + Sync + 'static,
F: FnOnce() -> Fut + Send + 'static,
Fut: Future<Output = std::result::Result<T, E>> + Send + 'static,
{
self.fetch_with(loader).await
}
pub async fn fetch_with<E, F, Fut>(self, loader: F) -> Result<T>
where
T: Serialize + DeserializeOwned + Send + 'static,
E: Error + Send + Sync + 'static,
F: FnOnce() -> Fut + Send + 'static,
Fut: Future<Output = std::result::Result<T, E>> + Send + 'static,
{
self.fetch_value_with(loader).await
}
pub async fn fetch_value_with<U, E, F, Fut>(self, loader: F) -> Result<U>
where
U: Serialize + DeserializeOwned + Send + 'static,
E: Error + Send + Sync + 'static,
F: FnOnce() -> Fut + Send + 'static,
Fut: Future<Output = std::result::Result<U, E>> + Send + 'static,
{
let key = self.required_physical_key()?;
self.cache
.get_or_load(&key, self.options(), loader)
.await
.map_err(DbCacheError::from)
}
fn options(&self) -> CacheOptions {
self.policy.cache_options()
}
fn required_physical_key(&self) -> Result<String> {
self.physical_key().ok_or_else(|| DbCacheError::MissingKey {
operation: self.operation_label(),
})
}
fn operation_label(&self) -> String {
self.name()
.map(str::to_owned)
.unwrap_or_else(|| default_operation_label(&self.namespace))
}
}
fn physical_key(namespace: &str, key: &str) -> String {
if namespace.is_empty() {
key.to_owned()
} else {
format!("{namespace}:{key}")
}
}
fn default_operation_label(namespace: &str) -> String {
if namespace.is_empty() {
"unnamed".to_owned()
} else {
format!("{namespace}:unnamed")
}
}