use std::time::{Duration, SystemTime, UNIX_EPOCH};
use async_trait::async_trait;
use aws_sdk_dynamodb::{primitives::Blob, types::AttributeValue};
use super::{Driver, Error as CacheError};
#[derive(Debug, Clone)]
pub struct Config {
pub table: String,
pub prefix: String,
pub key_attribute: String,
pub value_attribute: String,
pub expiration_attribute: String,
pub aws_config: aws_types::SdkConfig,
}
impl Default for Config {
fn default() -> Self {
Self {
prefix: String::new(),
table: "cache".to_string(),
key_attribute: String::from("key"),
value_attribute: String::from("value"),
expiration_attribute: String::from("expires_at"),
aws_config: aws_types::SdkConfig::builder().build(),
}
}
}
#[allow(clippy::module_name_repetitions)]
#[derive(Debug)]
pub struct DynamoDBDriver {
table: String,
prefix: String,
key_attribute: String,
value_attribute: String,
expiration_attribute: String,
client: aws_sdk_dynamodb::Client,
}
impl DynamoDBDriver {
#[must_use]
pub fn new(config: Config) -> Self {
Self {
table: config.table,
prefix: config.prefix,
key_attribute: config.key_attribute,
value_attribute: config.value_attribute,
expiration_attribute: config.expiration_attribute,
client: aws_sdk_dynamodb::Client::new(&config.aws_config),
}
}
async fn get_item(&self, key: &str) -> Result<Option<Vec<u8>>, CacheError> {
let response = self
.client
.get_item()
.table_name(&self.table)
.key(
self.key_attribute.clone(),
AttributeValue::S(format!("{}{key}", self.prefix)),
)
.send()
.await
.map_err(|e| CacheError::Other(Box::new(e)))?;
let Some(item) = response.item else {
return Ok(None);
};
if let Some(expires_at) = item
.get(&self.expiration_attribute)
.map(|value| value.as_n())
{
let expires_at: u64 = expires_at
.map_err(|_| CacheError::Other(Box::new(Error::InvalidDataFormat)))?
.parse()
.map_err(|_| CacheError::Other(Box::new(Error::InvalidDataFormat)))?;
if UNIX_EPOCH + Duration::from_secs(expires_at) < SystemTime::now() {
return Ok(None);
}
}
let data = if let Some(data) = item.get(&self.value_attribute).map(|value| value.as_b()) {
data.map_err(|_| CacheError::Other(Box::new(Error::InvalidDataFormat)))?
} else {
return Err(CacheError::Other(Box::new(Error::InvalidDataFormat)));
};
Ok(Some(data.as_ref().to_vec()))
}
}
#[async_trait]
impl Driver for DynamoDBDriver {
async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, CacheError> {
self.get_item(key).await
}
async fn has(&self, key: &str) -> Result<bool, CacheError> {
Ok(self.get_item(key).await?.is_some())
}
async fn put(
&self,
key: &str,
value: Vec<u8>,
expiry: Option<Duration>,
) -> Result<(), CacheError> {
let expires_at = expiry.map(|expiry| SystemTime::now() + expiry);
self.client
.put_item()
.table_name(&self.table)
.item(
self.key_attribute.clone(),
AttributeValue::S(format!("{}{key}", self.prefix)),
)
.item(
self.value_attribute.clone(),
AttributeValue::B(Blob::new(value)),
)
.item(
self.expiration_attribute.clone(),
expires_at.map_or(AttributeValue::Null(true), |expires_at| {
AttributeValue::N(
expires_at
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_secs()
.to_string(),
)
}),
)
.send()
.await
.map_err(|e| CacheError::Other(Box::new(e)))?;
Ok(())
}
async fn forget(&self, key: &str) -> Result<(), CacheError> {
self.client
.delete_item()
.table_name(&self.table)
.key(&self.key_attribute, AttributeValue::S(key.to_string()))
.send()
.await
.map_err(|e| CacheError::Other(Box::new(e)))?;
Ok(())
}
async fn flush(&self) -> Result<(), CacheError> {
Err(CacheError::Other(Box::new(Error::FlushNotSupported)))
}
}
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("DynamoDB does not support flushing the cache.")]
FlushNotSupported,
#[error("the stored data was on an unexpected format.")]
InvalidDataFormat,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Cache;
#[tokio::test]
async fn test_dynamodb_driver() {
let cache = Cache::new(DynamoDBDriver::new(Config::default()));
assert_eq!(cache.get::<String>("foo").await.unwrap(), None);
assert!(!cache.has("foo").await.unwrap());
cache
.put("foo", &"bar", Duration::from_secs(1))
.await
.unwrap();
assert_eq!(
cache.get::<String>("foo").await.unwrap(),
Some("bar".to_string())
);
assert!(cache.has("foo").await.unwrap());
cache.forget("foo").await.unwrap();
assert_eq!(cache.get::<String>("foo").await.unwrap(), None);
assert!(!cache.has("foo").await.unwrap());
}
}