use async_memcached::{AsciiProtocol, Client};
use async_trait::async_trait;
use bb8::{Pool, PooledConnection};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use super::{CacheBackend, CacheEntry, CacheRead};
use crate::error::CacheError;
pub struct MemcachedConnectionManager {
address: String,
}
impl MemcachedConnectionManager {
pub fn new(address: impl Into<String>) -> Self {
Self {
address: address.into(),
}
}
}
#[async_trait]
impl bb8::ManageConnection for MemcachedConnectionManager {
type Connection = Client;
type Error = async_memcached::Error;
async fn connect(&self) -> Result<Self::Connection, Self::Error> {
Client::new(&self.address).await
}
async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> {
conn.version().await?;
Ok(())
}
fn has_broken(&self, _conn: &mut Self::Connection) -> bool {
false
}
}
type MemcachedPool = Pool<MemcachedConnectionManager>;
#[derive(Clone)]
pub struct MemcachedBackend {
pool: MemcachedPool,
namespace: String,
}
impl MemcachedBackend {
pub async fn new(address: impl Into<String>) -> Result<Self, CacheError> {
Self::builder().address(address).build().await
}
pub fn builder() -> MemcachedBackendBuilder {
MemcachedBackendBuilder::default()
}
async fn get_connection(
&self,
) -> Result<PooledConnection<'_, MemcachedConnectionManager>, CacheError> {
self.pool
.get()
.await
.map_err(|e| CacheError::Backend(format!("Failed to get connection: {}", e)))
}
fn make_key(&self, key: &str) -> String {
format!("{}:{}", self.namespace, key)
}
pub fn pool_state(&self) -> PoolState {
let state = self.pool.state();
PoolState {
connections: state.connections,
idle_connections: state.idle_connections,
}
}
}
#[derive(Debug, Clone)]
pub struct PoolState {
pub connections: u32,
pub idle_connections: u32,
}
pub struct MemcachedBackendBuilder {
address: Option<String>,
namespace: String,
max_connections: u32,
min_connections: u32,
connection_timeout: Duration,
}
impl Default for MemcachedBackendBuilder {
fn default() -> Self {
Self {
address: None,
namespace: "tower_http_cache".to_string(),
max_connections: 10,
min_connections: 2,
connection_timeout: Duration::from_secs(30),
}
}
}
impl MemcachedBackendBuilder {
pub fn address(mut self, address: impl Into<String>) -> Self {
self.address = Some(address.into());
self
}
pub fn namespace(mut self, namespace: impl Into<String>) -> Self {
self.namespace = namespace.into();
self
}
pub fn max_connections(mut self, max: u32) -> Self {
self.max_connections = max;
self
}
pub fn min_connections(mut self, min: u32) -> Self {
self.min_connections = min;
self
}
pub fn connection_timeout(mut self, timeout: Duration) -> Self {
self.connection_timeout = timeout;
self
}
pub async fn build(self) -> Result<MemcachedBackend, CacheError> {
let address = self
.address
.ok_or_else(|| CacheError::Backend("address is required".to_string()))?;
let manager = MemcachedConnectionManager::new(address);
let pool = Pool::builder()
.max_size(self.max_connections)
.min_idle(Some(self.min_connections))
.connection_timeout(self.connection_timeout)
.build(manager)
.await
.map_err(|e| CacheError::Backend(format!("Failed to create connection pool: {}", e)))?;
Ok(MemcachedBackend {
pool,
namespace: self.namespace,
})
}
}
#[derive(serde::Serialize, serde::Deserialize)]
struct MemcachedRecord {
entry: CacheEntry,
expires_at_ms: u64,
stale_until_ms: u64,
}
fn system_time_to_unix_ms(time: SystemTime) -> Result<u64, CacheError> {
time.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.map_err(|e| CacheError::Backend(format!("Time conversion error: {}", e)))
}
fn unix_ms_to_system_time(ms: u64) -> Result<SystemTime, CacheError> {
Ok(UNIX_EPOCH + Duration::from_millis(ms))
}
fn current_millis() -> Result<u64, CacheError> {
system_time_to_unix_ms(SystemTime::now())
}
fn duration_millis(duration: Duration) -> u64 {
duration.as_millis() as u64
}
#[async_trait]
impl CacheBackend for MemcachedBackend {
async fn get(&self, key: &str) -> Result<Option<CacheRead>, CacheError> {
let namespaced_key = self.make_key(key);
let mut conn = self.get_connection().await?;
let value = (*conn)
.get(namespaced_key.as_bytes())
.await
.map_err(|e| CacheError::Backend(format!("Memcached get failed: {}", e)))?;
if let Some(data) = value {
let data_bytes = data
.data
.as_ref()
.ok_or_else(|| CacheError::Backend("Memcached value has no data".to_string()))?;
let record: MemcachedRecord = bincode::deserialize(data_bytes.as_slice())
.map_err(|e| CacheError::Backend(format!("Deserialization failed: {}", e)))?;
Ok(Some(CacheRead {
entry: record.entry,
expires_at: Some(unix_ms_to_system_time(record.expires_at_ms)?),
stale_until: Some(unix_ms_to_system_time(record.stale_until_ms)?),
}))
} else {
Ok(None)
}
}
async fn set(
&self,
key: String,
entry: CacheEntry,
ttl: Duration,
stale_for: Duration,
) -> Result<(), CacheError> {
if ttl.is_zero() {
return Ok(());
}
let namespaced_key = self.make_key(&key);
let now_ms = current_millis()?;
let expires_at_ms = now_ms.saturating_add(duration_millis(ttl));
let stale_until_ms = expires_at_ms.saturating_add(duration_millis(stale_for));
let record = MemcachedRecord {
entry,
expires_at_ms,
stale_until_ms,
};
let bytes = bincode::serialize(&record)
.map_err(|e| CacheError::Backend(format!("Serialization failed: {}", e)))?;
let total_ttl = ttl.saturating_add(stale_for);
let ttl_secs = total_ttl.as_secs();
let ttl_u32 = ttl_secs.min(u32::MAX as u64) as u32;
let mut conn = self.get_connection().await?;
(*conn)
.set(
namespaced_key.as_bytes(),
bytes.as_slice(),
Some(ttl_u32 as i64),
Default::default(),
)
.await
.map_err(|e| CacheError::Backend(format!("Memcached set failed: {}", e)))?;
Ok(())
}
async fn invalidate(&self, key: &str) -> Result<(), CacheError> {
let namespaced_key = self.make_key(key);
let mut conn = self.get_connection().await?;
(*conn)
.delete(namespaced_key.as_bytes())
.await
.map_err(|e| CacheError::Backend(format!("Memcached delete failed: {}", e)))?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use http::StatusCode;
#[test]
fn test_make_key() {
let namespace = "test_app";
let make_key = |key: &str| format!("{}:{}", namespace, key);
assert_eq!(make_key("my_key"), "test_app:my_key");
assert_eq!(make_key("another/key"), "test_app:another/key");
}
#[test]
fn test_system_time_conversion() {
let now = SystemTime::now();
let ms = system_time_to_unix_ms(now).unwrap();
let converted = unix_ms_to_system_time(ms).unwrap();
let diff = now
.duration_since(converted)
.or_else(|_| converted.duration_since(now))
.unwrap();
assert!(diff.as_millis() < 2);
}
#[test]
fn test_memcached_record_serialization() {
let entry = CacheEntry::new(
StatusCode::OK,
http::Version::HTTP_11,
vec![("content-type".to_string(), b"application/json".to_vec())],
Bytes::from_static(b"{\"test\":true}"),
);
let record = MemcachedRecord {
entry: entry.clone(),
expires_at_ms: 1000000,
stale_until_ms: 2000000,
};
assert_eq!(record.entry.status, StatusCode::OK);
assert_eq!(record.entry.version, http::Version::HTTP_11);
assert_eq!(record.entry.body, Bytes::from_static(b"{\"test\":true}"));
assert_eq!(record.expires_at_ms, 1000000);
assert_eq!(record.stale_until_ms, 2000000);
}
#[test]
fn test_builder_defaults() {
let builder = MemcachedBackendBuilder::default();
assert_eq!(builder.namespace, "tower_http_cache");
assert_eq!(builder.max_connections, 10);
assert_eq!(builder.min_connections, 2);
assert_eq!(builder.connection_timeout, Duration::from_secs(30));
}
#[test]
fn test_builder_customization() {
let builder = MemcachedBackendBuilder::default()
.address("127.0.0.1:11211")
.namespace("custom")
.max_connections(20)
.min_connections(5)
.connection_timeout(Duration::from_secs(10));
assert_eq!(builder.address, Some("127.0.0.1:11211".to_string()));
assert_eq!(builder.namespace, "custom");
assert_eq!(builder.max_connections, 20);
assert_eq!(builder.min_connections, 5);
assert_eq!(builder.connection_timeout, Duration::from_secs(10));
}
}