1use crate::{
2 AppendEventRequest, CreateRequest, DeleteRequest, Event, Events, GetRequest, KEY_PREFIX_APP,
3 KEY_PREFIX_TEMP, KEY_PREFIX_USER, ListRequest, Session, SessionService, State,
4};
5use adk_core::Result;
6use adk_core::identity::{AdkIdentity, AppName, SessionId, UserId};
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use serde_json::Value;
10use std::collections::HashMap;
11use std::sync::{Arc, RwLock};
12use uuid::Uuid;
13
14type StateMap = HashMap<String, Value>;
15
16#[derive(Clone)]
17struct SessionData {
18 identity: AdkIdentity,
19 events: Vec<Event>,
20 state: StateMap,
21 updated_at: DateTime<Utc>,
22}
23
24pub struct InMemorySessionService {
28 sessions: Arc<RwLock<HashMap<AdkIdentity, SessionData>>>,
29 app_state: Arc<RwLock<HashMap<String, StateMap>>>,
30 user_state: Arc<RwLock<HashMap<String, HashMap<String, StateMap>>>>,
31}
32
33impl InMemorySessionService {
34 pub fn new() -> Self {
36 Self {
37 sessions: Arc::new(RwLock::new(HashMap::new())),
38 app_state: Arc::new(RwLock::new(HashMap::new())),
39 user_state: Arc::new(RwLock::new(HashMap::new())),
40 }
41 }
42
43 fn extract_state_deltas(delta: &HashMap<String, Value>) -> (StateMap, StateMap, StateMap) {
44 let mut app_delta = StateMap::new();
45 let mut user_delta = StateMap::new();
46 let mut session_delta = StateMap::new();
47
48 for (key, value) in delta {
49 if let Some(clean_key) = key.strip_prefix(KEY_PREFIX_APP) {
50 app_delta.insert(clean_key.to_string(), value.clone());
51 } else if let Some(clean_key) = key.strip_prefix(KEY_PREFIX_USER) {
52 user_delta.insert(clean_key.to_string(), value.clone());
53 } else if !key.starts_with(KEY_PREFIX_TEMP) {
54 session_delta.insert(key.clone(), value.clone());
55 }
56 }
57
58 (app_delta, user_delta, session_delta)
59 }
60
61 fn merge_states(app: &StateMap, user: &StateMap, session: &StateMap) -> StateMap {
62 let mut merged = session.clone();
63 for (k, v) in app {
64 merged.insert(format!("{KEY_PREFIX_APP}{k}"), v.clone());
65 }
66 for (k, v) in user {
67 merged.insert(format!("{KEY_PREFIX_USER}{k}"), v.clone());
68 }
69 merged
70 }
71
72 fn make_identity(app_name: &str, user_id: &str, session_id: &str) -> Result<AdkIdentity> {
75 Ok(AdkIdentity::new(
76 AppName::try_from(app_name)
77 .map_err(|e| adk_core::AdkError::session(format!("invalid app_name: {e}")))?,
78 UserId::try_from(user_id)
79 .map_err(|e| adk_core::AdkError::session(format!("invalid user_id: {e}")))?,
80 SessionId::try_from(session_id)
81 .map_err(|e| adk_core::AdkError::session(format!("invalid session_id: {e}")))?,
82 ))
83 }
84}
85
86impl Default for InMemorySessionService {
87 fn default() -> Self {
88 Self::new()
89 }
90}
91
92#[async_trait]
93impl SessionService for InMemorySessionService {
94 async fn create(&self, req: CreateRequest) -> Result<Box<dyn Session>> {
95 let session_id_str = req.session_id.unwrap_or_else(|| Uuid::new_v4().to_string());
96
97 let identity = Self::make_identity(&req.app_name, &req.user_id, &session_id_str)?;
98
99 let (app_delta, user_delta, session_state) = Self::extract_state_deltas(&req.state);
100
101 let mut app_state_lock = self.app_state.write().unwrap();
102 let app_state = app_state_lock.entry(req.app_name.clone()).or_default();
103 app_state.extend(app_delta);
104 let app_state_clone = app_state.clone();
105 drop(app_state_lock);
106
107 let mut user_state_lock = self.user_state.write().unwrap();
108 let user_map = user_state_lock.entry(req.app_name.clone()).or_default();
109 let user_state = user_map.entry(req.user_id.clone()).or_default();
110 user_state.extend(user_delta);
111 let user_state_clone = user_state.clone();
112 drop(user_state_lock);
113
114 let merged_state = Self::merge_states(&app_state_clone, &user_state_clone, &session_state);
115
116 let data = SessionData {
117 identity: identity.clone(),
118 events: Vec::new(),
119 state: merged_state.clone(),
120 updated_at: Utc::now(),
121 };
122
123 let mut sessions = self.sessions.write().unwrap();
124 sessions.insert(identity.clone(), data);
125 drop(sessions);
126
127 Ok(Box::new(InMemorySession {
128 identity,
129 state: merged_state,
130 events: Vec::new(),
131 updated_at: Utc::now(),
132 }))
133 }
134
135 async fn get(&self, req: GetRequest) -> Result<Box<dyn Session>> {
136 let identity = Self::make_identity(&req.app_name, &req.user_id, &req.session_id)?;
137
138 let sessions = self.sessions.read().unwrap();
139 let data = sessions
140 .get(&identity)
141 .ok_or_else(|| adk_core::AdkError::session("session not found"))?;
142
143 let app_state_lock = self.app_state.read().unwrap();
144 let app_state = app_state_lock.get(&req.app_name).cloned().unwrap_or_default();
145 drop(app_state_lock);
146
147 let user_state_lock = self.user_state.read().unwrap();
148 let user_state = user_state_lock
149 .get(&req.app_name)
150 .and_then(|m| m.get(&req.user_id))
151 .cloned()
152 .unwrap_or_default();
153 drop(user_state_lock);
154
155 let merged_state = Self::merge_states(&app_state, &user_state, &data.state);
156
157 let mut events = data.events.clone();
158 if let Some(num) = req.num_recent_events {
159 let start = events.len().saturating_sub(num);
160 events = events[start..].to_vec();
161 }
162 if let Some(after) = req.after {
163 events.retain(|e| e.timestamp >= after);
164 }
165
166 Ok(Box::new(InMemorySession {
167 identity: data.identity.clone(),
168 state: merged_state,
169 events,
170 updated_at: data.updated_at,
171 }))
172 }
173
174 async fn list(&self, req: ListRequest) -> Result<Vec<Box<dyn Session>>> {
175 let sessions = self.sessions.read().unwrap();
176 let offset = req.offset.unwrap_or(0);
177 let limit = req.limit.unwrap_or(usize::MAX);
178 let mut result = Vec::new();
179
180 for data in sessions.values() {
181 if data.identity.app_name.as_ref() == req.app_name
182 && data.identity.user_id.as_ref() == req.user_id
183 {
184 result.push(data.clone());
185 }
186 }
187
188 result.sort_by_key(|b| std::cmp::Reverse(b.updated_at));
190
191 let result: Vec<Box<dyn Session>> = result
192 .into_iter()
193 .skip(offset)
194 .take(limit)
195 .map(|data| {
196 Box::new(InMemorySession {
197 identity: data.identity,
198 state: data.state,
199 events: data.events,
200 updated_at: data.updated_at,
201 }) as Box<dyn Session>
202 })
203 .collect();
204
205 Ok(result)
206 }
207
208 async fn delete(&self, req: DeleteRequest) -> Result<()> {
209 let identity = Self::make_identity(&req.app_name, &req.user_id, &req.session_id)?;
210
211 let mut sessions = self.sessions.write().unwrap();
212 sessions.remove(&identity);
213 Ok(())
214 }
215
216 async fn delete_all_sessions(&self, app_name: &str, user_id: &str) -> Result<()> {
217 let mut sessions = self.sessions.write().unwrap();
218 sessions.retain(|_, data| {
219 !(data.identity.app_name.as_ref() == app_name
220 && data.identity.user_id.as_ref() == user_id)
221 });
222 Ok(())
223 }
224
225 async fn append_event(&self, session_id: &str, mut event: Event) -> Result<()> {
226 event.actions.state_delta.retain(|k, _| !k.starts_with(KEY_PREFIX_TEMP));
227
228 let (app_name, user_id, app_delta, user_delta, _session_delta) = {
229 let mut sessions = self.sessions.write().unwrap();
230 let data = sessions
231 .values_mut()
232 .find(|d| d.identity.session_id.as_ref() == session_id)
233 .ok_or_else(|| adk_core::AdkError::session("session not found"))?;
234
235 data.events.push(event.clone());
236 data.updated_at = event.timestamp;
237
238 let (app_delta, user_delta, session_delta) =
239 Self::extract_state_deltas(&event.actions.state_delta);
240 data.state.extend(session_delta.clone());
241
242 (
243 data.identity.app_name.as_ref().to_string(),
244 data.identity.user_id.as_ref().to_string(),
245 app_delta,
246 user_delta,
247 session_delta,
248 )
249 };
250
251 if !app_delta.is_empty() {
252 let mut app_state_lock = self.app_state.write().unwrap();
253 let app_state = app_state_lock.entry(app_name.clone()).or_default();
254 app_state.extend(app_delta);
255 }
256
257 if !user_delta.is_empty() {
258 let mut user_state_lock = self.user_state.write().unwrap();
259 let user_map = user_state_lock.entry(app_name).or_default();
260 let user_state = user_map.entry(user_id).or_default();
261 user_state.extend(user_delta);
262 }
263
264 Ok(())
265 }
266
267 async fn append_event_for_identity(&self, req: AppendEventRequest) -> Result<()> {
268 let mut event = req.event;
269 event.actions.state_delta.retain(|k, _| !k.starts_with(KEY_PREFIX_TEMP));
270
271 let identity = req.identity;
272
273 let (app_name_str, user_id_str, app_delta, user_delta) = {
274 let mut sessions = self.sessions.write().unwrap();
275 let data = sessions
276 .get_mut(&identity)
277 .ok_or_else(|| adk_core::AdkError::session("session not found"))?;
278
279 data.events.push(event.clone());
280 data.updated_at = event.timestamp;
281
282 let (app_delta, user_delta, session_delta) =
283 Self::extract_state_deltas(&event.actions.state_delta);
284 data.state.extend(session_delta);
285
286 (
287 identity.app_name.as_ref().to_string(),
288 identity.user_id.as_ref().to_string(),
289 app_delta,
290 user_delta,
291 )
292 };
293
294 if !app_delta.is_empty() {
295 let mut app_state_lock = self.app_state.write().unwrap();
296 let app_state = app_state_lock.entry(app_name_str.clone()).or_default();
297 app_state.extend(app_delta);
298 }
299
300 if !user_delta.is_empty() {
301 let mut user_state_lock = self.user_state.write().unwrap();
302 let user_map = user_state_lock.entry(app_name_str).or_default();
303 let user_state = user_map.entry(user_id_str).or_default();
304 user_state.extend(user_delta);
305 }
306
307 Ok(())
308 }
309
310 async fn get_for_identity(&self, identity: &AdkIdentity) -> Result<Box<dyn Session>> {
311 let sessions = self.sessions.read().unwrap();
312 let data = sessions
313 .get(identity)
314 .ok_or_else(|| adk_core::AdkError::session("session not found"))?;
315
316 let app_state_lock = self.app_state.read().unwrap();
317 let app_state = app_state_lock.get(identity.app_name.as_ref()).cloned().unwrap_or_default();
318 drop(app_state_lock);
319
320 let user_state_lock = self.user_state.read().unwrap();
321 let user_state = user_state_lock
322 .get(identity.app_name.as_ref())
323 .and_then(|m| m.get(identity.user_id.as_ref()))
324 .cloned()
325 .unwrap_or_default();
326 drop(user_state_lock);
327
328 let merged_state = Self::merge_states(&app_state, &user_state, &data.state);
329
330 Ok(Box::new(InMemorySession {
331 identity: data.identity.clone(),
332 state: merged_state,
333 events: data.events.clone(),
334 updated_at: data.updated_at,
335 }))
336 }
337
338 async fn delete_for_identity(&self, identity: &AdkIdentity) -> Result<()> {
339 let mut sessions = self.sessions.write().unwrap();
340 sessions.remove(identity);
341 Ok(())
342 }
343}
344
345struct InMemorySession {
346 identity: AdkIdentity,
347 state: StateMap,
348 events: Vec<Event>,
349 updated_at: DateTime<Utc>,
350}
351
352impl Session for InMemorySession {
353 fn id(&self) -> &str {
354 self.identity.session_id.as_ref()
355 }
356
357 fn app_name(&self) -> &str {
358 self.identity.app_name.as_ref()
359 }
360
361 fn user_id(&self) -> &str {
362 self.identity.user_id.as_ref()
363 }
364
365 fn state(&self) -> &dyn State {
366 self
367 }
368
369 fn events(&self) -> &dyn Events {
370 self
371 }
372
373 fn last_update_time(&self) -> DateTime<Utc> {
374 self.updated_at
375 }
376}
377
378impl State for InMemorySession {
379 fn get(&self, key: &str) -> Option<Value> {
380 self.state.get(key).cloned()
381 }
382
383 fn set(&mut self, key: String, value: Value) {
384 if let Err(msg) = adk_core::validate_state_key(&key) {
385 tracing::warn!(key = %key, "rejecting invalid state key: {msg}");
386 return;
387 }
388 self.state.insert(key, value);
389 }
390
391 fn all(&self) -> HashMap<String, Value> {
392 self.state.clone()
393 }
394}
395
396impl Events for InMemorySession {
397 fn all(&self) -> Vec<Event> {
398 self.events.clone()
399 }
400
401 fn len(&self) -> usize {
402 self.events.len()
403 }
404
405 fn at(&self, index: usize) -> Option<&Event> {
406 self.events.get(index)
407 }
408}