use crate::{
AppendEventRequest, CreateRequest, DeleteRequest, Event, Events, GetRequest, KEY_PREFIX_TEMP,
ListRequest, Session, SessionService, State, state_utils,
};
use adk_core::Result;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use fred::clients::Transaction;
use fred::prelude::*;
use serde_json::Value;
use std::collections::HashMap;
use std::time::Duration;
use tracing::instrument;
use uuid::Uuid;
#[derive(Debug, Clone)]
pub struct RedisSessionConfig {
pub url: String,
pub ttl: Option<Duration>,
pub cluster_nodes: Option<Vec<String>>,
}
pub fn session_key(app: &str, user: &str, session: &str) -> String {
format!("{app}:{user}:{session}")
}
pub fn events_key(app: &str, user: &str, session: &str) -> String {
format!("{app}:{user}:{session}:events")
}
pub fn app_state_key(app: &str) -> String {
format!("app_state:{app}")
}
pub fn user_state_key(app: &str, user: &str) -> String {
format!("user_state:{app}:{user}")
}
pub fn index_key(app: &str, user: &str) -> String {
format!("sessions_idx:{app}:{user}")
}
fn lookup_key(session: &str) -> String {
format!("session_lookup:{session}")
}
pub struct RedisSessionService {
client: Client,
ttl: Option<Duration>,
}
impl RedisSessionService {
pub async fn new(config: RedisSessionConfig) -> Result<Self> {
let redis_config = if let Some(ref nodes) = config.cluster_nodes {
let hosts: Vec<(String, u16)> = nodes
.iter()
.map(|n| {
let parts: Vec<&str> = n.rsplitn(2, ':').collect();
if parts.len() == 2 {
let port = parts[0].parse::<u16>().unwrap_or(6379);
(parts[1].to_string(), port)
} else {
(n.clone(), 6379)
}
})
.collect();
Config { server: ServerConfig::new_clustered(hosts), ..Default::default() }
} else {
Config::from_url(&config.url)
.map_err(|e| adk_core::AdkError::session(format!("redis connection failed: {e}")))?
};
let client = Builder::from_config(redis_config)
.build()
.map_err(|e| adk_core::AdkError::session(format!("redis connection failed: {e}")))?;
client
.init()
.await
.map_err(|e| adk_core::AdkError::session(format!("redis connection failed: {e}")))?;
Ok(Self { client, ttl: config.ttl })
}
async fn apply_ttl(&self, session_k: &str, events_k: &str) -> Result<()> {
if let Some(ttl) = self.ttl {
let seconds = ttl.as_secs() as i64;
let _: () =
self.client.expire(session_k, seconds, None).await.map_err(|e| {
adk_core::AdkError::session(format!("redis expire failed: {e}"))
})?;
let _: () =
self.client.expire(events_k, seconds, None).await.map_err(|e| {
adk_core::AdkError::session(format!("redis expire failed: {e}"))
})?;
}
Ok(())
}
async fn read_state_hash(&self, key: &str) -> Result<HashMap<String, Value>> {
let raw: HashMap<String, String> = self
.client
.hgetall(key)
.await
.map_err(|e| adk_core::AdkError::session(format!("redis hgetall failed: {e}")))?;
let mut map = HashMap::new();
for (k, v) in raw {
let val: Value = serde_json::from_str(&v).unwrap_or(Value::String(v));
map.insert(k, val);
}
Ok(map)
}
async fn write_state_hash(
trx: &Transaction,
key: &str,
state: &HashMap<String, Value>,
) -> Result<()> {
if state.is_empty() {
return Ok(());
}
let mut fields: Vec<(String, String)> = Vec::with_capacity(state.len());
for (k, v) in state {
let serialized = serde_json::to_string(v).map_err(|e| {
adk_core::AdkError::session(format!("serialize state value failed: {e}"))
})?;
fields.push((k.clone(), serialized));
}
let _: () = trx
.hset(key, fields)
.await
.map_err(|e| adk_core::AdkError::session(format!("redis hset failed: {e}")))?;
Ok(())
}
}
#[async_trait]
impl SessionService for RedisSessionService {
#[instrument(skip_all, fields(app_name = %req.app_name, user_id = %req.user_id))]
async fn create(&self, req: CreateRequest) -> Result<Box<dyn Session>> {
let session_id = req.session_id.unwrap_or_else(|| Uuid::new_v4().to_string());
let now = Utc::now();
let (app_delta, user_delta, session_state) = state_utils::extract_state_deltas(&req.state);
let mut app_state = self.read_state_hash(&app_state_key(&req.app_name)).await?;
app_state.extend(app_delta);
let mut user_state =
self.read_state_hash(&user_state_key(&req.app_name, &req.user_id)).await?;
user_state.extend(user_delta);
let merged_state = state_utils::merge_states(&app_state, &user_state, &session_state);
let trx = self.client.multi();
let session_k = session_key(&req.app_name, &req.user_id, &session_id);
let state_json = serde_json::to_string(&merged_state)
.map_err(|e| adk_core::AdkError::session(format!("serialize state failed: {e}")))?;
let session_fields: Vec<(String, String)> = vec![
("app_name".into(), req.app_name.clone()),
("user_id".into(), req.user_id.clone()),
("session_id".into(), session_id.clone()),
("state".into(), state_json),
("created_at".into(), now.to_rfc3339()),
("updated_at".into(), now.to_rfc3339()),
];
let _: () = trx
.hset(&session_k, session_fields)
.await
.map_err(|e| adk_core::AdkError::session(format!("redis hset failed: {e}")))?;
Self::write_state_hash(&trx, &app_state_key(&req.app_name), &app_state).await?;
Self::write_state_hash(&trx, &user_state_key(&req.app_name, &req.user_id), &user_state)
.await?;
let idx_k = index_key(&req.app_name, &req.user_id);
let _: () = trx
.sadd(&idx_k, &session_id)
.await
.map_err(|e| adk_core::AdkError::session(format!("redis sadd failed: {e}")))?;
let lk = lookup_key(&session_id);
let lookup_val = format!("{}:{}", req.app_name, req.user_id);
let _: () = trx
.set(&lk, lookup_val, None, None, false)
.await
.map_err(|e| adk_core::AdkError::session(format!("redis set failed: {e}")))?;
let _: () = trx
.exec(true)
.await
.map_err(|e| adk_core::AdkError::session(format!("redis transaction failed: {e}")))?;
let events_k = events_key(&req.app_name, &req.user_id, &session_id);
self.apply_ttl(&session_k, &events_k).await?;
Ok(Box::new(RedisSession {
app_name: req.app_name,
user_id: req.user_id,
session_id,
state: merged_state,
events: Vec::new(),
updated_at: now,
}))
}
#[instrument(skip_all, fields(app_name = %req.app_name, user_id = %req.user_id, session_id = %req.session_id))]
async fn get(&self, req: GetRequest) -> Result<Box<dyn Session>> {
let session_k = session_key(&req.app_name, &req.user_id, &req.session_id);
let exists: bool = self
.client
.exists(&session_k)
.await
.map_err(|e| adk_core::AdkError::session(format!("redis exists failed: {e}")))?;
if !exists {
return Err(adk_core::AdkError::session("session not found"));
}
let raw: HashMap<String, String> = self
.client
.hgetall(&session_k)
.await
.map_err(|e| adk_core::AdkError::session(format!("redis hgetall failed: {e}")))?;
let updated_at: DateTime<Utc> =
raw.get("updated_at").and_then(|s| s.parse().ok()).unwrap_or_else(Utc::now);
let session_state: HashMap<String, Value> =
raw.get("state").and_then(|s| serde_json::from_str(s).ok()).unwrap_or_default();
let events_k = events_key(&req.app_name, &req.user_id, &req.session_id);
let raw_events: Vec<(String, f64)> = self
.client
.zrange(&events_k, 0, -1, None, false, None, true)
.await
.map_err(|e| adk_core::AdkError::session(format!("redis zrange failed: {e}")))?;
let mut events: Vec<Event> = raw_events
.into_iter()
.filter_map(|(json, _score)| serde_json::from_str(&json).ok())
.collect();
if let Some(num) = req.num_recent_events {
let start = events.len().saturating_sub(num);
events = events[start..].to_vec();
}
if let Some(after) = req.after {
events.retain(|e| e.timestamp >= after);
}
Ok(Box::new(RedisSession {
app_name: req.app_name,
user_id: req.user_id,
session_id: req.session_id,
state: session_state,
events,
updated_at,
}))
}
#[instrument(skip_all, fields(app_name = %req.app_name, user_id = %req.user_id))]
async fn list(&self, req: ListRequest) -> Result<Vec<Box<dyn Session>>> {
let idx_k = index_key(&req.app_name, &req.user_id);
let session_ids: Vec<String> = self
.client
.smembers(&idx_k)
.await
.map_err(|e| adk_core::AdkError::session(format!("redis smembers failed: {e}")))?;
let offset = req.offset.unwrap_or(0);
let limit = req.limit.unwrap_or(usize::MAX);
let mut sessions: Vec<Box<dyn Session>> = Vec::new();
for sid in session_ids.into_iter().skip(offset).take(limit) {
let session_k = session_key(&req.app_name, &req.user_id, &sid);
let raw: HashMap<String, String> =
self.client.hgetall(&session_k).await.map_err(|e| {
adk_core::AdkError::session(format!("redis hgetall failed: {e}"))
})?;
if raw.is_empty() {
continue; }
let state: HashMap<String, Value> =
raw.get("state").and_then(|s| serde_json::from_str(s).ok()).unwrap_or_default();
let updated_at: DateTime<Utc> =
raw.get("updated_at").and_then(|s| s.parse().ok()).unwrap_or_else(Utc::now);
sessions.push(Box::new(RedisSession {
app_name: req.app_name.clone(),
user_id: req.user_id.clone(),
session_id: sid,
state,
events: Vec::new(),
updated_at,
}));
}
Ok(sessions)
}
#[instrument(skip_all, fields(app_name = %req.app_name, user_id = %req.user_id, session_id = %req.session_id))]
async fn delete(&self, req: DeleteRequest) -> Result<()> {
let session_k = session_key(&req.app_name, &req.user_id, &req.session_id);
let events_k = events_key(&req.app_name, &req.user_id, &req.session_id);
let idx_k = index_key(&req.app_name, &req.user_id);
let lk = lookup_key(&req.session_id);
let trx = self.client.multi();
let _: () = trx
.del(vec![session_k, events_k, lk])
.await
.map_err(|e| adk_core::AdkError::session(format!("redis del failed: {e}")))?;
let _: () = trx
.srem(&idx_k, &req.session_id)
.await
.map_err(|e| adk_core::AdkError::session(format!("redis srem failed: {e}")))?;
let _: () = trx
.exec(true)
.await
.map_err(|e| adk_core::AdkError::session(format!("redis transaction failed: {e}")))?;
Ok(())
}
#[instrument(skip_all, fields(session_id = %session_id))]
async fn append_event(&self, session_id: &str, mut event: Event) -> Result<()> {
event.actions.state_delta.retain(|k, _| !k.starts_with(KEY_PREFIX_TEMP));
let lk = lookup_key(session_id);
let lookup_val: Option<String> = self
.client
.get(&lk)
.await
.map_err(|e| adk_core::AdkError::session(format!("redis get failed: {e}")))?;
let lookup_val =
lookup_val.ok_or_else(|| adk_core::AdkError::session("session not found"))?;
let (app_name, user_id) = lookup_val
.split_once(':')
.ok_or_else(|| adk_core::AdkError::session("corrupt session lookup entry"))?;
let session_k = session_key(app_name, user_id, session_id);
let exists: bool = self
.client
.exists(&session_k)
.await
.map_err(|e| adk_core::AdkError::session(format!("redis exists failed: {e}")))?;
if !exists {
return Err(adk_core::AdkError::session("session not found"));
}
let raw: HashMap<String, String> = self
.client
.hgetall(&session_k)
.await
.map_err(|e| adk_core::AdkError::session(format!("redis hgetall failed: {e}")))?;
let existing_state: HashMap<String, Value> =
raw.get("state").and_then(|s| serde_json::from_str(s).ok()).unwrap_or_default();
let (_, _, mut session_state) = state_utils::extract_state_deltas(&existing_state);
let app_state = self.read_state_hash(&app_state_key(app_name)).await?;
let user_state = self.read_state_hash(&user_state_key(app_name, user_id)).await?;
let (app_delta, user_delta, session_delta) =
state_utils::extract_state_deltas(&event.actions.state_delta);
let mut new_app_state = app_state;
new_app_state.extend(app_delta);
let mut new_user_state = user_state;
new_user_state.extend(user_delta);
session_state.extend(session_delta);
let merged_state =
state_utils::merge_states(&new_app_state, &new_user_state, &session_state);
let event_json = serde_json::to_string(&event)
.map_err(|e| adk_core::AdkError::session(format!("serialize failed: {e}")))?;
let score = event.timestamp.timestamp_millis() as f64;
let trx = self.client.multi();
Self::write_state_hash(&trx, &app_state_key(app_name), &new_app_state).await?;
Self::write_state_hash(&trx, &user_state_key(app_name, user_id), &new_user_state).await?;
let merged_state_json = serde_json::to_string(&merged_state)
.map_err(|e| adk_core::AdkError::session(format!("serialize state failed: {e}")))?;
let _: () = trx
.hset(
&session_k,
vec![
("state".to_string(), merged_state_json),
("updated_at".to_string(), event.timestamp.to_rfc3339()),
],
)
.await
.map_err(|e| adk_core::AdkError::session(format!("redis hset failed: {e}")))?;
let events_k = events_key(app_name, user_id, session_id);
let _: () = trx
.zadd(&events_k, None, None, false, false, (score, event_json))
.await
.map_err(|e| adk_core::AdkError::session(format!("redis zadd failed: {e}")))?;
let _: () = trx
.exec(true)
.await
.map_err(|e| adk_core::AdkError::session(format!("redis transaction failed: {e}")))?;
self.apply_ttl(&session_k, &events_k).await?;
Ok(())
}
#[instrument(skip_all, fields(
app_name = %req.identity.app_name,
user_id = %req.identity.user_id,
session_id = %req.identity.session_id,
))]
async fn append_event_for_identity(&self, req: AppendEventRequest) -> Result<()> {
let mut event = req.event;
event.actions.state_delta.retain(|k, _| !k.starts_with(KEY_PREFIX_TEMP));
let app_name = req.identity.app_name.as_ref();
let user_id = req.identity.user_id.as_ref();
let sid = req.identity.session_id.as_ref();
let session_k = session_key(app_name, user_id, sid);
let exists: bool = self
.client
.exists(&session_k)
.await
.map_err(|e| adk_core::AdkError::session(format!("redis exists failed: {e}")))?;
if !exists {
return Err(adk_core::AdkError::session("session not found"));
}
let raw: HashMap<String, String> = self
.client
.hgetall(&session_k)
.await
.map_err(|e| adk_core::AdkError::session(format!("redis hgetall failed: {e}")))?;
let existing_state: HashMap<String, Value> =
raw.get("state").and_then(|s| serde_json::from_str(s).ok()).unwrap_or_default();
let (_, _, mut session_state) = state_utils::extract_state_deltas(&existing_state);
let app_state = self.read_state_hash(&app_state_key(app_name)).await?;
let user_state = self.read_state_hash(&user_state_key(app_name, user_id)).await?;
let (app_delta, user_delta, session_delta) =
state_utils::extract_state_deltas(&event.actions.state_delta);
let mut new_app_state = app_state;
new_app_state.extend(app_delta);
let mut new_user_state = user_state;
new_user_state.extend(user_delta);
session_state.extend(session_delta);
let merged_state =
state_utils::merge_states(&new_app_state, &new_user_state, &session_state);
let event_json = serde_json::to_string(&event)
.map_err(|e| adk_core::AdkError::session(format!("serialize failed: {e}")))?;
let score = event.timestamp.timestamp_millis() as f64;
let trx = self.client.multi();
Self::write_state_hash(&trx, &app_state_key(app_name), &new_app_state).await?;
Self::write_state_hash(&trx, &user_state_key(app_name, user_id), &new_user_state).await?;
let merged_state_json = serde_json::to_string(&merged_state)
.map_err(|e| adk_core::AdkError::session(format!("serialize state failed: {e}")))?;
let _: () = trx
.hset(
&session_k,
vec![
("state".to_string(), merged_state_json),
("updated_at".to_string(), event.timestamp.to_rfc3339()),
],
)
.await
.map_err(|e| adk_core::AdkError::session(format!("redis hset failed: {e}")))?;
let events_k = events_key(app_name, user_id, sid);
let _: () = trx
.zadd(&events_k, None, None, false, false, (score, event_json))
.await
.map_err(|e| adk_core::AdkError::session(format!("redis zadd failed: {e}")))?;
let _: () = trx
.exec(true)
.await
.map_err(|e| adk_core::AdkError::session(format!("redis transaction failed: {e}")))?;
self.apply_ttl(&session_k, &events_k).await?;
Ok(())
}
#[instrument(skip_all, fields(app_name = %app_name, user_id = %user_id))]
async fn delete_all_sessions(&self, app_name: &str, user_id: &str) -> Result<()> {
let idx_k = index_key(app_name, user_id);
let session_ids: Vec<String> = self
.client
.smembers(&idx_k)
.await
.map_err(|e| adk_core::AdkError::session(format!("redis smembers failed: {e}")))?;
if session_ids.is_empty() {
return Ok(());
}
let trx = self.client.multi();
for sid in &session_ids {
let sk = session_key(app_name, user_id, sid);
let ek = events_key(app_name, user_id, sid);
let lk = lookup_key(sid);
let _: () = trx
.del(vec![sk, ek, lk])
.await
.map_err(|e| adk_core::AdkError::session(format!("redis del failed: {e}")))?;
}
let _: () = trx
.del(&idx_k)
.await
.map_err(|e| adk_core::AdkError::session(format!("redis del failed: {e}")))?;
let _: () = trx
.exec(true)
.await
.map_err(|e| adk_core::AdkError::session(format!("redis transaction failed: {e}")))?;
Ok(())
}
#[instrument(skip_all)]
async fn health_check(&self) -> Result<()> {
let _: String = self
.client
.ping(None)
.await
.map_err(|e| adk_core::AdkError::session(format!("health check failed: {e}")))?;
Ok(())
}
}
struct RedisSession {
app_name: String,
user_id: String,
session_id: String,
state: HashMap<String, Value>,
events: Vec<Event>,
updated_at: DateTime<Utc>,
}
impl Session for RedisSession {
fn id(&self) -> &str {
&self.session_id
}
fn app_name(&self) -> &str {
&self.app_name
}
fn user_id(&self) -> &str {
&self.user_id
}
fn state(&self) -> &dyn State {
self
}
fn events(&self) -> &dyn Events {
self
}
fn last_update_time(&self) -> DateTime<Utc> {
self.updated_at
}
}
impl State for RedisSession {
fn get(&self, key: &str) -> Option<Value> {
self.state.get(key).cloned()
}
fn set(&mut self, key: String, value: Value) {
if let Err(msg) = adk_core::validate_state_key(&key) {
tracing::warn!(key = %key, "rejecting invalid state key: {msg}");
return;
}
self.state.insert(key, value);
}
fn all(&self) -> HashMap<String, Value> {
self.state.clone()
}
}
impl Events for RedisSession {
fn all(&self) -> Vec<Event> {
self.events.clone()
}
fn len(&self) -> usize {
self.events.len()
}
fn at(&self, index: usize) -> Option<&Event> {
self.events.get(index)
}
}