adk_session/
inmemory.rs

1use crate::{
2    CreateRequest, DeleteRequest, Event, Events, GetRequest, ListRequest, Session, SessionService,
3    State, KEY_PREFIX_APP, KEY_PREFIX_TEMP, KEY_PREFIX_USER,
4};
5use adk_core::Result;
6use async_trait::async_trait;
7use chrono::{DateTime, Utc};
8use serde_json::Value;
9use std::collections::HashMap;
10use std::sync::{Arc, RwLock};
11use uuid::Uuid;
12
13type StateMap = HashMap<String, Value>;
14
15#[derive(Clone)]
16struct SessionData {
17    id: SessionId,
18    events: Vec<Event>,
19    state: StateMap,
20    updated_at: DateTime<Utc>,
21}
22
23#[derive(Clone, Debug, PartialEq, Eq, Hash)]
24struct SessionId {
25    app_name: String,
26    user_id: String,
27    session_id: String,
28}
29
30impl SessionId {
31    fn key(&self) -> String {
32        format!("{}:{}:{}", self.app_name, self.user_id, self.session_id)
33    }
34}
35
36pub struct InMemorySessionService {
37    sessions: Arc<RwLock<HashMap<String, SessionData>>>,
38    app_state: Arc<RwLock<HashMap<String, StateMap>>>,
39    user_state: Arc<RwLock<HashMap<String, HashMap<String, StateMap>>>>,
40}
41
42impl InMemorySessionService {
43    pub fn new() -> Self {
44        Self {
45            sessions: Arc::new(RwLock::new(HashMap::new())),
46            app_state: Arc::new(RwLock::new(HashMap::new())),
47            user_state: Arc::new(RwLock::new(HashMap::new())),
48        }
49    }
50
51    fn extract_state_deltas(delta: &HashMap<String, Value>) -> (StateMap, StateMap, StateMap) {
52        let mut app_delta = StateMap::new();
53        let mut user_delta = StateMap::new();
54        let mut session_delta = StateMap::new();
55
56        for (key, value) in delta {
57            if let Some(clean_key) = key.strip_prefix(KEY_PREFIX_APP) {
58                app_delta.insert(clean_key.to_string(), value.clone());
59            } else if let Some(clean_key) = key.strip_prefix(KEY_PREFIX_USER) {
60                user_delta.insert(clean_key.to_string(), value.clone());
61            } else if !key.starts_with(KEY_PREFIX_TEMP) {
62                session_delta.insert(key.clone(), value.clone());
63            }
64        }
65
66        (app_delta, user_delta, session_delta)
67    }
68
69    fn merge_states(app: &StateMap, user: &StateMap, session: &StateMap) -> StateMap {
70        let mut merged = session.clone();
71        for (k, v) in app {
72            merged.insert(format!("{}{}", KEY_PREFIX_APP, k), v.clone());
73        }
74        for (k, v) in user {
75            merged.insert(format!("{}{}", KEY_PREFIX_USER, k), v.clone());
76        }
77        merged
78    }
79}
80
81impl Default for InMemorySessionService {
82    fn default() -> Self {
83        Self::new()
84    }
85}
86
87#[async_trait]
88impl SessionService for InMemorySessionService {
89    async fn create(&self, req: CreateRequest) -> Result<Box<dyn Session>> {
90        let session_id = req.session_id.unwrap_or_else(|| Uuid::new_v4().to_string());
91
92        let id = SessionId {
93            app_name: req.app_name.clone(),
94            user_id: req.user_id.clone(),
95            session_id: session_id.clone(),
96        };
97
98        let (app_delta, user_delta, session_state) = Self::extract_state_deltas(&req.state);
99
100        let mut app_state_lock = self.app_state.write().unwrap();
101        let app_state = app_state_lock.entry(req.app_name.clone()).or_default();
102        app_state.extend(app_delta);
103        let app_state_clone = app_state.clone();
104        drop(app_state_lock);
105
106        let mut user_state_lock = self.user_state.write().unwrap();
107        let user_map = user_state_lock.entry(req.app_name.clone()).or_default();
108        let user_state = user_map.entry(req.user_id.clone()).or_default();
109        user_state.extend(user_delta);
110        let user_state_clone = user_state.clone();
111        drop(user_state_lock);
112
113        let merged_state = Self::merge_states(&app_state_clone, &user_state_clone, &session_state);
114
115        let data = SessionData {
116            id: id.clone(),
117            events: Vec::new(),
118            state: merged_state.clone(),
119            updated_at: Utc::now(),
120        };
121
122        let mut sessions = self.sessions.write().unwrap();
123        sessions.insert(id.key(), data);
124        drop(sessions);
125
126        Ok(Box::new(InMemorySession {
127            id,
128            state: merged_state,
129            events: Vec::new(),
130            updated_at: Utc::now(),
131        }))
132    }
133
134    async fn get(&self, req: GetRequest) -> Result<Box<dyn Session>> {
135        let id = SessionId {
136            app_name: req.app_name.clone(),
137            user_id: req.user_id.clone(),
138            session_id: req.session_id.clone(),
139        };
140
141        let sessions = self.sessions.read().unwrap();
142        let data = sessions
143            .get(&id.key())
144            .ok_or_else(|| adk_core::AdkError::Session("session not found".into()))?;
145
146        let app_state_lock = self.app_state.read().unwrap();
147        let app_state = app_state_lock.get(&req.app_name).cloned().unwrap_or_default();
148        drop(app_state_lock);
149
150        let user_state_lock = self.user_state.read().unwrap();
151        let user_state = user_state_lock
152            .get(&req.app_name)
153            .and_then(|m| m.get(&req.user_id))
154            .cloned()
155            .unwrap_or_default();
156        drop(user_state_lock);
157
158        let merged_state = Self::merge_states(&app_state, &user_state, &data.state);
159
160        let mut events = data.events.clone();
161        if let Some(num) = req.num_recent_events {
162            let start = events.len().saturating_sub(num);
163            events = events[start..].to_vec();
164        }
165        if let Some(after) = req.after {
166            events.retain(|e| e.timestamp >= after);
167        }
168
169        Ok(Box::new(InMemorySession {
170            id: data.id.clone(),
171            state: merged_state,
172            events,
173            updated_at: data.updated_at,
174        }))
175    }
176
177    async fn list(&self, req: ListRequest) -> Result<Vec<Box<dyn Session>>> {
178        let sessions = self.sessions.read().unwrap();
179        let mut result = Vec::new();
180
181        for data in sessions.values() {
182            if data.id.app_name == req.app_name && data.id.user_id == req.user_id {
183                result.push(Box::new(InMemorySession {
184                    id: data.id.clone(),
185                    state: data.state.clone(),
186                    events: data.events.clone(),
187                    updated_at: data.updated_at,
188                }) as Box<dyn Session>);
189            }
190        }
191
192        Ok(result)
193    }
194
195    async fn delete(&self, req: DeleteRequest) -> Result<()> {
196        let id =
197            SessionId { app_name: req.app_name, user_id: req.user_id, session_id: req.session_id };
198
199        let mut sessions = self.sessions.write().unwrap();
200        sessions.remove(&id.key());
201        Ok(())
202    }
203
204    async fn append_event(&self, session_id: &str, mut event: Event) -> Result<()> {
205        event.actions.state_delta.retain(|k, _| !k.starts_with(KEY_PREFIX_TEMP));
206
207        let (app_name, user_id, app_delta, user_delta, _session_delta) = {
208            let mut sessions = self.sessions.write().unwrap();
209            let data = sessions
210                .values_mut()
211                .find(|d| d.id.session_id == session_id)
212                .ok_or_else(|| adk_core::AdkError::Session("session not found".into()))?;
213
214            data.events.push(event.clone());
215            data.updated_at = event.timestamp;
216
217            let (app_delta, user_delta, session_delta) =
218                Self::extract_state_deltas(&event.actions.state_delta);
219            data.state.extend(session_delta.clone());
220
221            (
222                data.id.app_name.clone(),
223                data.id.user_id.clone(),
224                app_delta,
225                user_delta,
226                session_delta,
227            )
228        };
229
230        if !app_delta.is_empty() {
231            let mut app_state_lock = self.app_state.write().unwrap();
232            let app_state = app_state_lock.entry(app_name.clone()).or_default();
233            app_state.extend(app_delta);
234        }
235
236        if !user_delta.is_empty() {
237            let mut user_state_lock = self.user_state.write().unwrap();
238            let user_map = user_state_lock.entry(app_name).or_default();
239            let user_state = user_map.entry(user_id).or_default();
240            user_state.extend(user_delta);
241        }
242
243        Ok(())
244    }
245}
246
247struct InMemorySession {
248    id: SessionId,
249    state: StateMap,
250    events: Vec<Event>,
251    updated_at: DateTime<Utc>,
252}
253
254impl Session for InMemorySession {
255    fn id(&self) -> &str {
256        &self.id.session_id
257    }
258
259    fn app_name(&self) -> &str {
260        &self.id.app_name
261    }
262
263    fn user_id(&self) -> &str {
264        &self.id.user_id
265    }
266
267    fn state(&self) -> &dyn State {
268        self
269    }
270
271    fn events(&self) -> &dyn Events {
272        self
273    }
274
275    fn last_update_time(&self) -> DateTime<Utc> {
276        self.updated_at
277    }
278}
279
280impl State for InMemorySession {
281    fn get(&self, key: &str) -> Option<Value> {
282        self.state.get(key).cloned()
283    }
284
285    fn set(&mut self, key: String, value: Value) {
286        self.state.insert(key, value);
287    }
288
289    fn all(&self) -> HashMap<String, Value> {
290        self.state.clone()
291    }
292}
293
294impl Events for InMemorySession {
295    fn all(&self) -> Vec<Event> {
296        self.events.clone()
297    }
298
299    fn len(&self) -> usize {
300        self.events.len()
301    }
302
303    fn at(&self, index: usize) -> Option<&Event> {
304        self.events.get(index)
305    }
306}