1use crate::{
2 CreateRequest, DeleteRequest, Event, Events, GetRequest, ListRequest, Session, SessionService,
3 State, KEY_PREFIX_APP, KEY_PREFIX_TEMP, KEY_PREFIX_USER,
4};
5use adk_core::Result;
6use async_trait::async_trait;
7use chrono::{DateTime, Utc};
8use serde_json::Value;
9use std::collections::HashMap;
10use std::sync::{Arc, RwLock};
11use uuid::Uuid;
12
13type StateMap = HashMap<String, Value>;
14
15#[derive(Clone)]
16struct SessionData {
17 id: SessionId,
18 events: Vec<Event>,
19 state: StateMap,
20 updated_at: DateTime<Utc>,
21}
22
23#[derive(Clone, Debug, PartialEq, Eq, Hash)]
24struct SessionId {
25 app_name: String,
26 user_id: String,
27 session_id: String,
28}
29
30impl SessionId {
31 fn key(&self) -> String {
32 format!("{}:{}:{}", self.app_name, self.user_id, self.session_id)
33 }
34}
35
36pub struct InMemorySessionService {
37 sessions: Arc<RwLock<HashMap<String, SessionData>>>,
38 app_state: Arc<RwLock<HashMap<String, StateMap>>>,
39 user_state: Arc<RwLock<HashMap<String, HashMap<String, StateMap>>>>,
40}
41
42impl InMemorySessionService {
43 pub fn new() -> Self {
44 Self {
45 sessions: Arc::new(RwLock::new(HashMap::new())),
46 app_state: Arc::new(RwLock::new(HashMap::new())),
47 user_state: Arc::new(RwLock::new(HashMap::new())),
48 }
49 }
50
51 fn extract_state_deltas(delta: &HashMap<String, Value>) -> (StateMap, StateMap, StateMap) {
52 let mut app_delta = StateMap::new();
53 let mut user_delta = StateMap::new();
54 let mut session_delta = StateMap::new();
55
56 for (key, value) in delta {
57 if let Some(clean_key) = key.strip_prefix(KEY_PREFIX_APP) {
58 app_delta.insert(clean_key.to_string(), value.clone());
59 } else if let Some(clean_key) = key.strip_prefix(KEY_PREFIX_USER) {
60 user_delta.insert(clean_key.to_string(), value.clone());
61 } else if !key.starts_with(KEY_PREFIX_TEMP) {
62 session_delta.insert(key.clone(), value.clone());
63 }
64 }
65
66 (app_delta, user_delta, session_delta)
67 }
68
69 fn merge_states(app: &StateMap, user: &StateMap, session: &StateMap) -> StateMap {
70 let mut merged = session.clone();
71 for (k, v) in app {
72 merged.insert(format!("{}{}", KEY_PREFIX_APP, k), v.clone());
73 }
74 for (k, v) in user {
75 merged.insert(format!("{}{}", KEY_PREFIX_USER, k), v.clone());
76 }
77 merged
78 }
79}
80
81impl Default for InMemorySessionService {
82 fn default() -> Self {
83 Self::new()
84 }
85}
86
87#[async_trait]
88impl SessionService for InMemorySessionService {
89 async fn create(&self, req: CreateRequest) -> Result<Box<dyn Session>> {
90 let session_id = req.session_id.unwrap_or_else(|| Uuid::new_v4().to_string());
91
92 let id = SessionId {
93 app_name: req.app_name.clone(),
94 user_id: req.user_id.clone(),
95 session_id: session_id.clone(),
96 };
97
98 let (app_delta, user_delta, session_state) = Self::extract_state_deltas(&req.state);
99
100 let mut app_state_lock = self.app_state.write().unwrap();
101 let app_state = app_state_lock.entry(req.app_name.clone()).or_default();
102 app_state.extend(app_delta);
103 let app_state_clone = app_state.clone();
104 drop(app_state_lock);
105
106 let mut user_state_lock = self.user_state.write().unwrap();
107 let user_map = user_state_lock.entry(req.app_name.clone()).or_default();
108 let user_state = user_map.entry(req.user_id.clone()).or_default();
109 user_state.extend(user_delta);
110 let user_state_clone = user_state.clone();
111 drop(user_state_lock);
112
113 let merged_state = Self::merge_states(&app_state_clone, &user_state_clone, &session_state);
114
115 let data = SessionData {
116 id: id.clone(),
117 events: Vec::new(),
118 state: merged_state.clone(),
119 updated_at: Utc::now(),
120 };
121
122 let mut sessions = self.sessions.write().unwrap();
123 sessions.insert(id.key(), data);
124 drop(sessions);
125
126 Ok(Box::new(InMemorySession {
127 id,
128 state: merged_state,
129 events: Vec::new(),
130 updated_at: Utc::now(),
131 }))
132 }
133
134 async fn get(&self, req: GetRequest) -> Result<Box<dyn Session>> {
135 let id = SessionId {
136 app_name: req.app_name.clone(),
137 user_id: req.user_id.clone(),
138 session_id: req.session_id.clone(),
139 };
140
141 let sessions = self.sessions.read().unwrap();
142 let data = sessions
143 .get(&id.key())
144 .ok_or_else(|| adk_core::AdkError::Session("session not found".into()))?;
145
146 let app_state_lock = self.app_state.read().unwrap();
147 let app_state = app_state_lock.get(&req.app_name).cloned().unwrap_or_default();
148 drop(app_state_lock);
149
150 let user_state_lock = self.user_state.read().unwrap();
151 let user_state = user_state_lock
152 .get(&req.app_name)
153 .and_then(|m| m.get(&req.user_id))
154 .cloned()
155 .unwrap_or_default();
156 drop(user_state_lock);
157
158 let merged_state = Self::merge_states(&app_state, &user_state, &data.state);
159
160 let mut events = data.events.clone();
161 if let Some(num) = req.num_recent_events {
162 let start = events.len().saturating_sub(num);
163 events = events[start..].to_vec();
164 }
165 if let Some(after) = req.after {
166 events.retain(|e| e.timestamp >= after);
167 }
168
169 Ok(Box::new(InMemorySession {
170 id: data.id.clone(),
171 state: merged_state,
172 events,
173 updated_at: data.updated_at,
174 }))
175 }
176
177 async fn list(&self, req: ListRequest) -> Result<Vec<Box<dyn Session>>> {
178 let sessions = self.sessions.read().unwrap();
179 let mut result = Vec::new();
180
181 for data in sessions.values() {
182 if data.id.app_name == req.app_name && data.id.user_id == req.user_id {
183 result.push(Box::new(InMemorySession {
184 id: data.id.clone(),
185 state: data.state.clone(),
186 events: data.events.clone(),
187 updated_at: data.updated_at,
188 }) as Box<dyn Session>);
189 }
190 }
191
192 Ok(result)
193 }
194
195 async fn delete(&self, req: DeleteRequest) -> Result<()> {
196 let id =
197 SessionId { app_name: req.app_name, user_id: req.user_id, session_id: req.session_id };
198
199 let mut sessions = self.sessions.write().unwrap();
200 sessions.remove(&id.key());
201 Ok(())
202 }
203
204 async fn append_event(&self, session_id: &str, mut event: Event) -> Result<()> {
205 event.actions.state_delta.retain(|k, _| !k.starts_with(KEY_PREFIX_TEMP));
206
207 let (app_name, user_id, app_delta, user_delta, _session_delta) = {
208 let mut sessions = self.sessions.write().unwrap();
209 let data = sessions
210 .values_mut()
211 .find(|d| d.id.session_id == session_id)
212 .ok_or_else(|| adk_core::AdkError::Session("session not found".into()))?;
213
214 data.events.push(event.clone());
215 data.updated_at = event.timestamp;
216
217 let (app_delta, user_delta, session_delta) =
218 Self::extract_state_deltas(&event.actions.state_delta);
219 data.state.extend(session_delta.clone());
220
221 (
222 data.id.app_name.clone(),
223 data.id.user_id.clone(),
224 app_delta,
225 user_delta,
226 session_delta,
227 )
228 };
229
230 if !app_delta.is_empty() {
231 let mut app_state_lock = self.app_state.write().unwrap();
232 let app_state = app_state_lock.entry(app_name.clone()).or_default();
233 app_state.extend(app_delta);
234 }
235
236 if !user_delta.is_empty() {
237 let mut user_state_lock = self.user_state.write().unwrap();
238 let user_map = user_state_lock.entry(app_name).or_default();
239 let user_state = user_map.entry(user_id).or_default();
240 user_state.extend(user_delta);
241 }
242
243 Ok(())
244 }
245}
246
247struct InMemorySession {
248 id: SessionId,
249 state: StateMap,
250 events: Vec<Event>,
251 updated_at: DateTime<Utc>,
252}
253
254impl Session for InMemorySession {
255 fn id(&self) -> &str {
256 &self.id.session_id
257 }
258
259 fn app_name(&self) -> &str {
260 &self.id.app_name
261 }
262
263 fn user_id(&self) -> &str {
264 &self.id.user_id
265 }
266
267 fn state(&self) -> &dyn State {
268 self
269 }
270
271 fn events(&self) -> &dyn Events {
272 self
273 }
274
275 fn last_update_time(&self) -> DateTime<Utc> {
276 self.updated_at
277 }
278}
279
280impl State for InMemorySession {
281 fn get(&self, key: &str) -> Option<Value> {
282 self.state.get(key).cloned()
283 }
284
285 fn set(&mut self, key: String, value: Value) {
286 self.state.insert(key, value);
287 }
288
289 fn all(&self) -> HashMap<String, Value> {
290 self.state.clone()
291 }
292}
293
294impl Events for InMemorySession {
295 fn all(&self) -> Vec<Event> {
296 self.events.clone()
297 }
298
299 fn len(&self) -> usize {
300 self.events.len()
301 }
302
303 fn at(&self, index: usize) -> Option<&Event> {
304 self.events.get(index)
305 }
306}