use crate::{
ClientError, Error, Result,
client::{Client, PreparedCommand},
commands::{
BitFieldSubCommand, BitRange, BitmapCommands, ClientTrackingOptions, ClientTrackingStatus,
ConnectionCommands, HashCommands, ListCommands, SetCommands, SortedSetCommands,
StringCommands, ZRangeOptions,
},
resp::{
BulkString, Command, CommandArgsMut, FastPathCommandBuilder, RespDeserializer,
RespResponse, Response,
},
};
use bytes::Bytes;
use dashmap::DashMap;
use futures_util::StreamExt;
use serde::{Serialize, de::DeserializeOwned};
use std::{sync::Arc, time::Duration};
pub use moka::future::CacheBuilder;
type SubCache = DashMap<Bytes, RespResponse>;
type MokaCache = moka::future::Cache<BulkString, Arc<SubCache>>;
type MokaCacheBuilder = moka::future::CacheBuilder<BulkString, Arc<SubCache>, MokaCache>;
pub struct Cache {
cache: Arc<MokaCache>,
client: Client,
#[allow(dead_code)]
invalidation_task: tokio::task::JoinHandle<()>,
}
impl Cache {
#[allow(clippy::type_complexity)]
pub async fn from_builder(
client: Client,
builder: MokaCacheBuilder,
tracking_opts: ClientTrackingOptions,
) -> Result<Arc<Self>> {
client
.client_tracking(ClientTrackingStatus::On, tracking_opts)
.await?;
let stream = client.create_client_tracking_invalidation_stream()?;
let cache = Arc::new(builder.build());
let cache_clone = cache.clone();
let connection_tag = client.connection_tag().to_owned();
let invalidation_task = tokio::spawn(async move {
let mut stream = stream;
while let Some(keys) = stream.next().await {
for key in keys {
log::debug!(
"[{}] Invalidating key `{key}` from client cache",
connection_tag
);
let key: BulkString = key.into_bytes().into();
cache_clone.invalidate(&key).await;
}
}
});
Ok(Arc::new(Self {
cache,
client,
invalidation_task,
}))
}
pub async fn new(
client: Client,
ttl_secs: u64,
tracking_opts: ClientTrackingOptions,
) -> Result<Arc<Self>> {
let builder = MokaCache::builder()
.time_to_live(Duration::from_secs(ttl_secs))
.max_capacity(10_000);
Self::from_builder(client, builder, tracking_opts).await
}
pub async fn get<R: Response + DeserializeOwned>(&self, key: impl Serialize) -> Result<R> {
self.process_prepared_command(key_to_bulk_string(&key), self.client.get(key))
.await
}
pub async fn mget<R: Response + DeserializeOwned>(&self, keys: impl Serialize) -> Result<R> {
let prepared_command = self.client.mget::<R>(keys);
let mut responses = Vec::with_capacity(prepared_command.command.num_args());
let mut missing_indices = Vec::new();
let mut missing_keys = Vec::new();
for (i, arg) in prepared_command.command.args().enumerate() {
let key = BulkString::from(arg.clone());
if let Some(values) = self.cache.get(&key).await
&& let Some(response) = values.get(FastPathCommandBuilder::get(key.clone()).bytes())
{
log::debug!(
"[{}] Cache hit on key `{}`",
self.client.connection_tag(),
key
);
responses.push(response.clone());
} else {
log::debug!(
"[{}] Cache miss on key `{}`",
self.client.connection_tag(),
key
);
responses.push(RespResponse::null());
missing_indices.push(i);
missing_keys.push(key);
}
}
if !missing_keys.is_empty() {
let missing_prepared_command = self.client.mget::<R>(missing_keys);
let response = self
.client
.internal_send(missing_prepared_command.command, None)
.await?;
let Ok(array_iter) = response.clone().into_array_iter() else {
return Err(Error::Client(ClientError::ExpectedArrayForMGet));
};
for (idx_in_missing, response) in array_iter.enumerate() {
let original_idx = missing_indices[idx_in_missing];
let Some(key) = prepared_command
.command
.get_arg(original_idx)
.map(BulkString::from)
else {
break;
};
self.cache
.entry(key.clone())
.or_insert_with(async { Arc::new(DashMap::new()) })
.await
.value()
.insert(
FastPathCommandBuilder::get(key).bytes().clone(),
response.clone(),
);
responses[original_idx] = response;
}
} else {
log::debug!("[{}] Cache hit on mget", self.client.connection_tag());
}
let response = RespResponse::owned_array(responses);
let deserializer = RespDeserializer::new(response.view());
R::deserialize(deserializer)
}
pub async fn getrange<R: Response + DeserializeOwned>(
&self,
key: impl Serialize,
start: isize,
end: isize,
) -> Result<R> {
self.process_prepared_command(
key_to_bulk_string(&key),
self.client.getrange(key, start, end),
)
.await
}
pub async fn strlen(&self, key: impl Serialize) -> Result<usize> {
self.process_prepared_command(key_to_bulk_string(&key), self.client.strlen(key))
.await
}
pub async fn hexists(&self, key: impl Serialize, field: impl Serialize) -> Result<bool> {
self.process_prepared_command(
key_to_bulk_string(&key),
self.client.hexists(key_to_bulk_string(&key), field),
)
.await
}
pub async fn hget<R: Response + DeserializeOwned>(
&self,
key: impl Serialize,
field: impl Serialize,
) -> Result<R> {
self.process_prepared_command(key_to_bulk_string(&key), self.client.hget(key, field))
.await
}
pub async fn hgetall<R: Response + DeserializeOwned>(&self, key: impl Serialize) -> Result<R> {
self.process_prepared_command(key_to_bulk_string(&key), self.client.hgetall(key))
.await
}
pub async fn hlen(&self, key: impl Serialize) -> Result<usize> {
self.process_prepared_command(key_to_bulk_string(&key), self.client.hlen(key))
.await
}
pub async fn hkeys<R: Response + DeserializeOwned>(&self, key: impl Serialize) -> Result<R> {
self.process_prepared_command(key_to_bulk_string(&key), self.client.hkeys(key))
.await
}
pub async fn hvals<R: Response + DeserializeOwned>(&self, key: impl Serialize) -> Result<R> {
self.process_prepared_command(key_to_bulk_string(&key), self.client.hvals(key))
.await
}
pub async fn hstrlen(&self, key: impl Serialize, field: impl Serialize) -> Result<usize> {
self.process_prepared_command(key_to_bulk_string(&key), self.client.hstrlen(key, field))
.await
}
pub async fn hmget<R: Response + DeserializeOwned>(
&self,
key: impl Serialize,
fields: impl Serialize,
) -> Result<R> {
self.process_prepared_command(key_to_bulk_string(&key), self.client.hmget(key, fields))
.await
}
pub async fn lrange<R: Response + DeserializeOwned>(
&self,
key: impl Serialize,
start: isize,
stop: isize,
) -> Result<R> {
self.process_prepared_command(
key_to_bulk_string(&key),
self.client.lrange(key, start, stop),
)
.await
}
pub async fn llen(&self, key: impl Serialize) -> Result<usize> {
self.process_prepared_command(key_to_bulk_string(&key), self.client.llen(key))
.await
}
pub async fn lindex<R: Response + DeserializeOwned>(
&self,
key: impl Serialize,
index: isize,
) -> Result<R> {
self.process_prepared_command(key_to_bulk_string(&key), self.client.lindex(key, index))
.await
}
pub async fn smembers<R: Response + DeserializeOwned>(&self, key: impl Serialize) -> Result<R> {
self.process_prepared_command(key_to_bulk_string(&key), self.client.smembers(key))
.await
}
pub async fn scard(&self, key: impl Serialize) -> Result<usize> {
self.process_prepared_command(key_to_bulk_string(&key), self.client.scard(key))
.await
}
pub async fn sismember(&self, key: impl Serialize, member: impl Serialize) -> Result<bool> {
self.process_prepared_command(key_to_bulk_string(&key), self.client.sismember(key, member))
.await
}
pub async fn zcard(&self, key: impl Serialize) -> Result<usize> {
self.process_prepared_command(key_to_bulk_string(&key), self.client.zcard(key))
.await
}
pub async fn zcount(
&self,
key: impl Serialize,
min: impl Serialize,
max: impl Serialize,
) -> Result<usize> {
self.process_prepared_command(key_to_bulk_string(&key), self.client.zcount(key, min, max))
.await
}
pub async fn zlexcount(
&self,
key: impl Serialize,
min: impl Serialize,
max: impl Serialize,
) -> Result<usize> {
self.process_prepared_command(
key_to_bulk_string(&key),
self.client.zlexcount(key, min, max),
)
.await
}
pub async fn zrange<R: Response + DeserializeOwned>(
&self,
key: impl Serialize,
start: impl Serialize,
stop: impl Serialize,
options: ZRangeOptions,
) -> Result<R> {
self.process_prepared_command(
key_to_bulk_string(&key),
self.client.zrange(key, start, stop, options),
)
.await
}
pub async fn zrank(
&self,
key: impl Serialize,
member: impl Serialize,
) -> Result<Option<usize>> {
self.process_prepared_command(key_to_bulk_string(&key), self.client.zrank(key, member))
.await
}
pub async fn zremrangebyscore(
&self,
key: impl Serialize,
start: impl Serialize,
stop: impl Serialize,
) -> Result<usize> {
self.process_prepared_command(
key_to_bulk_string(&key),
self.client.zremrangebyscore(key, start, stop),
)
.await
}
pub async fn zrevrank(
&self,
key: impl Serialize,
member: impl Serialize,
) -> Result<Option<usize>> {
self.process_prepared_command(key_to_bulk_string(&key), self.client.zrevrank(key, member))
.await
}
pub async fn zscore(&self, key: impl Serialize, member: impl Serialize) -> Result<Option<f64>> {
self.process_prepared_command(key_to_bulk_string(&key), self.client.zscore(key, member))
.await
}
pub async fn bitcount(&self, key: impl Serialize, range: BitRange) -> Result<usize> {
self.process_prepared_command(key_to_bulk_string(&key), self.client.bitcount(key, range))
.await
}
pub async fn bitpos(&self, key: impl Serialize, bit: u64, range: BitRange) -> Result<usize> {
self.process_prepared_command(
key_to_bulk_string(&key),
self.client.bitpos(key, bit, range),
)
.await
}
pub async fn getbit(&self, key: impl Serialize, offset: u64) -> Result<u64> {
self.process_prepared_command(key_to_bulk_string(&key), self.client.getbit(key, offset))
.await
}
pub async fn bitfield_readonly<'a>(
&self,
key: impl Serialize,
sub_commands: impl IntoIterator<Item = BitFieldSubCommand<'a>> + Serialize,
) -> Result<Vec<u64>> {
self.process_prepared_command(
key_to_bulk_string(&key),
self.client.bitfield_readonly(key, sub_commands),
)
.await
}
async fn process_prepared_command<'a, R>(
&self,
key: BulkString,
prepared_command: PreparedCommand<'a, &'a Client, R>,
) -> Result<R>
where
R: Response + DeserializeOwned,
{
self.process_command(key, prepared_command.command).await
}
async fn process_command<R>(&self, key: BulkString, command: Command) -> Result<R>
where
R: Response + DeserializeOwned,
{
if let Some(values) = self.cache.get(&key).await
&& let Some(response) = values.get(command.bytes())
{
log::debug!(
"[{}] Cache hit on key `{}`",
self.client.connection_tag(),
key
);
let deserializer = RespDeserializer::new(response.view());
return R::deserialize(deserializer);
}
log::debug!(
"[{}] Cache miss on key `{}`",
self.client.connection_tag(),
key
);
let command_bytes = command.bytes().clone();
let response = self.client.internal_send(command, None).await?;
let deserializer = RespDeserializer::new(response.view());
let deserialized = R::deserialize(deserializer)?;
self.cache
.entry(key)
.or_insert_with(async { Arc::new(DashMap::new()) })
.await
.value()
.insert(command_bytes, response);
Ok(deserialized)
}
}
fn key_to_bulk_string(key: &impl Serialize) -> BulkString {
let args = CommandArgsMut::default().arg(key).freeze();
args.into_iter()
.next()
.expect("expected a single argument")
.into()
}