1use crate::{
2 AppendEventRequest, CreateRequest, DeleteRequest, Event, Events, GetRequest, KEY_PREFIX_TEMP,
3 ListRequest, Session, SessionService, State, state_utils,
4};
5use adk_core::Result;
6use adk_core::identity::{AdkIdentity, AppName, SessionId, UserId};
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use serde_json::Value;
10use std::collections::HashMap;
11use std::sync::{Arc, RwLock};
12use uuid::Uuid;
13
14type StateMap = HashMap<String, Value>;
15
16#[derive(Clone)]
17struct SessionData {
18 identity: AdkIdentity,
19 events: Vec<Event>,
20 state: StateMap,
21 updated_at: DateTime<Utc>,
22}
23
24pub struct InMemorySessionService {
28 sessions: Arc<RwLock<HashMap<AdkIdentity, SessionData>>>,
29 app_state: Arc<RwLock<HashMap<String, StateMap>>>,
30 user_state: Arc<RwLock<HashMap<String, HashMap<String, StateMap>>>>,
31}
32
33impl InMemorySessionService {
34 pub fn new() -> Self {
36 Self {
37 sessions: Arc::new(RwLock::new(HashMap::new())),
38 app_state: Arc::new(RwLock::new(HashMap::new())),
39 user_state: Arc::new(RwLock::new(HashMap::new())),
40 }
41 }
42
43 fn extract_state_deltas(delta: &HashMap<String, Value>) -> (StateMap, StateMap, StateMap) {
44 state_utils::extract_state_deltas(delta)
45 }
46
47 fn merge_states(app: &StateMap, user: &StateMap, session: &StateMap) -> StateMap {
48 state_utils::merge_states(app, user, session)
49 }
50
51 fn make_identity(app_name: &str, user_id: &str, session_id: &str) -> Result<AdkIdentity> {
54 Ok(AdkIdentity::new(
55 AppName::try_from(app_name)
56 .map_err(|e| adk_core::AdkError::session(format!("invalid app_name: {e}")))?,
57 UserId::try_from(user_id)
58 .map_err(|e| adk_core::AdkError::session(format!("invalid user_id: {e}")))?,
59 SessionId::try_from(session_id)
60 .map_err(|e| adk_core::AdkError::session(format!("invalid session_id: {e}")))?,
61 ))
62 }
63
64 async fn rewind_to_empty(&self, session_id: &str) -> Result<Box<dyn Session>> {
66 let mut sessions = self.sessions.write().unwrap_or_else(|e| e.into_inner());
67 let data = sessions
68 .values_mut()
69 .find(|d| d.identity.session_id.as_ref() == session_id)
70 .ok_or_else(|| adk_core::AdkError::session("session not found"))?;
71
72 data.events.clear();
73 data.state = HashMap::new();
74 data.updated_at = Utc::now();
75
76 let app_name = data.identity.app_name.as_ref().to_string();
77 let user_id = data.identity.user_id.as_ref().to_string();
78 let identity = data.identity.clone();
79 let updated_at = data.updated_at;
80 drop(sessions);
81
82 let app_state_lock = self.app_state.read().unwrap_or_else(|e| e.into_inner());
83 let app_state = app_state_lock.get(&app_name).cloned().unwrap_or_default();
84 drop(app_state_lock);
85
86 let user_state_lock = self.user_state.read().unwrap_or_else(|e| e.into_inner());
87 let user_state = user_state_lock
88 .get(&app_name)
89 .and_then(|m| m.get(&user_id))
90 .cloned()
91 .unwrap_or_default();
92 drop(user_state_lock);
93
94 let merged_state = state_utils::merge_states(&app_state, &user_state, &HashMap::new());
95
96 Ok(Box::new(InMemorySession {
97 identity,
98 state: merged_state,
99 events: Vec::new(),
100 updated_at,
101 }))
102 }
103}
104
105impl Default for InMemorySessionService {
106 fn default() -> Self {
107 Self::new()
108 }
109}
110
111#[async_trait]
112impl SessionService for InMemorySessionService {
113 async fn create(&self, req: CreateRequest) -> Result<Box<dyn Session>> {
114 let session_id_str = req.session_id.unwrap_or_else(|| Uuid::new_v4().to_string());
115
116 let identity = Self::make_identity(&req.app_name, &req.user_id, &session_id_str)?;
117
118 let (app_delta, user_delta, session_state) = Self::extract_state_deltas(&req.state);
119
120 let mut app_state_lock = self.app_state.write().unwrap_or_else(|e| e.into_inner());
121 let app_state = app_state_lock.entry(req.app_name.clone()).or_default();
122 app_state.extend(app_delta);
123 let app_state_clone = app_state.clone();
124 drop(app_state_lock);
125
126 let mut user_state_lock = self.user_state.write().unwrap_or_else(|e| e.into_inner());
127 let user_map = user_state_lock.entry(req.app_name.clone()).or_default();
128 let user_state = user_map.entry(req.user_id.clone()).or_default();
129 user_state.extend(user_delta);
130 let user_state_clone = user_state.clone();
131 drop(user_state_lock);
132
133 let merged_state = Self::merge_states(&app_state_clone, &user_state_clone, &session_state);
134
135 let data = SessionData {
136 identity: identity.clone(),
137 events: Vec::new(),
138 state: merged_state.clone(),
139 updated_at: Utc::now(),
140 };
141
142 let mut sessions = self.sessions.write().unwrap_or_else(|e| e.into_inner());
143 sessions.insert(identity.clone(), data);
144 drop(sessions);
145
146 Ok(Box::new(InMemorySession {
147 identity,
148 state: merged_state,
149 events: Vec::new(),
150 updated_at: Utc::now(),
151 }))
152 }
153
154 async fn get(&self, req: GetRequest) -> Result<Box<dyn Session>> {
155 let identity = Self::make_identity(&req.app_name, &req.user_id, &req.session_id)?;
156
157 let sessions = self.sessions.read().unwrap_or_else(|e| e.into_inner());
158 let data = sessions
159 .get(&identity)
160 .ok_or_else(|| adk_core::AdkError::session("session not found"))?;
161
162 let app_state_lock = self.app_state.read().unwrap_or_else(|e| e.into_inner());
163 let app_state = app_state_lock.get(&req.app_name).cloned().unwrap_or_default();
164 drop(app_state_lock);
165
166 let user_state_lock = self.user_state.read().unwrap_or_else(|e| e.into_inner());
167 let user_state = user_state_lock
168 .get(&req.app_name)
169 .and_then(|m| m.get(&req.user_id))
170 .cloned()
171 .unwrap_or_default();
172 drop(user_state_lock);
173
174 let merged_state = Self::merge_states(&app_state, &user_state, &data.state);
175
176 let mut events = data.events.clone();
177 if let Some(num) = req.num_recent_events {
178 let start = events.len().saturating_sub(num);
179 events = events[start..].to_vec();
180 }
181 if let Some(after) = req.after {
182 events.retain(|e| e.timestamp >= after);
183 }
184
185 Ok(Box::new(InMemorySession {
186 identity: data.identity.clone(),
187 state: merged_state,
188 events,
189 updated_at: data.updated_at,
190 }))
191 }
192
193 async fn list(&self, req: ListRequest) -> Result<Vec<Box<dyn Session>>> {
194 let sessions = self.sessions.read().unwrap_or_else(|e| e.into_inner());
195 let offset = req.offset.unwrap_or(0);
196 let limit = req.limit.unwrap_or(usize::MAX);
197 let mut result = Vec::new();
198
199 for data in sessions.values() {
200 if data.identity.app_name.as_ref() == req.app_name
201 && data.identity.user_id.as_ref() == req.user_id
202 {
203 result.push(data.clone());
204 }
205 }
206
207 result.sort_by_key(|b| std::cmp::Reverse(b.updated_at));
209
210 let result: Vec<Box<dyn Session>> = result
211 .into_iter()
212 .skip(offset)
213 .take(limit)
214 .map(|data| {
215 Box::new(InMemorySession {
216 identity: data.identity,
217 state: data.state,
218 events: data.events,
219 updated_at: data.updated_at,
220 }) as Box<dyn Session>
221 })
222 .collect();
223
224 Ok(result)
225 }
226
227 async fn delete(&self, req: DeleteRequest) -> Result<()> {
228 let identity = Self::make_identity(&req.app_name, &req.user_id, &req.session_id)?;
229
230 let mut sessions = self.sessions.write().unwrap_or_else(|e| e.into_inner());
231 sessions.remove(&identity);
232 Ok(())
233 }
234
235 async fn delete_all_sessions(&self, app_name: &str, user_id: &str) -> Result<()> {
236 let mut sessions = self.sessions.write().unwrap_or_else(|e| e.into_inner());
237 sessions.retain(|_, data| {
238 !(data.identity.app_name.as_ref() == app_name
239 && data.identity.user_id.as_ref() == user_id)
240 });
241 Ok(())
242 }
243
244 async fn append_event(&self, session_id: &str, mut event: Event) -> Result<()> {
245 event.actions.state_delta.retain(|k, _| !k.starts_with(KEY_PREFIX_TEMP));
246
247 let (app_name, user_id, app_delta, user_delta, _session_delta) = {
248 let mut sessions = self.sessions.write().unwrap_or_else(|e| e.into_inner());
249 let data = sessions
250 .values_mut()
251 .find(|d| d.identity.session_id.as_ref() == session_id)
252 .ok_or_else(|| adk_core::AdkError::session("session not found"))?;
253
254 data.events.push(event.clone());
255 data.updated_at = event.timestamp;
256
257 let (app_delta, user_delta, session_delta) =
258 Self::extract_state_deltas(&event.actions.state_delta);
259 data.state.extend(session_delta.clone());
260
261 (
262 data.identity.app_name.as_ref().to_string(),
263 data.identity.user_id.as_ref().to_string(),
264 app_delta,
265 user_delta,
266 session_delta,
267 )
268 };
269
270 if !app_delta.is_empty() {
271 let mut app_state_lock = self.app_state.write().unwrap_or_else(|e| e.into_inner());
272 let app_state = app_state_lock.entry(app_name.clone()).or_default();
273 app_state.extend(app_delta);
274 }
275
276 if !user_delta.is_empty() {
277 let mut user_state_lock = self.user_state.write().unwrap_or_else(|e| e.into_inner());
278 let user_map = user_state_lock.entry(app_name).or_default();
279 let user_state = user_map.entry(user_id).or_default();
280 user_state.extend(user_delta);
281 }
282
283 Ok(())
284 }
285
286 async fn append_event_for_identity(&self, req: AppendEventRequest) -> Result<()> {
287 let mut event = req.event;
288 event.actions.state_delta.retain(|k, _| !k.starts_with(KEY_PREFIX_TEMP));
289
290 let identity = req.identity;
291
292 let (app_name_str, user_id_str, app_delta, user_delta) = {
293 let mut sessions = self.sessions.write().unwrap_or_else(|e| e.into_inner());
294 let data = sessions
295 .get_mut(&identity)
296 .ok_or_else(|| adk_core::AdkError::session("session not found"))?;
297
298 data.events.push(event.clone());
299 data.updated_at = event.timestamp;
300
301 let (app_delta, user_delta, session_delta) =
302 Self::extract_state_deltas(&event.actions.state_delta);
303 data.state.extend(session_delta);
304
305 (
306 identity.app_name.as_ref().to_string(),
307 identity.user_id.as_ref().to_string(),
308 app_delta,
309 user_delta,
310 )
311 };
312
313 if !app_delta.is_empty() {
314 let mut app_state_lock = self.app_state.write().unwrap_or_else(|e| e.into_inner());
315 let app_state = app_state_lock.entry(app_name_str.clone()).or_default();
316 app_state.extend(app_delta);
317 }
318
319 if !user_delta.is_empty() {
320 let mut user_state_lock = self.user_state.write().unwrap_or_else(|e| e.into_inner());
321 let user_map = user_state_lock.entry(app_name_str).or_default();
322 let user_state = user_map.entry(user_id_str).or_default();
323 user_state.extend(user_delta);
324 }
325
326 Ok(())
327 }
328
329 async fn get_for_identity(&self, identity: &AdkIdentity) -> Result<Box<dyn Session>> {
330 let sessions = self.sessions.read().unwrap_or_else(|e| e.into_inner());
331 let data = sessions
332 .get(identity)
333 .ok_or_else(|| adk_core::AdkError::session("session not found"))?;
334
335 let app_state_lock = self.app_state.read().unwrap_or_else(|e| e.into_inner());
336 let app_state = app_state_lock.get(identity.app_name.as_ref()).cloned().unwrap_or_default();
337 drop(app_state_lock);
338
339 let user_state_lock = self.user_state.read().unwrap_or_else(|e| e.into_inner());
340 let user_state = user_state_lock
341 .get(identity.app_name.as_ref())
342 .and_then(|m| m.get(identity.user_id.as_ref()))
343 .cloned()
344 .unwrap_or_default();
345 drop(user_state_lock);
346
347 let merged_state = Self::merge_states(&app_state, &user_state, &data.state);
348
349 Ok(Box::new(InMemorySession {
350 identity: data.identity.clone(),
351 state: merged_state,
352 events: data.events.clone(),
353 updated_at: data.updated_at,
354 }))
355 }
356
357 async fn delete_for_identity(&self, identity: &AdkIdentity) -> Result<()> {
358 let mut sessions = self.sessions.write().unwrap_or_else(|e| e.into_inner());
359 sessions.remove(identity);
360 Ok(())
361 }
362
363 async fn rewind(&self, session_id: &str, target_event_id: &str) -> Result<Box<dyn Session>> {
364 let mut sessions = self.sessions.write().unwrap_or_else(|e| e.into_inner());
365
366 let data = sessions
368 .values_mut()
369 .find(|d| d.identity.session_id.as_ref() == session_id)
370 .ok_or_else(|| adk_core::AdkError::session("session not found"))?;
371
372 let target_index =
374 data.events.iter().position(|e| e.id == target_event_id).ok_or_else(|| {
375 adk_core::AdkError::session(format!("target event not found: {target_event_id}"))
376 })?;
377
378 data.events.truncate(target_index + 1);
380
381 let mut rebuilt_session_state: HashMap<String, Value> = HashMap::new();
383 for event in &data.events {
384 let (_app_delta, _user_delta, session_delta) =
385 state_utils::extract_state_deltas(&event.actions.state_delta);
386 rebuilt_session_state.extend(session_delta);
387 }
388
389 let app_name = data.identity.app_name.as_ref().to_string();
391 let user_id = data.identity.user_id.as_ref().to_string();
392
393 data.state = rebuilt_session_state.clone();
395 data.updated_at = data.events.last().map(|e| e.timestamp).unwrap_or(Utc::now());
396
397 let identity = data.identity.clone();
398 let events = data.events.clone();
399 let updated_at = data.updated_at;
400 drop(sessions);
401
402 let app_state_lock = self.app_state.read().unwrap_or_else(|e| e.into_inner());
404 let app_state = app_state_lock.get(&app_name).cloned().unwrap_or_default();
405 drop(app_state_lock);
406
407 let user_state_lock = self.user_state.read().unwrap_or_else(|e| e.into_inner());
408 let user_state = user_state_lock
409 .get(&app_name)
410 .and_then(|m| m.get(&user_id))
411 .cloned()
412 .unwrap_or_default();
413 drop(user_state_lock);
414
415 let merged_state =
416 state_utils::merge_states(&app_state, &user_state, &rebuilt_session_state);
417
418 Ok(Box::new(InMemorySession { identity, state: merged_state, events, updated_at }))
419 }
420
421 async fn rewind_steps(&self, session_id: &str, steps: usize) -> Result<Box<dyn Session>> {
422 if steps == 0 {
423 let sessions = self.sessions.read().unwrap_or_else(|e| e.into_inner());
425 let data = sessions
426 .values()
427 .find(|d| d.identity.session_id.as_ref() == session_id)
428 .ok_or_else(|| adk_core::AdkError::session("session not found"))?;
429
430 let app_name = data.identity.app_name.as_ref().to_string();
431 let user_id = data.identity.user_id.as_ref().to_string();
432 let identity = data.identity.clone();
433 let events = data.events.clone();
434 let session_state = data.state.clone();
435 let updated_at = data.updated_at;
436 drop(sessions);
437
438 let app_state_lock = self.app_state.read().unwrap_or_else(|e| e.into_inner());
439 let app_state = app_state_lock.get(&app_name).cloned().unwrap_or_default();
440 drop(app_state_lock);
441
442 let user_state_lock = self.user_state.read().unwrap_or_else(|e| e.into_inner());
443 let user_state = user_state_lock
444 .get(&app_name)
445 .and_then(|m| m.get(&user_id))
446 .cloned()
447 .unwrap_or_default();
448 drop(user_state_lock);
449
450 let merged_state = state_utils::merge_states(&app_state, &user_state, &session_state);
451
452 return Ok(Box::new(InMemorySession {
453 identity,
454 state: merged_state,
455 events,
456 updated_at,
457 }));
458 }
459
460 let rewind_target = {
462 let sessions = self.sessions.read().unwrap_or_else(|e| e.into_inner());
463 let data = sessions
464 .values()
465 .find(|d| d.identity.session_id.as_ref() == session_id)
466 .ok_or_else(|| adk_core::AdkError::session("session not found"))?;
467
468 if steps > data.events.len() {
469 return Err(adk_core::AdkError::session("rewind steps exceeds event count"));
470 }
471
472 let target_index = data.events.len() - steps;
473 if target_index == 0 {
474 None
476 } else {
477 Some(data.events[target_index - 1].id.clone())
478 }
479 };
480
481 match rewind_target {
482 Some(target_event_id) => self.rewind(session_id, &target_event_id).await,
483 None => self.rewind_to_empty(session_id).await,
484 }
485 }
486}
487
488struct InMemorySession {
489 identity: AdkIdentity,
490 state: StateMap,
491 events: Vec<Event>,
492 updated_at: DateTime<Utc>,
493}
494
495impl Session for InMemorySession {
496 fn id(&self) -> &str {
497 self.identity.session_id.as_ref()
498 }
499
500 fn app_name(&self) -> &str {
501 self.identity.app_name.as_ref()
502 }
503
504 fn user_id(&self) -> &str {
505 self.identity.user_id.as_ref()
506 }
507
508 fn state(&self) -> &dyn State {
509 self
510 }
511
512 fn events(&self) -> &dyn Events {
513 self
514 }
515
516 fn last_update_time(&self) -> DateTime<Utc> {
517 self.updated_at
518 }
519}
520
521impl State for InMemorySession {
522 fn get(&self, key: &str) -> Option<Value> {
523 self.state.get(key).cloned()
524 }
525
526 fn set(&mut self, key: String, value: Value) {
527 if let Err(msg) = adk_core::validate_state_key(&key) {
528 tracing::warn!(key = %key, "rejecting invalid state key: {msg}");
529 return;
530 }
531 self.state.insert(key, value);
532 }
533
534 fn all(&self) -> HashMap<String, Value> {
535 self.state.clone()
536 }
537}
538
539impl Events for InMemorySession {
540 fn all(&self) -> Vec<Event> {
541 self.events.clone()
542 }
543
544 fn len(&self) -> usize {
545 self.events.len()
546 }
547
548 fn at(&self, index: usize) -> Option<&Event> {
549 self.events.get(index)
550 }
551}