use super::Result;
use super::cache_trait::Cache;
use async_trait::async_trait;
use memcache_async::ascii::Protocol;
use reinhardt_core::exception::Error;
use serde::{Deserialize, Serialize};
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tokio_util::compat::{Compat, TokioAsyncReadCompatExt};
type MemcachedProtocol = Protocol<Compat<TcpStream>>;
#[derive(Debug, Clone)]
pub struct MemcachedConfig {
pub servers: Vec<String>,
pub pool_size: usize,
pub timeout_ms: u64,
}
impl Default for MemcachedConfig {
fn default() -> Self {
Self {
servers: vec!["127.0.0.1:11211".to_string()],
pool_size: 10,
timeout_ms: 1000,
}
}
}
pub struct MemcachedCache {
servers: Vec<Mutex<MemcachedProtocol>>,
}
impl MemcachedCache {
pub async fn new(config: MemcachedConfig) -> Result<Self> {
if config.servers.is_empty() {
return Err(Error::Http("No Memcached servers specified".to_string()));
}
let mut protocols = Vec::new();
let mut last_error = None;
for server_addr in &config.servers {
match Self::connect_to_server(server_addr).await {
Ok(protocol) => {
protocols.push(Mutex::new(protocol));
}
Err(e) => {
eprintln!(
"Warning: Failed to connect to Memcached server {}: {}",
server_addr, e
);
last_error = Some(e);
}
}
}
if protocols.is_empty() {
return Err(last_error.unwrap_or_else(|| {
Error::Http("Failed to connect to any Memcached server".to_string())
}));
}
Ok(Self { servers: protocols })
}
async fn connect_to_server(server_addr: &str) -> Result<MemcachedProtocol> {
let stream = TcpStream::connect(server_addr)
.await
.map_err(|e| Error::Http(format!("Failed to connect to Memcached: {}", e)))?;
let compat_stream = stream.compat();
Ok(Protocol::new(compat_stream))
}
fn get_server_index_for_key(&self, key: &str) -> usize {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
let hash = hasher.finish();
(hash as usize) % self.servers.len()
}
fn get_server(&self, index: usize) -> &Mutex<MemcachedProtocol> {
&self.servers[index % self.servers.len()]
}
pub async fn from_url(url: &str) -> Result<Self> {
let config = MemcachedConfig {
servers: vec![url.to_string()],
..Default::default()
};
Self::new(config).await
}
}
#[async_trait]
impl Cache for MemcachedCache {
async fn get<T>(&self, key: &str) -> Result<Option<T>>
where
T: for<'de> Deserialize<'de> + Send,
{
let start_index = self.get_server_index_for_key(key);
let server_count = self.servers.len();
for attempt in 0..server_count {
let index = (start_index + attempt) % server_count;
let server = self.get_server(index);
let mut protocol = server.lock().await;
match protocol.get(&key).await {
Ok(value) => {
if value.is_empty() {
return Ok(None);
}
let deserialized: T = serde_json::from_slice(&value).map_err(|e| {
Error::Serialization(format!("Failed to deserialize value: {}", e))
})?;
return Ok(Some(deserialized));
}
Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(None),
Err(e) => {
if attempt < server_count - 1 {
eprintln!(
"Warning: Get operation failed on server {}, trying next: {}",
index, e
);
} else {
return Err(Error::Http(format!("Memcached get error: {}", e)));
}
}
}
}
Err(Error::Http("All Memcached servers failed".to_string()))
}
async fn set<T>(&self, key: &str, value: &T, ttl: Option<Duration>) -> Result<()>
where
T: Serialize + Send + Sync,
{
let serialized = serde_json::to_vec(value)
.map_err(|e| Error::Serialization(format!("Failed to serialize value: {}", e)))?;
let expiration = ttl.map(|d| d.as_secs() as u32).unwrap_or(0);
let start_index = self.get_server_index_for_key(key);
let server_count = self.servers.len();
for attempt in 0..server_count {
let index = (start_index + attempt) % server_count;
let server = self.get_server(index);
let mut protocol = server.lock().await;
match protocol.set(&key, &serialized, expiration).await {
Ok(_) => return Ok(()),
Err(e) => {
if attempt < server_count - 1 {
eprintln!(
"Warning: Set operation failed on server {}, trying next: {}",
index, e
);
} else {
return Err(Error::Http(format!("Memcached set error: {}", e)));
}
}
}
}
Err(Error::Http("All Memcached servers failed".to_string()))
}
async fn delete(&self, key: &str) -> Result<()> {
let start_index = self.get_server_index_for_key(key);
let server_count = self.servers.len();
for attempt in 0..server_count {
let index = (start_index + attempt) % server_count;
let server = self.get_server(index);
let mut protocol = server.lock().await;
match protocol.set(&key, &[], 1).await {
Ok(_) => return Ok(()),
Err(e) => {
if attempt < server_count - 1 {
eprintln!(
"Warning: Delete operation failed on server {}, trying next: {}",
index, e
);
} else {
return Err(Error::Http(format!("Memcached delete error: {}", e)));
}
}
}
}
Err(Error::Http("All Memcached servers failed".to_string()))
}
async fn has_key(&self, key: &str) -> Result<bool> {
let start_index = self.get_server_index_for_key(key);
let server_count = self.servers.len();
for attempt in 0..server_count {
let index = (start_index + attempt) % server_count;
let server = self.get_server(index);
let mut protocol = server.lock().await;
match protocol.get(&key).await {
Ok(value) => {
return Ok(!value.is_empty());
}
Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(false),
Err(e) => {
if attempt < server_count - 1 {
eprintln!(
"Warning: Has_key operation failed on server {}, trying next: {}",
index, e
);
} else {
return Err(Error::Http(format!("Memcached has_key error: {}", e)));
}
}
}
}
Err(Error::Http("All Memcached servers failed".to_string()))
}
async fn clear(&self) -> Result<()> {
let mut last_error = None;
let mut success_count = 0;
for server in &self.servers {
let mut protocol = server.lock().await;
match protocol.flush().await {
Ok(_) => success_count += 1,
Err(e) => {
eprintln!("Warning: Failed to clear cache on one server: {}", e);
last_error = Some(Error::Http(format!("Memcached clear error: {}", e)));
}
}
}
if success_count > 0 {
Ok(())
} else {
Err(last_error
.unwrap_or_else(|| Error::Http("Failed to clear cache on all servers".to_string())))
}
}
}