Skip to main content

adk_session/
inmemory.rs

1use crate::{
2    AppendEventRequest, CreateRequest, DeleteRequest, Event, Events, GetRequest, KEY_PREFIX_TEMP,
3    ListRequest, Session, SessionService, State, state_utils,
4};
5use adk_core::Result;
6use adk_core::identity::{AdkIdentity, AppName, SessionId, UserId};
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use serde_json::Value;
10use std::collections::HashMap;
11use std::sync::{Arc, RwLock};
12use uuid::Uuid;
13
14type StateMap = HashMap<String, Value>;
15
16#[derive(Clone)]
17struct SessionData {
18    identity: AdkIdentity,
19    events: Vec<Event>,
20    state: StateMap,
21    updated_at: DateTime<Utc>,
22}
23
24/// In-memory session service for testing and lightweight deployments.
25///
26/// All data is stored in process memory and lost on restart.
27pub struct InMemorySessionService {
28    sessions: Arc<RwLock<HashMap<AdkIdentity, SessionData>>>,
29    app_state: Arc<RwLock<HashMap<String, StateMap>>>,
30    user_state: Arc<RwLock<HashMap<String, HashMap<String, StateMap>>>>,
31}
32
33impl InMemorySessionService {
34    /// Creates a new empty in-memory session service.
35    pub fn new() -> Self {
36        Self {
37            sessions: Arc::new(RwLock::new(HashMap::new())),
38            app_state: Arc::new(RwLock::new(HashMap::new())),
39            user_state: Arc::new(RwLock::new(HashMap::new())),
40        }
41    }
42
43    fn extract_state_deltas(delta: &HashMap<String, Value>) -> (StateMap, StateMap, StateMap) {
44        state_utils::extract_state_deltas(delta)
45    }
46
47    fn merge_states(app: &StateMap, user: &StateMap, session: &StateMap) -> StateMap {
48        state_utils::merge_states(app, user, session)
49    }
50
51    /// Build an [`AdkIdentity`] from raw string fields, returning a session
52    /// error if any field fails validation.
53    fn make_identity(app_name: &str, user_id: &str, session_id: &str) -> Result<AdkIdentity> {
54        Ok(AdkIdentity::new(
55            AppName::try_from(app_name)
56                .map_err(|e| adk_core::AdkError::session(format!("invalid app_name: {e}")))?,
57            UserId::try_from(user_id)
58                .map_err(|e| adk_core::AdkError::session(format!("invalid user_id: {e}")))?,
59            SessionId::try_from(session_id)
60                .map_err(|e| adk_core::AdkError::session(format!("invalid session_id: {e}")))?,
61        ))
62    }
63
64    /// Rewind a session to before all events (remove all events and reset state).
65    async fn rewind_to_empty(&self, session_id: &str) -> Result<Box<dyn Session>> {
66        let mut sessions = self.sessions.write().unwrap_or_else(|e| e.into_inner());
67        let data = sessions
68            .values_mut()
69            .find(|d| d.identity.session_id.as_ref() == session_id)
70            .ok_or_else(|| adk_core::AdkError::session("session not found"))?;
71
72        data.events.clear();
73        data.state = HashMap::new();
74        data.updated_at = Utc::now();
75
76        let app_name = data.identity.app_name.as_ref().to_string();
77        let user_id = data.identity.user_id.as_ref().to_string();
78        let identity = data.identity.clone();
79        let updated_at = data.updated_at;
80        drop(sessions);
81
82        let app_state_lock = self.app_state.read().unwrap_or_else(|e| e.into_inner());
83        let app_state = app_state_lock.get(&app_name).cloned().unwrap_or_default();
84        drop(app_state_lock);
85
86        let user_state_lock = self.user_state.read().unwrap_or_else(|e| e.into_inner());
87        let user_state = user_state_lock
88            .get(&app_name)
89            .and_then(|m| m.get(&user_id))
90            .cloned()
91            .unwrap_or_default();
92        drop(user_state_lock);
93
94        let merged_state = state_utils::merge_states(&app_state, &user_state, &HashMap::new());
95
96        Ok(Box::new(InMemorySession {
97            identity,
98            state: merged_state,
99            events: Vec::new(),
100            updated_at,
101        }))
102    }
103}
104
105impl Default for InMemorySessionService {
106    fn default() -> Self {
107        Self::new()
108    }
109}
110
111#[async_trait]
112impl SessionService for InMemorySessionService {
113    async fn create(&self, req: CreateRequest) -> Result<Box<dyn Session>> {
114        let session_id_str = req.session_id.unwrap_or_else(|| Uuid::new_v4().to_string());
115
116        let identity = Self::make_identity(&req.app_name, &req.user_id, &session_id_str)?;
117
118        let (app_delta, user_delta, session_state) = Self::extract_state_deltas(&req.state);
119
120        let mut app_state_lock = self.app_state.write().unwrap_or_else(|e| e.into_inner());
121        let app_state = app_state_lock.entry(req.app_name.clone()).or_default();
122        app_state.extend(app_delta);
123        let app_state_clone = app_state.clone();
124        drop(app_state_lock);
125
126        let mut user_state_lock = self.user_state.write().unwrap_or_else(|e| e.into_inner());
127        let user_map = user_state_lock.entry(req.app_name.clone()).or_default();
128        let user_state = user_map.entry(req.user_id.clone()).or_default();
129        user_state.extend(user_delta);
130        let user_state_clone = user_state.clone();
131        drop(user_state_lock);
132
133        let merged_state = Self::merge_states(&app_state_clone, &user_state_clone, &session_state);
134
135        let data = SessionData {
136            identity: identity.clone(),
137            events: Vec::new(),
138            state: merged_state.clone(),
139            updated_at: Utc::now(),
140        };
141
142        let mut sessions = self.sessions.write().unwrap_or_else(|e| e.into_inner());
143        sessions.insert(identity.clone(), data);
144        drop(sessions);
145
146        Ok(Box::new(InMemorySession {
147            identity,
148            state: merged_state,
149            events: Vec::new(),
150            updated_at: Utc::now(),
151        }))
152    }
153
154    async fn get(&self, req: GetRequest) -> Result<Box<dyn Session>> {
155        let identity = Self::make_identity(&req.app_name, &req.user_id, &req.session_id)?;
156
157        let sessions = self.sessions.read().unwrap_or_else(|e| e.into_inner());
158        let data = sessions
159            .get(&identity)
160            .ok_or_else(|| adk_core::AdkError::session("session not found"))?;
161
162        let app_state_lock = self.app_state.read().unwrap_or_else(|e| e.into_inner());
163        let app_state = app_state_lock.get(&req.app_name).cloned().unwrap_or_default();
164        drop(app_state_lock);
165
166        let user_state_lock = self.user_state.read().unwrap_or_else(|e| e.into_inner());
167        let user_state = user_state_lock
168            .get(&req.app_name)
169            .and_then(|m| m.get(&req.user_id))
170            .cloned()
171            .unwrap_or_default();
172        drop(user_state_lock);
173
174        let merged_state = Self::merge_states(&app_state, &user_state, &data.state);
175
176        let mut events = data.events.clone();
177        if let Some(num) = req.num_recent_events {
178            let start = events.len().saturating_sub(num);
179            events = events[start..].to_vec();
180        }
181        if let Some(after) = req.after {
182            events.retain(|e| e.timestamp >= after);
183        }
184
185        Ok(Box::new(InMemorySession {
186            identity: data.identity.clone(),
187            state: merged_state,
188            events,
189            updated_at: data.updated_at,
190        }))
191    }
192
193    async fn list(&self, req: ListRequest) -> Result<Vec<Box<dyn Session>>> {
194        let sessions = self.sessions.read().unwrap_or_else(|e| e.into_inner());
195        let offset = req.offset.unwrap_or(0);
196        let limit = req.limit.unwrap_or(usize::MAX);
197        let mut result = Vec::new();
198
199        for data in sessions.values() {
200            if data.identity.app_name.as_ref() == req.app_name
201                && data.identity.user_id.as_ref() == req.user_id
202            {
203                result.push(data.clone());
204            }
205        }
206
207        // Sort by updated_at descending for consistency with other backends
208        result.sort_by_key(|b| std::cmp::Reverse(b.updated_at));
209
210        let result: Vec<Box<dyn Session>> = result
211            .into_iter()
212            .skip(offset)
213            .take(limit)
214            .map(|data| {
215                Box::new(InMemorySession {
216                    identity: data.identity,
217                    state: data.state,
218                    events: data.events,
219                    updated_at: data.updated_at,
220                }) as Box<dyn Session>
221            })
222            .collect();
223
224        Ok(result)
225    }
226
227    async fn delete(&self, req: DeleteRequest) -> Result<()> {
228        let identity = Self::make_identity(&req.app_name, &req.user_id, &req.session_id)?;
229
230        let mut sessions = self.sessions.write().unwrap_or_else(|e| e.into_inner());
231        sessions.remove(&identity);
232        Ok(())
233    }
234
235    async fn delete_all_sessions(&self, app_name: &str, user_id: &str) -> Result<()> {
236        let mut sessions = self.sessions.write().unwrap_or_else(|e| e.into_inner());
237        sessions.retain(|_, data| {
238            !(data.identity.app_name.as_ref() == app_name
239                && data.identity.user_id.as_ref() == user_id)
240        });
241        Ok(())
242    }
243
244    async fn append_event(&self, session_id: &str, mut event: Event) -> Result<()> {
245        event.actions.state_delta.retain(|k, _| !k.starts_with(KEY_PREFIX_TEMP));
246
247        let (app_name, user_id, app_delta, user_delta, _session_delta) = {
248            let mut sessions = self.sessions.write().unwrap_or_else(|e| e.into_inner());
249            let data = sessions
250                .values_mut()
251                .find(|d| d.identity.session_id.as_ref() == session_id)
252                .ok_or_else(|| adk_core::AdkError::session("session not found"))?;
253
254            data.events.push(event.clone());
255            data.updated_at = event.timestamp;
256
257            let (app_delta, user_delta, session_delta) =
258                Self::extract_state_deltas(&event.actions.state_delta);
259            data.state.extend(session_delta.clone());
260
261            (
262                data.identity.app_name.as_ref().to_string(),
263                data.identity.user_id.as_ref().to_string(),
264                app_delta,
265                user_delta,
266                session_delta,
267            )
268        };
269
270        if !app_delta.is_empty() {
271            let mut app_state_lock = self.app_state.write().unwrap_or_else(|e| e.into_inner());
272            let app_state = app_state_lock.entry(app_name.clone()).or_default();
273            app_state.extend(app_delta);
274        }
275
276        if !user_delta.is_empty() {
277            let mut user_state_lock = self.user_state.write().unwrap_or_else(|e| e.into_inner());
278            let user_map = user_state_lock.entry(app_name).or_default();
279            let user_state = user_map.entry(user_id).or_default();
280            user_state.extend(user_delta);
281        }
282
283        Ok(())
284    }
285
286    async fn append_event_for_identity(&self, req: AppendEventRequest) -> Result<()> {
287        let mut event = req.event;
288        event.actions.state_delta.retain(|k, _| !k.starts_with(KEY_PREFIX_TEMP));
289
290        let identity = req.identity;
291
292        let (app_name_str, user_id_str, app_delta, user_delta) = {
293            let mut sessions = self.sessions.write().unwrap_or_else(|e| e.into_inner());
294            let data = sessions
295                .get_mut(&identity)
296                .ok_or_else(|| adk_core::AdkError::session("session not found"))?;
297
298            data.events.push(event.clone());
299            data.updated_at = event.timestamp;
300
301            let (app_delta, user_delta, session_delta) =
302                Self::extract_state_deltas(&event.actions.state_delta);
303            data.state.extend(session_delta);
304
305            (
306                identity.app_name.as_ref().to_string(),
307                identity.user_id.as_ref().to_string(),
308                app_delta,
309                user_delta,
310            )
311        };
312
313        if !app_delta.is_empty() {
314            let mut app_state_lock = self.app_state.write().unwrap_or_else(|e| e.into_inner());
315            let app_state = app_state_lock.entry(app_name_str.clone()).or_default();
316            app_state.extend(app_delta);
317        }
318
319        if !user_delta.is_empty() {
320            let mut user_state_lock = self.user_state.write().unwrap_or_else(|e| e.into_inner());
321            let user_map = user_state_lock.entry(app_name_str).or_default();
322            let user_state = user_map.entry(user_id_str).or_default();
323            user_state.extend(user_delta);
324        }
325
326        Ok(())
327    }
328
329    async fn get_for_identity(&self, identity: &AdkIdentity) -> Result<Box<dyn Session>> {
330        let sessions = self.sessions.read().unwrap_or_else(|e| e.into_inner());
331        let data = sessions
332            .get(identity)
333            .ok_or_else(|| adk_core::AdkError::session("session not found"))?;
334
335        let app_state_lock = self.app_state.read().unwrap_or_else(|e| e.into_inner());
336        let app_state = app_state_lock.get(identity.app_name.as_ref()).cloned().unwrap_or_default();
337        drop(app_state_lock);
338
339        let user_state_lock = self.user_state.read().unwrap_or_else(|e| e.into_inner());
340        let user_state = user_state_lock
341            .get(identity.app_name.as_ref())
342            .and_then(|m| m.get(identity.user_id.as_ref()))
343            .cloned()
344            .unwrap_or_default();
345        drop(user_state_lock);
346
347        let merged_state = Self::merge_states(&app_state, &user_state, &data.state);
348
349        Ok(Box::new(InMemorySession {
350            identity: data.identity.clone(),
351            state: merged_state,
352            events: data.events.clone(),
353            updated_at: data.updated_at,
354        }))
355    }
356
357    async fn delete_for_identity(&self, identity: &AdkIdentity) -> Result<()> {
358        let mut sessions = self.sessions.write().unwrap_or_else(|e| e.into_inner());
359        sessions.remove(identity);
360        Ok(())
361    }
362
363    async fn rewind(&self, session_id: &str, target_event_id: &str) -> Result<Box<dyn Session>> {
364        let mut sessions = self.sessions.write().unwrap_or_else(|e| e.into_inner());
365
366        // Find the session by session_id
367        let data = sessions
368            .values_mut()
369            .find(|d| d.identity.session_id.as_ref() == session_id)
370            .ok_or_else(|| adk_core::AdkError::session("session not found"))?;
371
372        // Find the target event index
373        let target_index =
374            data.events.iter().position(|e| e.id == target_event_id).ok_or_else(|| {
375                adk_core::AdkError::session(format!("target event not found: {target_event_id}"))
376            })?;
377
378        // Truncate events after the target (keep 0..=target_index)
379        data.events.truncate(target_index + 1);
380
381        // Rebuild session state from remaining events' state deltas
382        let mut rebuilt_session_state: HashMap<String, Value> = HashMap::new();
383        for event in &data.events {
384            let (_app_delta, _user_delta, session_delta) =
385                state_utils::extract_state_deltas(&event.actions.state_delta);
386            rebuilt_session_state.extend(session_delta);
387        }
388
389        // Get app and user state (these are not rewound — they are separate)
390        let app_name = data.identity.app_name.as_ref().to_string();
391        let user_id = data.identity.user_id.as_ref().to_string();
392
393        // Update the stored session state with rebuilt session-level state
394        data.state = rebuilt_session_state.clone();
395        data.updated_at = data.events.last().map(|e| e.timestamp).unwrap_or(Utc::now());
396
397        let identity = data.identity.clone();
398        let events = data.events.clone();
399        let updated_at = data.updated_at;
400        drop(sessions);
401
402        // Merge with app and user state for the returned session
403        let app_state_lock = self.app_state.read().unwrap_or_else(|e| e.into_inner());
404        let app_state = app_state_lock.get(&app_name).cloned().unwrap_or_default();
405        drop(app_state_lock);
406
407        let user_state_lock = self.user_state.read().unwrap_or_else(|e| e.into_inner());
408        let user_state = user_state_lock
409            .get(&app_name)
410            .and_then(|m| m.get(&user_id))
411            .cloned()
412            .unwrap_or_default();
413        drop(user_state_lock);
414
415        let merged_state =
416            state_utils::merge_states(&app_state, &user_state, &rebuilt_session_state);
417
418        Ok(Box::new(InMemorySession { identity, state: merged_state, events, updated_at }))
419    }
420
421    async fn rewind_steps(&self, session_id: &str, steps: usize) -> Result<Box<dyn Session>> {
422        if steps == 0 {
423            // Return the session unchanged
424            let sessions = self.sessions.read().unwrap_or_else(|e| e.into_inner());
425            let data = sessions
426                .values()
427                .find(|d| d.identity.session_id.as_ref() == session_id)
428                .ok_or_else(|| adk_core::AdkError::session("session not found"))?;
429
430            let app_name = data.identity.app_name.as_ref().to_string();
431            let user_id = data.identity.user_id.as_ref().to_string();
432            let identity = data.identity.clone();
433            let events = data.events.clone();
434            let session_state = data.state.clone();
435            let updated_at = data.updated_at;
436            drop(sessions);
437
438            let app_state_lock = self.app_state.read().unwrap_or_else(|e| e.into_inner());
439            let app_state = app_state_lock.get(&app_name).cloned().unwrap_or_default();
440            drop(app_state_lock);
441
442            let user_state_lock = self.user_state.read().unwrap_or_else(|e| e.into_inner());
443            let user_state = user_state_lock
444                .get(&app_name)
445                .and_then(|m| m.get(&user_id))
446                .cloned()
447                .unwrap_or_default();
448            drop(user_state_lock);
449
450            let merged_state = state_utils::merge_states(&app_state, &user_state, &session_state);
451
452            return Ok(Box::new(InMemorySession {
453                identity,
454                state: merged_state,
455                events,
456                updated_at,
457            }));
458        }
459
460        // Read the event count and determine target
461        let rewind_target = {
462            let sessions = self.sessions.read().unwrap_or_else(|e| e.into_inner());
463            let data = sessions
464                .values()
465                .find(|d| d.identity.session_id.as_ref() == session_id)
466                .ok_or_else(|| adk_core::AdkError::session("session not found"))?;
467
468            if steps > data.events.len() {
469                return Err(adk_core::AdkError::session("rewind steps exceeds event count"));
470            }
471
472            let target_index = data.events.len() - steps;
473            if target_index == 0 {
474                // Rewinding all events
475                None
476            } else {
477                Some(data.events[target_index - 1].id.clone())
478            }
479        };
480
481        match rewind_target {
482            Some(target_event_id) => self.rewind(session_id, &target_event_id).await,
483            None => self.rewind_to_empty(session_id).await,
484        }
485    }
486}
487
488struct InMemorySession {
489    identity: AdkIdentity,
490    state: StateMap,
491    events: Vec<Event>,
492    updated_at: DateTime<Utc>,
493}
494
495impl Session for InMemorySession {
496    fn id(&self) -> &str {
497        self.identity.session_id.as_ref()
498    }
499
500    fn app_name(&self) -> &str {
501        self.identity.app_name.as_ref()
502    }
503
504    fn user_id(&self) -> &str {
505        self.identity.user_id.as_ref()
506    }
507
508    fn state(&self) -> &dyn State {
509        self
510    }
511
512    fn events(&self) -> &dyn Events {
513        self
514    }
515
516    fn last_update_time(&self) -> DateTime<Utc> {
517        self.updated_at
518    }
519}
520
521impl State for InMemorySession {
522    fn get(&self, key: &str) -> Option<Value> {
523        self.state.get(key).cloned()
524    }
525
526    fn set(&mut self, key: String, value: Value) {
527        if let Err(msg) = adk_core::validate_state_key(&key) {
528            tracing::warn!(key = %key, "rejecting invalid state key: {msg}");
529            return;
530        }
531        self.state.insert(key, value);
532    }
533
534    fn all(&self) -> HashMap<String, Value> {
535        self.state.clone()
536    }
537}
538
539impl Events for InMemorySession {
540    fn all(&self) -> Vec<Event> {
541        self.events.clone()
542    }
543
544    fn len(&self) -> usize {
545        self.events.len()
546    }
547
548    fn at(&self, index: usize) -> Option<&Event> {
549        self.events.get(index)
550    }
551}