use std::sync::Arc;
use async_trait::async_trait;
use dashmap::DashMap;
use parking_lot::Mutex;
use crate::core::{
Event, GetSessionConfig, ListSessionsResponse, Session, SessionMeta, SessionService, State,
StateDelta, StateScope,
};
use crate::error::{Error, Result};
#[derive(Debug, Default)]
pub struct InMemorySessionService {
sessions: DashMap<(String, String, String), Arc<Mutex<Session>>>,
app_state: DashMap<String, Arc<Mutex<State>>>,
user_state: DashMap<(String, String), Arc<Mutex<State>>>,
}
impl InMemorySessionService {
#[must_use]
pub fn new() -> Self {
Self::default()
}
fn key(app: &str, user: &str, sid: &str) -> (String, String, String) {
(app.to_string(), user.to_string(), sid.to_string())
}
fn app_slot(&self, app: &str) -> Arc<Mutex<State>> {
self.app_state
.entry(app.to_string())
.or_insert_with(|| Arc::new(Mutex::new(State::new())))
.value()
.clone()
}
fn user_slot(&self, app: &str, user: &str) -> Arc<Mutex<State>> {
self.user_state
.entry((app.to_string(), user.to_string()))
.or_insert_with(|| Arc::new(Mutex::new(State::new())))
.value()
.clone()
}
fn mirror_event(
&self,
key: (String, String, String),
seed: &Session,
event: &Event,
session_delta: &StateDelta,
) {
let arc = self
.sessions
.entry(key)
.or_insert_with(|| Arc::new(Mutex::new(seed.clone())))
.value()
.clone();
let mut stored = arc.lock();
if !stored.events.iter().any(|e| e.id == event.id) {
stored.events.push(event.clone());
}
stored.state.apply(session_delta);
if seed.last_update_time > stored.last_update_time {
stored.last_update_time = seed.last_update_time;
}
}
fn overlay_state(&self, sess: &mut Session) {
let app_view = self.app_slot(&sess.app_name);
let user_view = self.user_slot(&sess.app_name, &sess.user_id);
let mut merged = State::new();
for (k, v) in app_view.lock().iter() {
merged.set(k.clone(), v.clone());
}
for (k, v) in user_view.lock().iter() {
merged.set(k.clone(), v.clone());
}
for (k, v) in sess.state.iter() {
merged.set(k.clone(), v.clone());
}
sess.state = merged;
}
}
#[async_trait]
impl SessionService for InMemorySessionService {
async fn create_session(
&self,
app_name: &str,
user_id: &str,
state: Option<State>,
id: Option<&str>,
) -> Result<Session> {
let sid = id
.map(str::to_string)
.unwrap_or_else(crate::core::services::new_session_id);
let key = Self::key(app_name, user_id, &sid);
if self.sessions.contains_key(&key) {
return Err(Error::already_exists(format!("session {sid}")));
}
let mut s = Session::new(app_name, user_id, sid);
if let Some(state) = state {
let (app_delta, user_delta, session_delta, _temp_delta) =
State::partition_by_scope(&state.map);
if !app_delta.is_empty() {
self.app_slot(app_name).lock().apply(&app_delta);
}
if !user_delta.is_empty() {
self.user_slot(app_name, user_id).lock().apply(&user_delta);
}
s.state = State::from_iter(session_delta);
}
let arc = Arc::new(Mutex::new(s.clone()));
self.sessions.insert(key, arc);
self.overlay_state(&mut s);
Ok(s)
}
async fn get_session(
&self,
app_name: &str,
user_id: &str,
session_id: &str,
cfg: GetSessionConfig,
) -> Result<Option<Session>> {
let key = Self::key(app_name, user_id, session_id);
let Some(arc) = self.sessions.get(&key) else {
return Ok(None);
};
let mut snap = arc.lock().clone();
self.overlay_state(&mut snap);
Ok(Some(apply_filter(snap, &cfg)))
}
async fn list_sessions(&self, app_name: &str, user_id: &str) -> Result<ListSessionsResponse> {
let sessions: Vec<SessionMeta> = self
.sessions
.iter()
.filter(|kv| kv.key().0 == app_name && kv.key().1 == user_id)
.map(|kv| {
let s = kv.value().lock();
SessionMeta {
id: s.id.clone(),
app_name: s.app_name.clone(),
user_id: s.user_id.clone(),
last_update_time: s.last_update_time,
}
})
.collect();
Ok(ListSessionsResponse { sessions })
}
async fn delete_session(&self, app_name: &str, user_id: &str, session_id: &str) -> Result<()> {
self.sessions
.remove(&Self::key(app_name, user_id, session_id));
Ok(())
}
async fn append_event(&self, session: &mut Session, mut event: Event) -> Result<Event> {
if event.partial == Some(true) {
return Ok(event);
}
let (app_delta, user_delta, session_delta, temp_delta) =
State::partition_by_scope(&event.actions.state_delta);
for (k, v) in &temp_delta {
session.state.set(k.clone(), v.clone());
}
if !app_delta.is_empty() {
self.app_slot(&session.app_name).lock().apply(&app_delta);
}
if !user_delta.is_empty() {
self.user_slot(&session.app_name, &session.user_id)
.lock()
.apply(&user_delta);
}
session.state.apply(&app_delta);
session.state.apply(&user_delta);
session.state.apply(&session_delta);
event.actions.state_delta = session_delta.clone();
session.last_update_time = crate::core::session::now_secs();
session.events.push(event.clone());
let key = Self::key(&session.app_name, &session.user_id, &session.id);
let mut seed = session.clone();
seed.state = State::from_iter(
session
.state
.iter()
.filter(|(k, _)| StateScope::of(k) == StateScope::Session)
.map(|(k, v)| (k.clone(), v.clone())),
);
self.mirror_event(key, &seed, &event, &session_delta);
Ok(event)
}
async fn append_event_locked(
&self,
session_lock: &Arc<Mutex<Session>>,
event: Event,
) -> Result<Event> {
if event.partial == Some(true) {
return Ok(event);
}
let (app_delta, user_delta, session_delta, _temp_delta) =
State::partition_by_scope(&event.actions.state_delta);
let (event, key, session_only_snap) = {
let mut sess = session_lock.lock();
let event = crate::core::services::apply_event_to_session(&mut sess, event);
let key = Self::key(&sess.app_name, &sess.user_id, &sess.id);
let session_state = State::from_iter(
sess.state
.iter()
.filter(|(k, _)| StateScope::of(k) == StateScope::Session)
.map(|(k, v)| (k.clone(), v.clone())),
);
let mut snap = sess.clone();
snap.state = session_state;
(event, key, snap)
};
if !app_delta.is_empty() {
self.app_slot(&session_only_snap.app_name)
.lock()
.apply(&app_delta);
}
if !user_delta.is_empty() {
self.user_slot(&session_only_snap.app_name, &session_only_snap.user_id)
.lock()
.apply(&user_delta);
}
self.mirror_event(key, &session_only_snap, &event, &session_delta);
Ok(event)
}
}
#[cfg(test)]
mod race_tests {
use super::*;
use crate::core::LlmResponse;
#[tokio::test]
async fn concurrent_invocations_on_same_session_keep_all_events() {
let svc: Arc<dyn SessionService> = Arc::new(InMemorySessionService::new());
let s = svc
.create_session("app", "u", None, Some("sid"))
.await
.unwrap();
let inv1 = Arc::new(Mutex::new(s.clone()));
let inv2 = Arc::new(Mutex::new(s.clone()));
let mut ev1 = Event::new("agent", LlmResponse::default());
ev1.actions
.state_delta
.insert("from_inv1".into(), serde_json::json!(1));
let mut ev2 = Event::new("agent", LlmResponse::default());
ev2.actions
.state_delta
.insert("from_inv2".into(), serde_json::json!(2));
let ev1 = svc.append_event_locked(&inv1, ev1).await.unwrap();
let ev2 = svc.append_event_locked(&inv2, ev2).await.unwrap();
let mut ev3 = Event::new("agent", LlmResponse::default());
ev3.actions
.state_delta
.insert("from_inv1_again".into(), serde_json::json!(3));
let ev3 = svc.append_event_locked(&inv1, ev3).await.unwrap();
let stored = svc
.get_session("app", "u", "sid", Default::default())
.await
.unwrap()
.unwrap();
let ids: Vec<&str> = stored.events.iter().map(|e| e.id.as_str()).collect();
assert!(ids.contains(&ev1.id.as_str()), "lost inv1's first event");
assert!(ids.contains(&ev2.id.as_str()), "lost inv2's event");
assert!(ids.contains(&ev3.id.as_str()), "lost inv1's second event");
assert_eq!(stored.events.len(), 3);
assert_eq!(stored.state.get("from_inv1"), Some(&serde_json::json!(1)));
assert_eq!(stored.state.get("from_inv2"), Some(&serde_json::json!(2)));
assert_eq!(
stored.state.get("from_inv1_again"),
Some(&serde_json::json!(3))
);
}
#[tokio::test]
async fn app_scope_state_is_shared_across_sessions() {
let svc: Arc<dyn SessionService> = Arc::new(InMemorySessionService::new());
let s1 = svc
.create_session("app", "alice", None, None)
.await
.unwrap();
let lock1 = Arc::new(Mutex::new(s1.clone()));
let mut ev = Event::new("agent", crate::core::LlmResponse::default());
ev.actions
.state_delta
.insert("app:globals".into(), serde_json::json!({"tier": "premium"}));
svc.append_event_locked(&lock1, ev).await.unwrap();
let reloaded = svc
.get_session("app", "alice", &s1.id, Default::default())
.await
.unwrap()
.unwrap();
assert_eq!(
reloaded
.state
.get("app:globals")
.and_then(|v| v.get("tier")),
Some(&serde_json::Value::String("premium".into()))
);
let s2 = svc.create_session("app", "bob", None, None).await.unwrap();
let s2_loaded = svc
.get_session("app", "bob", &s2.id, Default::default())
.await
.unwrap()
.unwrap();
assert_eq!(
s2_loaded
.state
.get("app:globals")
.and_then(|v| v.get("tier")),
Some(&serde_json::Value::String("premium".into()))
);
}
#[tokio::test]
async fn user_scope_state_is_per_user() {
let svc: Arc<dyn SessionService> = Arc::new(InMemorySessionService::new());
let s1 = svc
.create_session("app", "alice", None, None)
.await
.unwrap();
let lock1 = Arc::new(Mutex::new(s1.clone()));
let mut ev = Event::new("agent", crate::core::LlmResponse::default());
ev.actions
.state_delta
.insert("user:lang".into(), serde_json::json!("en"));
svc.append_event_locked(&lock1, ev).await.unwrap();
let s2_alice = svc
.create_session("app", "alice", None, None)
.await
.unwrap();
let alice2 = svc
.get_session("app", "alice", &s2_alice.id, Default::default())
.await
.unwrap()
.unwrap();
assert_eq!(
alice2.state.get("user:lang"),
Some(&serde_json::Value::String("en".into()))
);
let s_bob = svc.create_session("app", "bob", None, None).await.unwrap();
let bob = svc
.get_session("app", "bob", &s_bob.id, Default::default())
.await
.unwrap()
.unwrap();
assert!(bob.state.get("user:lang").is_none());
}
#[tokio::test]
async fn concurrent_append_event_locked_preserves_every_event() {
let svc: Arc<dyn SessionService> = Arc::new(InMemorySessionService::new());
let s = svc.create_session("app", "u", None, None).await.unwrap();
let lock = Arc::new(Mutex::new(s));
const N: usize = 64;
let mut handles = Vec::with_capacity(N);
for i in 0..N {
let svc_c = svc.clone();
let lock_c = lock.clone();
handles.push(tokio::spawn(async move {
let ev = Event::new(format!("writer-{i}"), LlmResponse::default());
svc_c.append_event_locked(&lock_c, ev).await.unwrap();
}));
}
for h in handles {
h.await.unwrap();
}
let final_session = lock.lock();
assert_eq!(
final_session.events.len(),
N,
"every concurrent writer's event must survive (got {} of {})",
final_session.events.len(),
N
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn concurrent_app_writes_from_different_sessions_preserve_all_keys() {
let svc: Arc<dyn SessionService> = Arc::new(InMemorySessionService::new());
const N: usize = 32;
let mut locks = Vec::with_capacity(N);
for i in 0..N {
let s = svc
.create_session("app", &format!("u{i}"), None, None)
.await
.unwrap();
locks.push(Arc::new(Mutex::new(s)));
}
let barrier = Arc::new(tokio::sync::Barrier::new(N));
let mut handles = Vec::with_capacity(N);
for (i, lock) in locks.into_iter().enumerate() {
let svc_c = svc.clone();
let barrier_c = barrier.clone();
handles.push(tokio::spawn(async move {
barrier_c.wait().await;
let mut ev = Event::new("agent", LlmResponse::default());
ev.actions
.state_delta
.insert(format!("app:k{i}"), serde_json::json!(i));
svc_c.append_event_locked(&lock, ev).await.unwrap();
}));
}
for h in handles {
h.await.unwrap();
}
let reader = svc
.create_session("app", "reader", None, None)
.await
.unwrap();
let reloaded = svc
.get_session("app", "reader", &reader.id, Default::default())
.await
.unwrap()
.unwrap();
for i in 0..N {
let k = format!("app:k{i}");
assert!(
reloaded.state.get(&k).is_some(),
"missing {k}; got keys: {:?}",
reloaded
.state
.iter()
.map(|(k, _)| k.clone())
.collect::<Vec<_>>()
);
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn concurrent_user_writes_for_same_user_preserve_all_keys() {
let svc: Arc<dyn SessionService> = Arc::new(InMemorySessionService::new());
const N: usize = 32;
let mut locks = Vec::with_capacity(N);
for _ in 0..N {
let s = svc
.create_session("app", "alice", None, None)
.await
.unwrap();
locks.push(Arc::new(Mutex::new(s)));
}
let barrier = Arc::new(tokio::sync::Barrier::new(N));
let mut handles = Vec::with_capacity(N);
for (i, lock) in locks.into_iter().enumerate() {
let svc_c = svc.clone();
let barrier_c = barrier.clone();
handles.push(tokio::spawn(async move {
barrier_c.wait().await;
let mut ev = Event::new("agent", LlmResponse::default());
ev.actions
.state_delta
.insert(format!("user:k{i}"), serde_json::json!(i));
svc_c.append_event_locked(&lock, ev).await.unwrap();
}));
}
for h in handles {
h.await.unwrap();
}
let reader = svc
.create_session("app", "alice", None, None)
.await
.unwrap();
let reloaded = svc
.get_session("app", "alice", &reader.id, Default::default())
.await
.unwrap()
.unwrap();
for i in 0..N {
let k = format!("user:k{i}");
assert!(reloaded.state.get(&k).is_some(), "missing {k}");
}
}
#[tokio::test]
async fn create_session_with_initial_scoped_state_routes_correctly() {
let svc: Arc<dyn SessionService> = Arc::new(InMemorySessionService::new());
let mut state = State::new();
state.set("app:foo", serde_json::json!(1));
state.set("user:bar", serde_json::json!(2));
state.set("baz", serde_json::json!(3));
state.set("temp:x", serde_json::json!(4));
let s1 = svc
.create_session("app", "alice", Some(state), Some("s1"))
.await
.unwrap();
assert_eq!(s1.state.get("app:foo"), Some(&serde_json::json!(1)));
assert_eq!(s1.state.get("user:bar"), Some(&serde_json::json!(2)));
assert_eq!(s1.state.get("baz"), Some(&serde_json::json!(3)));
assert!(
s1.state.get("temp:x").is_none(),
"temp keys must not survive create_session"
);
let s2 = svc
.create_session("app", "bob", None, Some("s2"))
.await
.unwrap();
let bob = svc
.get_session("app", "bob", &s2.id, Default::default())
.await
.unwrap()
.unwrap();
assert_eq!(bob.state.get("app:foo"), Some(&serde_json::json!(1)));
assert!(
bob.state.get("user:bar").is_none(),
"user-scope must not leak across users"
);
assert!(
bob.state.get("baz").is_none(),
"session-scope must not leak across sessions"
);
let s3 = svc
.create_session("app", "alice", None, Some("s3"))
.await
.unwrap();
let alice2 = svc
.get_session("app", "alice", &s3.id, Default::default())
.await
.unwrap()
.unwrap();
assert_eq!(alice2.state.get("app:foo"), Some(&serde_json::json!(1)));
assert_eq!(alice2.state.get("user:bar"), Some(&serde_json::json!(2)));
assert!(
alice2.state.get("baz").is_none(),
"session-scope must not leak across sessions"
);
}
}
fn apply_filter(mut s: Session, cfg: &GetSessionConfig) -> Session {
if let Some(after) = cfg.after_timestamp {
s.events.retain(|e| e.timestamp >= after);
}
if let Some(n) = cfg.num_recent_events {
let drop = s.events.len().saturating_sub(n);
s.events.drain(..drop);
}
s
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn create_get_delete_roundtrip() {
let svc = InMemorySessionService::new();
let s = svc
.create_session("app", "user", None, Some("s1"))
.await
.unwrap();
assert_eq!(s.id, "s1");
let got = svc
.get_session("app", "user", "s1", GetSessionConfig::default())
.await
.unwrap();
assert!(got.is_some());
svc.delete_session("app", "user", "s1").await.unwrap();
let gone = svc
.get_session("app", "user", "s1", GetSessionConfig::default())
.await
.unwrap();
assert!(gone.is_none());
}
#[tokio::test]
async fn append_event_persists_and_applies_state() {
let svc = InMemorySessionService::new();
let mut s = svc.create_session("app", "user", None, None).await.unwrap();
let mut ev = Event::user_text("hello");
ev.actions
.state_delta
.insert("foo".into(), serde_json::json!("bar"));
ev.actions
.state_delta
.insert("temp:t".into(), serde_json::json!(1));
svc.append_event(&mut s, ev).await.unwrap();
let got = svc
.get_session("app", "user", &s.id, GetSessionConfig::default())
.await
.unwrap()
.unwrap();
assert_eq!(got.state.get("foo"), Some(&serde_json::json!("bar")));
assert!(
got.state.get("temp:t").is_none(),
"temp:t should not survive get_session: {:?}",
got.state
);
let stored_delta = &got.events[0].actions.state_delta;
assert!(!stored_delta.contains_key("temp:t"));
}
#[tokio::test]
async fn list_filters_by_app_and_user() {
let svc = InMemorySessionService::new();
svc.create_session("app", "u1", None, None).await.unwrap();
svc.create_session("app", "u2", None, None).await.unwrap();
svc.create_session("other", "u1", None, None).await.unwrap();
let r = svc.list_sessions("app", "u1").await.unwrap();
assert_eq!(r.sessions.len(), 1);
}
#[tokio::test]
async fn get_session_filters_recent_events() {
let svc = InMemorySessionService::new();
let mut s = svc.create_session("app", "user", None, None).await.unwrap();
for i in 0..5 {
let mut e = Event::user_text(format!("m{i}"));
e.timestamp = f64::from(i);
svc.append_event(&mut s, e).await.unwrap();
}
let got = svc
.get_session(
"app",
"user",
&s.id,
GetSessionConfig {
num_recent_events: Some(2),
..Default::default()
},
)
.await
.unwrap()
.unwrap();
assert_eq!(got.events.len(), 2);
assert_eq!(
got.events[0]
.response
.content
.as_ref()
.unwrap()
.text_concat(),
"m3"
);
}
}