Skip to main content

adk_session/
inmemory.rs

1use crate::{
2    AppendEventRequest, CreateRequest, DeleteRequest, Event, Events, GetRequest, KEY_PREFIX_APP,
3    KEY_PREFIX_TEMP, KEY_PREFIX_USER, ListRequest, Session, SessionService, State,
4};
5use adk_core::Result;
6use adk_core::identity::{AdkIdentity, AppName, SessionId, UserId};
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use serde_json::Value;
10use std::collections::HashMap;
11use std::sync::{Arc, RwLock};
12use uuid::Uuid;
13
14type StateMap = HashMap<String, Value>;
15
16#[derive(Clone)]
17struct SessionData {
18    identity: AdkIdentity,
19    events: Vec<Event>,
20    state: StateMap,
21    updated_at: DateTime<Utc>,
22}
23
24/// In-memory session service for testing and lightweight deployments.
25///
26/// All data is stored in process memory and lost on restart.
27pub struct InMemorySessionService {
28    sessions: Arc<RwLock<HashMap<AdkIdentity, SessionData>>>,
29    app_state: Arc<RwLock<HashMap<String, StateMap>>>,
30    user_state: Arc<RwLock<HashMap<String, HashMap<String, StateMap>>>>,
31}
32
33impl InMemorySessionService {
34    /// Creates a new empty in-memory session service.
35    pub fn new() -> Self {
36        Self {
37            sessions: Arc::new(RwLock::new(HashMap::new())),
38            app_state: Arc::new(RwLock::new(HashMap::new())),
39            user_state: Arc::new(RwLock::new(HashMap::new())),
40        }
41    }
42
43    fn extract_state_deltas(delta: &HashMap<String, Value>) -> (StateMap, StateMap, StateMap) {
44        let mut app_delta = StateMap::new();
45        let mut user_delta = StateMap::new();
46        let mut session_delta = StateMap::new();
47
48        for (key, value) in delta {
49            if let Some(clean_key) = key.strip_prefix(KEY_PREFIX_APP) {
50                app_delta.insert(clean_key.to_string(), value.clone());
51            } else if let Some(clean_key) = key.strip_prefix(KEY_PREFIX_USER) {
52                user_delta.insert(clean_key.to_string(), value.clone());
53            } else if !key.starts_with(KEY_PREFIX_TEMP) {
54                session_delta.insert(key.clone(), value.clone());
55            }
56        }
57
58        (app_delta, user_delta, session_delta)
59    }
60
61    fn merge_states(app: &StateMap, user: &StateMap, session: &StateMap) -> StateMap {
62        let mut merged = session.clone();
63        for (k, v) in app {
64            merged.insert(format!("{KEY_PREFIX_APP}{k}"), v.clone());
65        }
66        for (k, v) in user {
67            merged.insert(format!("{KEY_PREFIX_USER}{k}"), v.clone());
68        }
69        merged
70    }
71
72    /// Build an [`AdkIdentity`] from raw string fields, returning a session
73    /// error if any field fails validation.
74    fn make_identity(app_name: &str, user_id: &str, session_id: &str) -> Result<AdkIdentity> {
75        Ok(AdkIdentity::new(
76            AppName::try_from(app_name)
77                .map_err(|e| adk_core::AdkError::session(format!("invalid app_name: {e}")))?,
78            UserId::try_from(user_id)
79                .map_err(|e| adk_core::AdkError::session(format!("invalid user_id: {e}")))?,
80            SessionId::try_from(session_id)
81                .map_err(|e| adk_core::AdkError::session(format!("invalid session_id: {e}")))?,
82        ))
83    }
84}
85
86impl Default for InMemorySessionService {
87    fn default() -> Self {
88        Self::new()
89    }
90}
91
92#[async_trait]
93impl SessionService for InMemorySessionService {
94    async fn create(&self, req: CreateRequest) -> Result<Box<dyn Session>> {
95        let session_id_str = req.session_id.unwrap_or_else(|| Uuid::new_v4().to_string());
96
97        let identity = Self::make_identity(&req.app_name, &req.user_id, &session_id_str)?;
98
99        let (app_delta, user_delta, session_state) = Self::extract_state_deltas(&req.state);
100
101        let mut app_state_lock = self.app_state.write().unwrap();
102        let app_state = app_state_lock.entry(req.app_name.clone()).or_default();
103        app_state.extend(app_delta);
104        let app_state_clone = app_state.clone();
105        drop(app_state_lock);
106
107        let mut user_state_lock = self.user_state.write().unwrap();
108        let user_map = user_state_lock.entry(req.app_name.clone()).or_default();
109        let user_state = user_map.entry(req.user_id.clone()).or_default();
110        user_state.extend(user_delta);
111        let user_state_clone = user_state.clone();
112        drop(user_state_lock);
113
114        let merged_state = Self::merge_states(&app_state_clone, &user_state_clone, &session_state);
115
116        let data = SessionData {
117            identity: identity.clone(),
118            events: Vec::new(),
119            state: merged_state.clone(),
120            updated_at: Utc::now(),
121        };
122
123        let mut sessions = self.sessions.write().unwrap();
124        sessions.insert(identity.clone(), data);
125        drop(sessions);
126
127        Ok(Box::new(InMemorySession {
128            identity,
129            state: merged_state,
130            events: Vec::new(),
131            updated_at: Utc::now(),
132        }))
133    }
134
135    async fn get(&self, req: GetRequest) -> Result<Box<dyn Session>> {
136        let identity = Self::make_identity(&req.app_name, &req.user_id, &req.session_id)?;
137
138        let sessions = self.sessions.read().unwrap();
139        let data = sessions
140            .get(&identity)
141            .ok_or_else(|| adk_core::AdkError::session("session not found"))?;
142
143        let app_state_lock = self.app_state.read().unwrap();
144        let app_state = app_state_lock.get(&req.app_name).cloned().unwrap_or_default();
145        drop(app_state_lock);
146
147        let user_state_lock = self.user_state.read().unwrap();
148        let user_state = user_state_lock
149            .get(&req.app_name)
150            .and_then(|m| m.get(&req.user_id))
151            .cloned()
152            .unwrap_or_default();
153        drop(user_state_lock);
154
155        let merged_state = Self::merge_states(&app_state, &user_state, &data.state);
156
157        let mut events = data.events.clone();
158        if let Some(num) = req.num_recent_events {
159            let start = events.len().saturating_sub(num);
160            events = events[start..].to_vec();
161        }
162        if let Some(after) = req.after {
163            events.retain(|e| e.timestamp >= after);
164        }
165
166        Ok(Box::new(InMemorySession {
167            identity: data.identity.clone(),
168            state: merged_state,
169            events,
170            updated_at: data.updated_at,
171        }))
172    }
173
174    async fn list(&self, req: ListRequest) -> Result<Vec<Box<dyn Session>>> {
175        let sessions = self.sessions.read().unwrap();
176        let offset = req.offset.unwrap_or(0);
177        let limit = req.limit.unwrap_or(usize::MAX);
178        let mut result = Vec::new();
179
180        for data in sessions.values() {
181            if data.identity.app_name.as_ref() == req.app_name
182                && data.identity.user_id.as_ref() == req.user_id
183            {
184                result.push(data.clone());
185            }
186        }
187
188        // Sort by updated_at descending for consistency with other backends
189        result.sort_by_key(|b| std::cmp::Reverse(b.updated_at));
190
191        let result: Vec<Box<dyn Session>> = result
192            .into_iter()
193            .skip(offset)
194            .take(limit)
195            .map(|data| {
196                Box::new(InMemorySession {
197                    identity: data.identity,
198                    state: data.state,
199                    events: data.events,
200                    updated_at: data.updated_at,
201                }) as Box<dyn Session>
202            })
203            .collect();
204
205        Ok(result)
206    }
207
208    async fn delete(&self, req: DeleteRequest) -> Result<()> {
209        let identity = Self::make_identity(&req.app_name, &req.user_id, &req.session_id)?;
210
211        let mut sessions = self.sessions.write().unwrap();
212        sessions.remove(&identity);
213        Ok(())
214    }
215
216    async fn delete_all_sessions(&self, app_name: &str, user_id: &str) -> Result<()> {
217        let mut sessions = self.sessions.write().unwrap();
218        sessions.retain(|_, data| {
219            !(data.identity.app_name.as_ref() == app_name
220                && data.identity.user_id.as_ref() == user_id)
221        });
222        Ok(())
223    }
224
225    async fn append_event(&self, session_id: &str, mut event: Event) -> Result<()> {
226        event.actions.state_delta.retain(|k, _| !k.starts_with(KEY_PREFIX_TEMP));
227
228        let (app_name, user_id, app_delta, user_delta, _session_delta) = {
229            let mut sessions = self.sessions.write().unwrap();
230            let data = sessions
231                .values_mut()
232                .find(|d| d.identity.session_id.as_ref() == session_id)
233                .ok_or_else(|| adk_core::AdkError::session("session not found"))?;
234
235            data.events.push(event.clone());
236            data.updated_at = event.timestamp;
237
238            let (app_delta, user_delta, session_delta) =
239                Self::extract_state_deltas(&event.actions.state_delta);
240            data.state.extend(session_delta.clone());
241
242            (
243                data.identity.app_name.as_ref().to_string(),
244                data.identity.user_id.as_ref().to_string(),
245                app_delta,
246                user_delta,
247                session_delta,
248            )
249        };
250
251        if !app_delta.is_empty() {
252            let mut app_state_lock = self.app_state.write().unwrap();
253            let app_state = app_state_lock.entry(app_name.clone()).or_default();
254            app_state.extend(app_delta);
255        }
256
257        if !user_delta.is_empty() {
258            let mut user_state_lock = self.user_state.write().unwrap();
259            let user_map = user_state_lock.entry(app_name).or_default();
260            let user_state = user_map.entry(user_id).or_default();
261            user_state.extend(user_delta);
262        }
263
264        Ok(())
265    }
266
267    async fn append_event_for_identity(&self, req: AppendEventRequest) -> Result<()> {
268        let mut event = req.event;
269        event.actions.state_delta.retain(|k, _| !k.starts_with(KEY_PREFIX_TEMP));
270
271        let identity = req.identity;
272
273        let (app_name_str, user_id_str, app_delta, user_delta) = {
274            let mut sessions = self.sessions.write().unwrap();
275            let data = sessions
276                .get_mut(&identity)
277                .ok_or_else(|| adk_core::AdkError::session("session not found"))?;
278
279            data.events.push(event.clone());
280            data.updated_at = event.timestamp;
281
282            let (app_delta, user_delta, session_delta) =
283                Self::extract_state_deltas(&event.actions.state_delta);
284            data.state.extend(session_delta);
285
286            (
287                identity.app_name.as_ref().to_string(),
288                identity.user_id.as_ref().to_string(),
289                app_delta,
290                user_delta,
291            )
292        };
293
294        if !app_delta.is_empty() {
295            let mut app_state_lock = self.app_state.write().unwrap();
296            let app_state = app_state_lock.entry(app_name_str.clone()).or_default();
297            app_state.extend(app_delta);
298        }
299
300        if !user_delta.is_empty() {
301            let mut user_state_lock = self.user_state.write().unwrap();
302            let user_map = user_state_lock.entry(app_name_str).or_default();
303            let user_state = user_map.entry(user_id_str).or_default();
304            user_state.extend(user_delta);
305        }
306
307        Ok(())
308    }
309
310    async fn get_for_identity(&self, identity: &AdkIdentity) -> Result<Box<dyn Session>> {
311        let sessions = self.sessions.read().unwrap();
312        let data = sessions
313            .get(identity)
314            .ok_or_else(|| adk_core::AdkError::session("session not found"))?;
315
316        let app_state_lock = self.app_state.read().unwrap();
317        let app_state = app_state_lock.get(identity.app_name.as_ref()).cloned().unwrap_or_default();
318        drop(app_state_lock);
319
320        let user_state_lock = self.user_state.read().unwrap();
321        let user_state = user_state_lock
322            .get(identity.app_name.as_ref())
323            .and_then(|m| m.get(identity.user_id.as_ref()))
324            .cloned()
325            .unwrap_or_default();
326        drop(user_state_lock);
327
328        let merged_state = Self::merge_states(&app_state, &user_state, &data.state);
329
330        Ok(Box::new(InMemorySession {
331            identity: data.identity.clone(),
332            state: merged_state,
333            events: data.events.clone(),
334            updated_at: data.updated_at,
335        }))
336    }
337
338    async fn delete_for_identity(&self, identity: &AdkIdentity) -> Result<()> {
339        let mut sessions = self.sessions.write().unwrap();
340        sessions.remove(identity);
341        Ok(())
342    }
343}
344
345struct InMemorySession {
346    identity: AdkIdentity,
347    state: StateMap,
348    events: Vec<Event>,
349    updated_at: DateTime<Utc>,
350}
351
352impl Session for InMemorySession {
353    fn id(&self) -> &str {
354        self.identity.session_id.as_ref()
355    }
356
357    fn app_name(&self) -> &str {
358        self.identity.app_name.as_ref()
359    }
360
361    fn user_id(&self) -> &str {
362        self.identity.user_id.as_ref()
363    }
364
365    fn state(&self) -> &dyn State {
366        self
367    }
368
369    fn events(&self) -> &dyn Events {
370        self
371    }
372
373    fn last_update_time(&self) -> DateTime<Utc> {
374        self.updated_at
375    }
376}
377
378impl State for InMemorySession {
379    fn get(&self, key: &str) -> Option<Value> {
380        self.state.get(key).cloned()
381    }
382
383    fn set(&mut self, key: String, value: Value) {
384        if let Err(msg) = adk_core::validate_state_key(&key) {
385            tracing::warn!(key = %key, "rejecting invalid state key: {msg}");
386            return;
387        }
388        self.state.insert(key, value);
389    }
390
391    fn all(&self) -> HashMap<String, Value> {
392        self.state.clone()
393    }
394}
395
396impl Events for InMemorySession {
397    fn all(&self) -> Vec<Event> {
398        self.events.clone()
399    }
400
401    fn len(&self) -> usize {
402        self.events.len()
403    }
404
405    fn at(&self, index: usize) -> Option<&Event> {
406        self.events.get(index)
407    }
408}