use async_trait::async_trait;
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
#[async_trait]
pub trait MemoryStorage: Send + Sync {
async fn get(&self, key: &str) -> Result<Option<Value>, String>;
async fn set(&self, key: &str, value: Value) -> Result<(), String>;
async fn delete(&self, key: &str) -> Result<(), String>;
async fn list(&self, pattern: Option<&str>) -> Result<Vec<String>, String>;
async fn clear(&self) -> Result<(), String>;
}
pub struct InMemoryStorage {
data: Arc<RwLock<HashMap<String, Value>>>,
}
impl InMemoryStorage {
pub fn new() -> Self {
Self {
data: Arc::new(RwLock::new(HashMap::new())),
}
}
}
impl Default for InMemoryStorage {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl MemoryStorage for InMemoryStorage {
async fn get(&self, key: &str) -> Result<Option<Value>, String> {
let data = self.data.read().await;
Ok(data.get(key).cloned())
}
async fn set(&self, key: &str, value: Value) -> Result<(), String> {
let mut data = self.data.write().await;
data.insert(key.to_string(), value);
Ok(())
}
async fn delete(&self, key: &str) -> Result<(), String> {
let mut data = self.data.write().await;
data.remove(key);
Ok(())
}
async fn list(&self, pattern: Option<&str>) -> Result<Vec<String>, String> {
let data = self.data.read().await;
let keys: Vec<String> = if let Some(pat) = pattern {
data.keys().filter(|k| k.contains(pat)).cloned().collect()
} else {
data.keys().cloned().collect()
};
Ok(keys)
}
async fn clear(&self) -> Result<(), String> {
let mut data = self.data.write().await;
data.clear();
Ok(())
}
}
#[cfg(feature = "redis")]
pub mod redis {
use super::*;
use redis::AsyncCommands;
pub struct RedisStorage {
client: redis::Client,
}
impl RedisStorage {
pub fn new(url: &str) -> Result<Self, String> {
let client = redis::Client::open(url)
.map_err(|e| format!("Failed to connect to Redis: {}", e))?;
Ok(Self { client })
}
}
#[async_trait]
impl MemoryStorage for RedisStorage {
async fn get(&self, key: &str) -> Result<Option<Value>, String> {
let mut conn = self
.client
.get_async_connection()
.await
.map_err(|e| e.to_string())?;
let result: Option<String> = conn.get(key).await.map_err(|e| e.to_string())?;
match result {
Some(s) => serde_json::from_str(&s)
.map(Some)
.map_err(|e| e.to_string()),
None => Ok(None),
}
}
async fn set(&self, key: &str, value: Value) -> Result<(), String> {
let mut conn = self
.client
.get_async_connection()
.await
.map_err(|e| e.to_string())?;
let serialized = serde_json::to_string(&value).map_err(|e| e.to_string())?;
conn.set(key, serialized).await.map_err(|e| e.to_string())
}
async fn delete(&self, key: &str) -> Result<(), String> {
let mut conn = self
.client
.get_async_connection()
.await
.map_err(|e| e.to_string())?;
conn.del(key).await.map_err(|e| e.to_string())
}
async fn list(&self, pattern: Option<&str>) -> Result<Vec<String>, String> {
let mut conn = self
.client
.get_async_connection()
.await
.map_err(|e| e.to_string())?;
let pat = pattern.unwrap_or("*");
conn.keys(pat).await.map_err(|e| e.to_string())
}
async fn clear(&self) -> Result<(), String> {
let mut conn = self
.client
.get_async_connection()
.await
.map_err(|e| e.to_string())?;
redis::cmd("FLUSHDB")
.query_async(&mut conn)
.await
.map_err(|e| e.to_string())
}
}
}
#[cfg(feature = "postgres")]
pub mod postgres {
use super::*;
use sqlx::PgPool;
pub struct PostgresStorage {
pool: PgPool,
}
impl PostgresStorage {
pub async fn new(url: &str) -> Result<Self, String> {
let pool = PgPool::connect(url)
.await
.map_err(|e| format!("Failed to connect to Postgres: {}", e))?;
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS agent_memory (
key TEXT PRIMARY KEY,
value JSONB NOT NULL,
created_at TIMESTAMP DEFAULT NOW(),
updated_at TIMESTAMP DEFAULT NOW()
)
"#,
)
.execute(&pool)
.await
.map_err(|e| e.to_string())?;
Ok(Self { pool })
}
}
#[async_trait]
impl MemoryStorage for PostgresStorage {
async fn get(&self, key: &str) -> Result<Option<Value>, String> {
let result: Option<(Value,)> =
sqlx::query_as("SELECT value FROM agent_memory WHERE key = $1")
.bind(key)
.fetch_optional(&self.pool)
.await
.map_err(|e| e.to_string())?;
Ok(result.map(|(v,)| v))
}
async fn set(&self, key: &str, value: Value) -> Result<(), String> {
sqlx::query(
r#"
INSERT INTO agent_memory (key, value, updated_at)
VALUES ($1, $2, NOW())
ON CONFLICT (key) DO UPDATE SET value = $2, updated_at = NOW()
"#,
)
.bind(key)
.bind(&value)
.execute(&self.pool)
.await
.map_err(|e| e.to_string())?;
Ok(())
}
async fn delete(&self, key: &str) -> Result<(), String> {
sqlx::query("DELETE FROM agent_memory WHERE key = $1")
.bind(key)
.execute(&self.pool)
.await
.map_err(|e| e.to_string())?;
Ok(())
}
async fn list(&self, pattern: Option<&str>) -> Result<Vec<String>, String> {
let keys: Vec<(String,)> = if let Some(pat) = pattern {
sqlx::query_as("SELECT key FROM agent_memory WHERE key LIKE $1")
.bind(format!("%{}%", pat))
.fetch_all(&self.pool)
.await
} else {
sqlx::query_as("SELECT key FROM agent_memory")
.fetch_all(&self.pool)
.await
}
.map_err(|e| e.to_string())?;
Ok(keys.into_iter().map(|(k,)| k).collect())
}
async fn clear(&self) -> Result<(), String> {
sqlx::query("TRUNCATE TABLE agent_memory")
.execute(&self.pool)
.await
.map_err(|e| e.to_string())?;
Ok(())
}
}
}
pub fn create_memory_storage(
memory_type: &str,
_config: Option<&str>,
) -> Result<Arc<dyn MemoryStorage>, String> {
match memory_type {
"memory" | "buffer" | "in-memory" => Ok(Arc::new(InMemoryStorage::new())),
#[cfg(feature = "redis")]
"redis" => {
let url = config.ok_or("Redis URL required")?;
Ok(Arc::new(redis::RedisStorage::new(url)?))
}
#[cfg(feature = "postgres")]
"postgres" => {
let url = config.ok_or("Postgres URL required")?;
let storage = tokio::task::block_in_place(|| {
tokio::runtime::Handle::current()
.block_on(async { postgres::PostgresStorage::new(url).await })
})?;
Ok(Arc::new(storage))
}
_ => Err(format!("Unknown memory type: {}", memory_type)),
}
}