Skip to main content

adk_session/
inmemory.rs

1use crate::{
2    AppendEventRequest, CreateRequest, DeleteRequest, Event, Events, GetRequest, KEY_PREFIX_APP,
3    KEY_PREFIX_TEMP, KEY_PREFIX_USER, ListRequest, Session, SessionService, State,
4};
5use adk_core::Result;
6use adk_core::identity::{AdkIdentity, AppName, SessionId, UserId};
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use serde_json::Value;
10use std::collections::HashMap;
11use std::sync::{Arc, RwLock};
12use uuid::Uuid;
13
14type StateMap = HashMap<String, Value>;
15
16#[derive(Clone)]
17struct SessionData {
18    identity: AdkIdentity,
19    events: Vec<Event>,
20    state: StateMap,
21    updated_at: DateTime<Utc>,
22}
23
24pub struct InMemorySessionService {
25    sessions: Arc<RwLock<HashMap<AdkIdentity, SessionData>>>,
26    app_state: Arc<RwLock<HashMap<String, StateMap>>>,
27    user_state: Arc<RwLock<HashMap<String, HashMap<String, StateMap>>>>,
28}
29
30impl InMemorySessionService {
31    pub fn new() -> Self {
32        Self {
33            sessions: Arc::new(RwLock::new(HashMap::new())),
34            app_state: Arc::new(RwLock::new(HashMap::new())),
35            user_state: Arc::new(RwLock::new(HashMap::new())),
36        }
37    }
38
39    fn extract_state_deltas(delta: &HashMap<String, Value>) -> (StateMap, StateMap, StateMap) {
40        let mut app_delta = StateMap::new();
41        let mut user_delta = StateMap::new();
42        let mut session_delta = StateMap::new();
43
44        for (key, value) in delta {
45            if let Some(clean_key) = key.strip_prefix(KEY_PREFIX_APP) {
46                app_delta.insert(clean_key.to_string(), value.clone());
47            } else if let Some(clean_key) = key.strip_prefix(KEY_PREFIX_USER) {
48                user_delta.insert(clean_key.to_string(), value.clone());
49            } else if !key.starts_with(KEY_PREFIX_TEMP) {
50                session_delta.insert(key.clone(), value.clone());
51            }
52        }
53
54        (app_delta, user_delta, session_delta)
55    }
56
57    fn merge_states(app: &StateMap, user: &StateMap, session: &StateMap) -> StateMap {
58        let mut merged = session.clone();
59        for (k, v) in app {
60            merged.insert(format!("{KEY_PREFIX_APP}{k}"), v.clone());
61        }
62        for (k, v) in user {
63            merged.insert(format!("{KEY_PREFIX_USER}{k}"), v.clone());
64        }
65        merged
66    }
67
68    /// Build an [`AdkIdentity`] from raw string fields, returning a session
69    /// error if any field fails validation.
70    fn make_identity(app_name: &str, user_id: &str, session_id: &str) -> Result<AdkIdentity> {
71        Ok(AdkIdentity::new(
72            AppName::try_from(app_name)
73                .map_err(|e| adk_core::AdkError::Session(format!("invalid app_name: {e}")))?,
74            UserId::try_from(user_id)
75                .map_err(|e| adk_core::AdkError::Session(format!("invalid user_id: {e}")))?,
76            SessionId::try_from(session_id)
77                .map_err(|e| adk_core::AdkError::Session(format!("invalid session_id: {e}")))?,
78        ))
79    }
80}
81
82impl Default for InMemorySessionService {
83    fn default() -> Self {
84        Self::new()
85    }
86}
87
88#[async_trait]
89impl SessionService for InMemorySessionService {
90    async fn create(&self, req: CreateRequest) -> Result<Box<dyn Session>> {
91        let session_id_str = req.session_id.unwrap_or_else(|| Uuid::new_v4().to_string());
92
93        let identity = Self::make_identity(&req.app_name, &req.user_id, &session_id_str)?;
94
95        let (app_delta, user_delta, session_state) = Self::extract_state_deltas(&req.state);
96
97        let mut app_state_lock = self.app_state.write().unwrap();
98        let app_state = app_state_lock.entry(req.app_name.clone()).or_default();
99        app_state.extend(app_delta);
100        let app_state_clone = app_state.clone();
101        drop(app_state_lock);
102
103        let mut user_state_lock = self.user_state.write().unwrap();
104        let user_map = user_state_lock.entry(req.app_name.clone()).or_default();
105        let user_state = user_map.entry(req.user_id.clone()).or_default();
106        user_state.extend(user_delta);
107        let user_state_clone = user_state.clone();
108        drop(user_state_lock);
109
110        let merged_state = Self::merge_states(&app_state_clone, &user_state_clone, &session_state);
111
112        let data = SessionData {
113            identity: identity.clone(),
114            events: Vec::new(),
115            state: merged_state.clone(),
116            updated_at: Utc::now(),
117        };
118
119        let mut sessions = self.sessions.write().unwrap();
120        sessions.insert(identity.clone(), data);
121        drop(sessions);
122
123        Ok(Box::new(InMemorySession {
124            identity,
125            state: merged_state,
126            events: Vec::new(),
127            updated_at: Utc::now(),
128        }))
129    }
130
131    async fn get(&self, req: GetRequest) -> Result<Box<dyn Session>> {
132        let identity = Self::make_identity(&req.app_name, &req.user_id, &req.session_id)?;
133
134        let sessions = self.sessions.read().unwrap();
135        let data = sessions
136            .get(&identity)
137            .ok_or_else(|| adk_core::AdkError::Session("session not found".into()))?;
138
139        let app_state_lock = self.app_state.read().unwrap();
140        let app_state = app_state_lock.get(&req.app_name).cloned().unwrap_or_default();
141        drop(app_state_lock);
142
143        let user_state_lock = self.user_state.read().unwrap();
144        let user_state = user_state_lock
145            .get(&req.app_name)
146            .and_then(|m| m.get(&req.user_id))
147            .cloned()
148            .unwrap_or_default();
149        drop(user_state_lock);
150
151        let merged_state = Self::merge_states(&app_state, &user_state, &data.state);
152
153        let mut events = data.events.clone();
154        if let Some(num) = req.num_recent_events {
155            let start = events.len().saturating_sub(num);
156            events = events[start..].to_vec();
157        }
158        if let Some(after) = req.after {
159            events.retain(|e| e.timestamp >= after);
160        }
161
162        Ok(Box::new(InMemorySession {
163            identity: data.identity.clone(),
164            state: merged_state,
165            events,
166            updated_at: data.updated_at,
167        }))
168    }
169
170    async fn list(&self, req: ListRequest) -> Result<Vec<Box<dyn Session>>> {
171        let sessions = self.sessions.read().unwrap();
172        let offset = req.offset.unwrap_or(0);
173        let limit = req.limit.unwrap_or(usize::MAX);
174        let mut result = Vec::new();
175
176        for data in sessions.values() {
177            if data.identity.app_name.as_ref() == req.app_name
178                && data.identity.user_id.as_ref() == req.user_id
179            {
180                result.push(data.clone());
181            }
182        }
183
184        // Sort by updated_at descending for consistency with other backends
185        result.sort_by(|a, b| b.updated_at.cmp(&a.updated_at));
186
187        let result: Vec<Box<dyn Session>> = result
188            .into_iter()
189            .skip(offset)
190            .take(limit)
191            .map(|data| {
192                Box::new(InMemorySession {
193                    identity: data.identity,
194                    state: data.state,
195                    events: data.events,
196                    updated_at: data.updated_at,
197                }) as Box<dyn Session>
198            })
199            .collect();
200
201        Ok(result)
202    }
203
204    async fn delete(&self, req: DeleteRequest) -> Result<()> {
205        let identity = Self::make_identity(&req.app_name, &req.user_id, &req.session_id)?;
206
207        let mut sessions = self.sessions.write().unwrap();
208        sessions.remove(&identity);
209        Ok(())
210    }
211
212    async fn delete_all_sessions(&self, app_name: &str, user_id: &str) -> Result<()> {
213        let mut sessions = self.sessions.write().unwrap();
214        sessions.retain(|_, data| {
215            !(data.identity.app_name.as_ref() == app_name
216                && data.identity.user_id.as_ref() == user_id)
217        });
218        Ok(())
219    }
220
221    async fn append_event(&self, session_id: &str, mut event: Event) -> Result<()> {
222        event.actions.state_delta.retain(|k, _| !k.starts_with(KEY_PREFIX_TEMP));
223
224        let (app_name, user_id, app_delta, user_delta, _session_delta) = {
225            let mut sessions = self.sessions.write().unwrap();
226            let data = sessions
227                .values_mut()
228                .find(|d| d.identity.session_id.as_ref() == session_id)
229                .ok_or_else(|| adk_core::AdkError::Session("session not found".into()))?;
230
231            data.events.push(event.clone());
232            data.updated_at = event.timestamp;
233
234            let (app_delta, user_delta, session_delta) =
235                Self::extract_state_deltas(&event.actions.state_delta);
236            data.state.extend(session_delta.clone());
237
238            (
239                data.identity.app_name.as_ref().to_string(),
240                data.identity.user_id.as_ref().to_string(),
241                app_delta,
242                user_delta,
243                session_delta,
244            )
245        };
246
247        if !app_delta.is_empty() {
248            let mut app_state_lock = self.app_state.write().unwrap();
249            let app_state = app_state_lock.entry(app_name.clone()).or_default();
250            app_state.extend(app_delta);
251        }
252
253        if !user_delta.is_empty() {
254            let mut user_state_lock = self.user_state.write().unwrap();
255            let user_map = user_state_lock.entry(app_name).or_default();
256            let user_state = user_map.entry(user_id).or_default();
257            user_state.extend(user_delta);
258        }
259
260        Ok(())
261    }
262
263    async fn append_event_for_identity(&self, req: AppendEventRequest) -> Result<()> {
264        let mut event = req.event;
265        event.actions.state_delta.retain(|k, _| !k.starts_with(KEY_PREFIX_TEMP));
266
267        let identity = req.identity;
268
269        let (app_name_str, user_id_str, app_delta, user_delta) = {
270            let mut sessions = self.sessions.write().unwrap();
271            let data = sessions
272                .get_mut(&identity)
273                .ok_or_else(|| adk_core::AdkError::Session("session not found".into()))?;
274
275            data.events.push(event.clone());
276            data.updated_at = event.timestamp;
277
278            let (app_delta, user_delta, session_delta) =
279                Self::extract_state_deltas(&event.actions.state_delta);
280            data.state.extend(session_delta);
281
282            (
283                identity.app_name.as_ref().to_string(),
284                identity.user_id.as_ref().to_string(),
285                app_delta,
286                user_delta,
287            )
288        };
289
290        if !app_delta.is_empty() {
291            let mut app_state_lock = self.app_state.write().unwrap();
292            let app_state = app_state_lock.entry(app_name_str.clone()).or_default();
293            app_state.extend(app_delta);
294        }
295
296        if !user_delta.is_empty() {
297            let mut user_state_lock = self.user_state.write().unwrap();
298            let user_map = user_state_lock.entry(app_name_str).or_default();
299            let user_state = user_map.entry(user_id_str).or_default();
300            user_state.extend(user_delta);
301        }
302
303        Ok(())
304    }
305
306    async fn get_for_identity(&self, identity: &AdkIdentity) -> Result<Box<dyn Session>> {
307        let sessions = self.sessions.read().unwrap();
308        let data = sessions
309            .get(identity)
310            .ok_or_else(|| adk_core::AdkError::Session("session not found".into()))?;
311
312        let app_state_lock = self.app_state.read().unwrap();
313        let app_state = app_state_lock.get(identity.app_name.as_ref()).cloned().unwrap_or_default();
314        drop(app_state_lock);
315
316        let user_state_lock = self.user_state.read().unwrap();
317        let user_state = user_state_lock
318            .get(identity.app_name.as_ref())
319            .and_then(|m| m.get(identity.user_id.as_ref()))
320            .cloned()
321            .unwrap_or_default();
322        drop(user_state_lock);
323
324        let merged_state = Self::merge_states(&app_state, &user_state, &data.state);
325
326        Ok(Box::new(InMemorySession {
327            identity: data.identity.clone(),
328            state: merged_state,
329            events: data.events.clone(),
330            updated_at: data.updated_at,
331        }))
332    }
333
334    async fn delete_for_identity(&self, identity: &AdkIdentity) -> Result<()> {
335        let mut sessions = self.sessions.write().unwrap();
336        sessions.remove(identity);
337        Ok(())
338    }
339}
340
341struct InMemorySession {
342    identity: AdkIdentity,
343    state: StateMap,
344    events: Vec<Event>,
345    updated_at: DateTime<Utc>,
346}
347
348impl Session for InMemorySession {
349    fn id(&self) -> &str {
350        self.identity.session_id.as_ref()
351    }
352
353    fn app_name(&self) -> &str {
354        self.identity.app_name.as_ref()
355    }
356
357    fn user_id(&self) -> &str {
358        self.identity.user_id.as_ref()
359    }
360
361    fn state(&self) -> &dyn State {
362        self
363    }
364
365    fn events(&self) -> &dyn Events {
366        self
367    }
368
369    fn last_update_time(&self) -> DateTime<Utc> {
370        self.updated_at
371    }
372}
373
374impl State for InMemorySession {
375    fn get(&self, key: &str) -> Option<Value> {
376        self.state.get(key).cloned()
377    }
378
379    fn set(&mut self, key: String, value: Value) {
380        self.state.insert(key, value);
381    }
382
383    fn all(&self) -> HashMap<String, Value> {
384        self.state.clone()
385    }
386}
387
388impl Events for InMemorySession {
389    fn all(&self) -> Vec<Event> {
390        self.events.clone()
391    }
392
393    fn len(&self) -> usize {
394        self.events.len()
395    }
396
397    fn at(&self, index: usize) -> Option<&Event> {
398        self.events.get(index)
399    }
400}