1use crate::{
2 AppendEventRequest, CreateRequest, DeleteRequest, Event, Events, GetRequest, KEY_PREFIX_APP,
3 KEY_PREFIX_TEMP, KEY_PREFIX_USER, ListRequest, Session, SessionService, State,
4};
5use adk_core::Result;
6use adk_core::identity::{AdkIdentity, AppName, SessionId, UserId};
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use serde_json::Value;
10use std::collections::HashMap;
11use std::sync::{Arc, RwLock};
12use uuid::Uuid;
13
14type StateMap = HashMap<String, Value>;
15
16#[derive(Clone)]
17struct SessionData {
18 identity: AdkIdentity,
19 events: Vec<Event>,
20 state: StateMap,
21 updated_at: DateTime<Utc>,
22}
23
24pub struct InMemorySessionService {
25 sessions: Arc<RwLock<HashMap<AdkIdentity, SessionData>>>,
26 app_state: Arc<RwLock<HashMap<String, StateMap>>>,
27 user_state: Arc<RwLock<HashMap<String, HashMap<String, StateMap>>>>,
28}
29
30impl InMemorySessionService {
31 pub fn new() -> Self {
32 Self {
33 sessions: Arc::new(RwLock::new(HashMap::new())),
34 app_state: Arc::new(RwLock::new(HashMap::new())),
35 user_state: Arc::new(RwLock::new(HashMap::new())),
36 }
37 }
38
39 fn extract_state_deltas(delta: &HashMap<String, Value>) -> (StateMap, StateMap, StateMap) {
40 let mut app_delta = StateMap::new();
41 let mut user_delta = StateMap::new();
42 let mut session_delta = StateMap::new();
43
44 for (key, value) in delta {
45 if let Some(clean_key) = key.strip_prefix(KEY_PREFIX_APP) {
46 app_delta.insert(clean_key.to_string(), value.clone());
47 } else if let Some(clean_key) = key.strip_prefix(KEY_PREFIX_USER) {
48 user_delta.insert(clean_key.to_string(), value.clone());
49 } else if !key.starts_with(KEY_PREFIX_TEMP) {
50 session_delta.insert(key.clone(), value.clone());
51 }
52 }
53
54 (app_delta, user_delta, session_delta)
55 }
56
57 fn merge_states(app: &StateMap, user: &StateMap, session: &StateMap) -> StateMap {
58 let mut merged = session.clone();
59 for (k, v) in app {
60 merged.insert(format!("{KEY_PREFIX_APP}{k}"), v.clone());
61 }
62 for (k, v) in user {
63 merged.insert(format!("{KEY_PREFIX_USER}{k}"), v.clone());
64 }
65 merged
66 }
67
68 fn make_identity(app_name: &str, user_id: &str, session_id: &str) -> Result<AdkIdentity> {
71 Ok(AdkIdentity::new(
72 AppName::try_from(app_name)
73 .map_err(|e| adk_core::AdkError::Session(format!("invalid app_name: {e}")))?,
74 UserId::try_from(user_id)
75 .map_err(|e| adk_core::AdkError::Session(format!("invalid user_id: {e}")))?,
76 SessionId::try_from(session_id)
77 .map_err(|e| adk_core::AdkError::Session(format!("invalid session_id: {e}")))?,
78 ))
79 }
80}
81
82impl Default for InMemorySessionService {
83 fn default() -> Self {
84 Self::new()
85 }
86}
87
88#[async_trait]
89impl SessionService for InMemorySessionService {
90 async fn create(&self, req: CreateRequest) -> Result<Box<dyn Session>> {
91 let session_id_str = req.session_id.unwrap_or_else(|| Uuid::new_v4().to_string());
92
93 let identity = Self::make_identity(&req.app_name, &req.user_id, &session_id_str)?;
94
95 let (app_delta, user_delta, session_state) = Self::extract_state_deltas(&req.state);
96
97 let mut app_state_lock = self.app_state.write().unwrap();
98 let app_state = app_state_lock.entry(req.app_name.clone()).or_default();
99 app_state.extend(app_delta);
100 let app_state_clone = app_state.clone();
101 drop(app_state_lock);
102
103 let mut user_state_lock = self.user_state.write().unwrap();
104 let user_map = user_state_lock.entry(req.app_name.clone()).or_default();
105 let user_state = user_map.entry(req.user_id.clone()).or_default();
106 user_state.extend(user_delta);
107 let user_state_clone = user_state.clone();
108 drop(user_state_lock);
109
110 let merged_state = Self::merge_states(&app_state_clone, &user_state_clone, &session_state);
111
112 let data = SessionData {
113 identity: identity.clone(),
114 events: Vec::new(),
115 state: merged_state.clone(),
116 updated_at: Utc::now(),
117 };
118
119 let mut sessions = self.sessions.write().unwrap();
120 sessions.insert(identity.clone(), data);
121 drop(sessions);
122
123 Ok(Box::new(InMemorySession {
124 identity,
125 state: merged_state,
126 events: Vec::new(),
127 updated_at: Utc::now(),
128 }))
129 }
130
131 async fn get(&self, req: GetRequest) -> Result<Box<dyn Session>> {
132 let identity = Self::make_identity(&req.app_name, &req.user_id, &req.session_id)?;
133
134 let sessions = self.sessions.read().unwrap();
135 let data = sessions
136 .get(&identity)
137 .ok_or_else(|| adk_core::AdkError::Session("session not found".into()))?;
138
139 let app_state_lock = self.app_state.read().unwrap();
140 let app_state = app_state_lock.get(&req.app_name).cloned().unwrap_or_default();
141 drop(app_state_lock);
142
143 let user_state_lock = self.user_state.read().unwrap();
144 let user_state = user_state_lock
145 .get(&req.app_name)
146 .and_then(|m| m.get(&req.user_id))
147 .cloned()
148 .unwrap_or_default();
149 drop(user_state_lock);
150
151 let merged_state = Self::merge_states(&app_state, &user_state, &data.state);
152
153 let mut events = data.events.clone();
154 if let Some(num) = req.num_recent_events {
155 let start = events.len().saturating_sub(num);
156 events = events[start..].to_vec();
157 }
158 if let Some(after) = req.after {
159 events.retain(|e| e.timestamp >= after);
160 }
161
162 Ok(Box::new(InMemorySession {
163 identity: data.identity.clone(),
164 state: merged_state,
165 events,
166 updated_at: data.updated_at,
167 }))
168 }
169
170 async fn list(&self, req: ListRequest) -> Result<Vec<Box<dyn Session>>> {
171 let sessions = self.sessions.read().unwrap();
172 let offset = req.offset.unwrap_or(0);
173 let limit = req.limit.unwrap_or(usize::MAX);
174 let mut result = Vec::new();
175
176 for data in sessions.values() {
177 if data.identity.app_name.as_ref() == req.app_name
178 && data.identity.user_id.as_ref() == req.user_id
179 {
180 result.push(data.clone());
181 }
182 }
183
184 result.sort_by(|a, b| b.updated_at.cmp(&a.updated_at));
186
187 let result: Vec<Box<dyn Session>> = result
188 .into_iter()
189 .skip(offset)
190 .take(limit)
191 .map(|data| {
192 Box::new(InMemorySession {
193 identity: data.identity,
194 state: data.state,
195 events: data.events,
196 updated_at: data.updated_at,
197 }) as Box<dyn Session>
198 })
199 .collect();
200
201 Ok(result)
202 }
203
204 async fn delete(&self, req: DeleteRequest) -> Result<()> {
205 let identity = Self::make_identity(&req.app_name, &req.user_id, &req.session_id)?;
206
207 let mut sessions = self.sessions.write().unwrap();
208 sessions.remove(&identity);
209 Ok(())
210 }
211
212 async fn delete_all_sessions(&self, app_name: &str, user_id: &str) -> Result<()> {
213 let mut sessions = self.sessions.write().unwrap();
214 sessions.retain(|_, data| {
215 !(data.identity.app_name.as_ref() == app_name
216 && data.identity.user_id.as_ref() == user_id)
217 });
218 Ok(())
219 }
220
221 async fn append_event(&self, session_id: &str, mut event: Event) -> Result<()> {
222 event.actions.state_delta.retain(|k, _| !k.starts_with(KEY_PREFIX_TEMP));
223
224 let (app_name, user_id, app_delta, user_delta, _session_delta) = {
225 let mut sessions = self.sessions.write().unwrap();
226 let data = sessions
227 .values_mut()
228 .find(|d| d.identity.session_id.as_ref() == session_id)
229 .ok_or_else(|| adk_core::AdkError::Session("session not found".into()))?;
230
231 data.events.push(event.clone());
232 data.updated_at = event.timestamp;
233
234 let (app_delta, user_delta, session_delta) =
235 Self::extract_state_deltas(&event.actions.state_delta);
236 data.state.extend(session_delta.clone());
237
238 (
239 data.identity.app_name.as_ref().to_string(),
240 data.identity.user_id.as_ref().to_string(),
241 app_delta,
242 user_delta,
243 session_delta,
244 )
245 };
246
247 if !app_delta.is_empty() {
248 let mut app_state_lock = self.app_state.write().unwrap();
249 let app_state = app_state_lock.entry(app_name.clone()).or_default();
250 app_state.extend(app_delta);
251 }
252
253 if !user_delta.is_empty() {
254 let mut user_state_lock = self.user_state.write().unwrap();
255 let user_map = user_state_lock.entry(app_name).or_default();
256 let user_state = user_map.entry(user_id).or_default();
257 user_state.extend(user_delta);
258 }
259
260 Ok(())
261 }
262
263 async fn append_event_for_identity(&self, req: AppendEventRequest) -> Result<()> {
264 let mut event = req.event;
265 event.actions.state_delta.retain(|k, _| !k.starts_with(KEY_PREFIX_TEMP));
266
267 let identity = req.identity;
268
269 let (app_name_str, user_id_str, app_delta, user_delta) = {
270 let mut sessions = self.sessions.write().unwrap();
271 let data = sessions
272 .get_mut(&identity)
273 .ok_or_else(|| adk_core::AdkError::Session("session not found".into()))?;
274
275 data.events.push(event.clone());
276 data.updated_at = event.timestamp;
277
278 let (app_delta, user_delta, session_delta) =
279 Self::extract_state_deltas(&event.actions.state_delta);
280 data.state.extend(session_delta);
281
282 (
283 identity.app_name.as_ref().to_string(),
284 identity.user_id.as_ref().to_string(),
285 app_delta,
286 user_delta,
287 )
288 };
289
290 if !app_delta.is_empty() {
291 let mut app_state_lock = self.app_state.write().unwrap();
292 let app_state = app_state_lock.entry(app_name_str.clone()).or_default();
293 app_state.extend(app_delta);
294 }
295
296 if !user_delta.is_empty() {
297 let mut user_state_lock = self.user_state.write().unwrap();
298 let user_map = user_state_lock.entry(app_name_str).or_default();
299 let user_state = user_map.entry(user_id_str).or_default();
300 user_state.extend(user_delta);
301 }
302
303 Ok(())
304 }
305
306 async fn get_for_identity(&self, identity: &AdkIdentity) -> Result<Box<dyn Session>> {
307 let sessions = self.sessions.read().unwrap();
308 let data = sessions
309 .get(identity)
310 .ok_or_else(|| adk_core::AdkError::Session("session not found".into()))?;
311
312 let app_state_lock = self.app_state.read().unwrap();
313 let app_state = app_state_lock.get(identity.app_name.as_ref()).cloned().unwrap_or_default();
314 drop(app_state_lock);
315
316 let user_state_lock = self.user_state.read().unwrap();
317 let user_state = user_state_lock
318 .get(identity.app_name.as_ref())
319 .and_then(|m| m.get(identity.user_id.as_ref()))
320 .cloned()
321 .unwrap_or_default();
322 drop(user_state_lock);
323
324 let merged_state = Self::merge_states(&app_state, &user_state, &data.state);
325
326 Ok(Box::new(InMemorySession {
327 identity: data.identity.clone(),
328 state: merged_state,
329 events: data.events.clone(),
330 updated_at: data.updated_at,
331 }))
332 }
333
334 async fn delete_for_identity(&self, identity: &AdkIdentity) -> Result<()> {
335 let mut sessions = self.sessions.write().unwrap();
336 sessions.remove(identity);
337 Ok(())
338 }
339}
340
341struct InMemorySession {
342 identity: AdkIdentity,
343 state: StateMap,
344 events: Vec<Event>,
345 updated_at: DateTime<Utc>,
346}
347
348impl Session for InMemorySession {
349 fn id(&self) -> &str {
350 self.identity.session_id.as_ref()
351 }
352
353 fn app_name(&self) -> &str {
354 self.identity.app_name.as_ref()
355 }
356
357 fn user_id(&self) -> &str {
358 self.identity.user_id.as_ref()
359 }
360
361 fn state(&self) -> &dyn State {
362 self
363 }
364
365 fn events(&self) -> &dyn Events {
366 self
367 }
368
369 fn last_update_time(&self) -> DateTime<Utc> {
370 self.updated_at
371 }
372}
373
374impl State for InMemorySession {
375 fn get(&self, key: &str) -> Option<Value> {
376 self.state.get(key).cloned()
377 }
378
379 fn set(&mut self, key: String, value: Value) {
380 self.state.insert(key, value);
381 }
382
383 fn all(&self) -> HashMap<String, Value> {
384 self.state.clone()
385 }
386}
387
388impl Events for InMemorySession {
389 fn all(&self) -> Vec<Event> {
390 self.events.clone()
391 }
392
393 fn len(&self) -> usize {
394 self.events.len()
395 }
396
397 fn at(&self, index: usize) -> Option<&Event> {
398 self.events.get(index)
399 }
400}