1use async_trait::async_trait;
7use cortexai_core::{errors::MemoryError, Message};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::path::Path;
11use std::sync::Arc;
12use tokio::sync::RwLock;
13use uuid::Uuid;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct Session {
18 pub id: String,
20 pub agent_id: String,
22 pub messages: Vec<Message>,
24 pub metadata: HashMap<String, serde_json::Value>,
26 pub created_at: i64,
28 pub updated_at: i64,
30 pub resume_token: Option<String>,
32}
33
34impl Session {
35 pub fn new(agent_id: impl Into<String>) -> Self {
37 let now = chrono::Utc::now().timestamp_millis();
38 Self {
39 id: Uuid::new_v4().to_string(),
40 agent_id: agent_id.into(),
41 messages: Vec::new(),
42 metadata: HashMap::new(),
43 created_at: now,
44 updated_at: now,
45 resume_token: None,
46 }
47 }
48
49 pub fn with_id(id: impl Into<String>, agent_id: impl Into<String>) -> Self {
51 let now = chrono::Utc::now().timestamp_millis();
52 Self {
53 id: id.into(),
54 agent_id: agent_id.into(),
55 messages: Vec::new(),
56 metadata: HashMap::new(),
57 created_at: now,
58 updated_at: now,
59 resume_token: None,
60 }
61 }
62
63 pub fn add_message(&mut self, message: Message) {
65 self.messages.push(message);
66 self.updated_at = chrono::Utc::now().timestamp_millis();
67 }
68
69 pub fn set_metadata(&mut self, key: impl Into<String>, value: serde_json::Value) {
71 self.metadata.insert(key.into(), value);
72 self.updated_at = chrono::Utc::now().timestamp_millis();
73 }
74
75 pub fn get_metadata(&self, key: &str) -> Option<&serde_json::Value> {
77 self.metadata.get(key)
78 }
79
80 pub fn generate_resume_token(&mut self) -> String {
82 let token = format!("{}:{}", self.id, Uuid::new_v4());
83 self.resume_token = Some(token.clone());
84 self.updated_at = chrono::Utc::now().timestamp_millis();
85 token
86 }
87}
88
89#[async_trait]
91pub trait MemoryBackend: Send + Sync {
92 async fn save_session(&self, session: &Session) -> Result<(), MemoryError>;
94
95 async fn load_session(&self, session_id: &str) -> Result<Option<Session>, MemoryError>;
97
98 async fn load_by_resume_token(&self, token: &str) -> Result<Option<Session>, MemoryError>;
100
101 async fn delete_session(&self, session_id: &str) -> Result<(), MemoryError>;
103
104 async fn list_sessions(&self, agent_id: &str) -> Result<Vec<Session>, MemoryError>;
106
107 async fn list_session_ids(&self, agent_id: &str) -> Result<Vec<String>, MemoryError>;
109
110 async fn session_exists(&self, session_id: &str) -> Result<bool, MemoryError>;
112
113 async fn clear_agent_sessions(&self, agent_id: &str) -> Result<usize, MemoryError>;
115
116 fn backend_name(&self) -> &'static str;
118}
119
120pub struct InMemoryBackend {
122 sessions: Arc<RwLock<HashMap<String, Session>>>,
123 resume_tokens: Arc<RwLock<HashMap<String, String>>>, }
125
126impl InMemoryBackend {
127 pub fn new() -> Self {
128 Self {
129 sessions: Arc::new(RwLock::new(HashMap::new())),
130 resume_tokens: Arc::new(RwLock::new(HashMap::new())),
131 }
132 }
133}
134
135impl Default for InMemoryBackend {
136 fn default() -> Self {
137 Self::new()
138 }
139}
140
141#[async_trait]
142impl MemoryBackend for InMemoryBackend {
143 async fn save_session(&self, session: &Session) -> Result<(), MemoryError> {
144 let mut sessions = self.sessions.write().await;
145
146 if let Some(token) = &session.resume_token {
148 let mut tokens = self.resume_tokens.write().await;
149 tokens.insert(token.clone(), session.id.clone());
150 }
151
152 sessions.insert(session.id.clone(), session.clone());
153 Ok(())
154 }
155
156 async fn load_session(&self, session_id: &str) -> Result<Option<Session>, MemoryError> {
157 let sessions = self.sessions.read().await;
158 Ok(sessions.get(session_id).cloned())
159 }
160
161 async fn load_by_resume_token(&self, token: &str) -> Result<Option<Session>, MemoryError> {
162 let tokens = self.resume_tokens.read().await;
163 if let Some(session_id) = tokens.get(token) {
164 let sessions = self.sessions.read().await;
165 Ok(sessions.get(session_id).cloned())
166 } else {
167 Ok(None)
168 }
169 }
170
171 async fn delete_session(&self, session_id: &str) -> Result<(), MemoryError> {
172 let mut sessions = self.sessions.write().await;
173 if let Some(session) = sessions.remove(session_id) {
174 if let Some(token) = session.resume_token {
176 let mut tokens = self.resume_tokens.write().await;
177 tokens.remove(&token);
178 }
179 }
180 Ok(())
181 }
182
183 async fn list_sessions(&self, agent_id: &str) -> Result<Vec<Session>, MemoryError> {
184 let sessions = self.sessions.read().await;
185 Ok(sessions
186 .values()
187 .filter(|s| s.agent_id == agent_id)
188 .cloned()
189 .collect())
190 }
191
192 async fn list_session_ids(&self, agent_id: &str) -> Result<Vec<String>, MemoryError> {
193 let sessions = self.sessions.read().await;
194 Ok(sessions
195 .values()
196 .filter(|s| s.agent_id == agent_id)
197 .map(|s| s.id.clone())
198 .collect())
199 }
200
201 async fn session_exists(&self, session_id: &str) -> Result<bool, MemoryError> {
202 let sessions = self.sessions.read().await;
203 Ok(sessions.contains_key(session_id))
204 }
205
206 async fn clear_agent_sessions(&self, agent_id: &str) -> Result<usize, MemoryError> {
207 let mut sessions = self.sessions.write().await;
208 let mut tokens = self.resume_tokens.write().await;
209
210 let to_remove: Vec<_> = sessions
211 .iter()
212 .filter(|(_, s)| s.agent_id == agent_id)
213 .map(|(id, s)| (id.clone(), s.resume_token.clone()))
214 .collect();
215
216 let count = to_remove.len();
217
218 for (id, token) in to_remove {
219 sessions.remove(&id);
220 if let Some(t) = token {
221 tokens.remove(&t);
222 }
223 }
224
225 Ok(count)
226 }
227
228 fn backend_name(&self) -> &'static str {
229 "in-memory"
230 }
231}
232
233pub struct SledBackend {
235 db: sled::Db,
236 sessions_tree: sled::Tree,
237 tokens_tree: sled::Tree,
238 agent_index_tree: sled::Tree,
239}
240
241impl SledBackend {
242 pub fn new<P: AsRef<Path>>(path: P) -> Result<Self, MemoryError> {
244 let db = sled::open(path).map_err(|e| MemoryError::StorageError(e.to_string()))?;
245
246 let sessions_tree = db
247 .open_tree("sessions")
248 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
249
250 let tokens_tree = db
251 .open_tree("resume_tokens")
252 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
253
254 let agent_index_tree = db
255 .open_tree("agent_sessions")
256 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
257
258 Ok(Self {
259 db,
260 sessions_tree,
261 tokens_tree,
262 agent_index_tree,
263 })
264 }
265
266 pub fn temporary() -> Result<Self, MemoryError> {
268 let db = sled::Config::new()
269 .temporary(true)
270 .open()
271 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
272
273 let sessions_tree = db
274 .open_tree("sessions")
275 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
276
277 let tokens_tree = db
278 .open_tree("resume_tokens")
279 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
280
281 let agent_index_tree = db
282 .open_tree("agent_sessions")
283 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
284
285 Ok(Self {
286 db,
287 sessions_tree,
288 tokens_tree,
289 agent_index_tree,
290 })
291 }
292
293 pub fn flush(&self) -> Result<(), MemoryError> {
295 self.db
296 .flush()
297 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
298 Ok(())
299 }
300
301 fn agent_session_key(agent_id: &str, session_id: &str) -> Vec<u8> {
303 format!("{}:{}", agent_id, session_id).into_bytes()
304 }
305}
306
307#[async_trait]
308impl MemoryBackend for SledBackend {
309 async fn save_session(&self, session: &Session) -> Result<(), MemoryError> {
310 let session_bytes = serde_json::to_vec(session)
311 .map_err(|e| MemoryError::SerializationError(e.to_string()))?;
312
313 self.sessions_tree
315 .insert(session.id.as_bytes(), session_bytes)
316 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
317
318 if let Some(token) = &session.resume_token {
320 self.tokens_tree
321 .insert(token.as_bytes(), session.id.as_bytes())
322 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
323 }
324
325 let agent_key = Self::agent_session_key(&session.agent_id, &session.id);
327 self.agent_index_tree
328 .insert(agent_key, &[1u8])
329 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
330
331 Ok(())
332 }
333
334 async fn load_session(&self, session_id: &str) -> Result<Option<Session>, MemoryError> {
335 match self.sessions_tree.get(session_id.as_bytes()) {
336 Ok(Some(bytes)) => {
337 let session: Session = serde_json::from_slice(&bytes)
338 .map_err(|e| MemoryError::SerializationError(e.to_string()))?;
339 Ok(Some(session))
340 }
341 Ok(None) => Ok(None),
342 Err(e) => Err(MemoryError::StorageError(e.to_string())),
343 }
344 }
345
346 async fn load_by_resume_token(&self, token: &str) -> Result<Option<Session>, MemoryError> {
347 match self.tokens_tree.get(token.as_bytes()) {
348 Ok(Some(session_id_bytes)) => {
349 let session_id = String::from_utf8(session_id_bytes.to_vec())
350 .map_err(|e| MemoryError::SerializationError(e.to_string()))?;
351 self.load_session(&session_id).await
352 }
353 Ok(None) => Ok(None),
354 Err(e) => Err(MemoryError::StorageError(e.to_string())),
355 }
356 }
357
358 async fn delete_session(&self, session_id: &str) -> Result<(), MemoryError> {
359 if let Some(session) = self.load_session(session_id).await? {
361 self.sessions_tree
363 .remove(session_id.as_bytes())
364 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
365
366 if let Some(token) = session.resume_token {
368 self.tokens_tree
369 .remove(token.as_bytes())
370 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
371 }
372
373 let agent_key = Self::agent_session_key(&session.agent_id, session_id);
375 self.agent_index_tree
376 .remove(agent_key)
377 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
378 }
379
380 Ok(())
381 }
382
383 async fn list_sessions(&self, agent_id: &str) -> Result<Vec<Session>, MemoryError> {
384 let prefix = format!("{}:", agent_id);
385 let mut sessions = Vec::new();
386
387 for result in self.agent_index_tree.scan_prefix(prefix.as_bytes()) {
388 let (key, _) = result.map_err(|e| MemoryError::StorageError(e.to_string()))?;
389 let key_str = String::from_utf8(key.to_vec())
390 .map_err(|e| MemoryError::SerializationError(e.to_string()))?;
391
392 if let Some(session_id) = key_str.strip_prefix(&prefix) {
394 if let Some(session) = self.load_session(session_id).await? {
395 sessions.push(session);
396 }
397 }
398 }
399
400 Ok(sessions)
401 }
402
403 async fn list_session_ids(&self, agent_id: &str) -> Result<Vec<String>, MemoryError> {
404 let prefix = format!("{}:", agent_id);
405 let mut ids = Vec::new();
406
407 for result in self.agent_index_tree.scan_prefix(prefix.as_bytes()) {
408 let (key, _) = result.map_err(|e| MemoryError::StorageError(e.to_string()))?;
409 let key_str = String::from_utf8(key.to_vec())
410 .map_err(|e| MemoryError::SerializationError(e.to_string()))?;
411
412 if let Some(session_id) = key_str.strip_prefix(&prefix) {
413 ids.push(session_id.to_string());
414 }
415 }
416
417 Ok(ids)
418 }
419
420 async fn session_exists(&self, session_id: &str) -> Result<bool, MemoryError> {
421 self.sessions_tree
422 .contains_key(session_id.as_bytes())
423 .map_err(|e| MemoryError::StorageError(e.to_string()))
424 }
425
426 async fn clear_agent_sessions(&self, agent_id: &str) -> Result<usize, MemoryError> {
427 let session_ids = self.list_session_ids(agent_id).await?;
428 let count = session_ids.len();
429
430 for session_id in session_ids {
431 self.delete_session(&session_id).await?;
432 }
433
434 Ok(count)
435 }
436
437 fn backend_name(&self) -> &'static str {
438 "sled"
439 }
440}
441
442pub struct SessionManager {
444 backend: Arc<dyn MemoryBackend>,
445}
446
447impl SessionManager {
448 pub fn new(backend: Arc<dyn MemoryBackend>) -> Self {
450 Self { backend }
451 }
452
453 pub fn in_memory() -> Self {
455 Self {
456 backend: Arc::new(InMemoryBackend::new()),
457 }
458 }
459
460 pub fn sled<P: AsRef<Path>>(path: P) -> Result<Self, MemoryError> {
462 Ok(Self {
463 backend: Arc::new(SledBackend::new(path)?),
464 })
465 }
466
467 pub fn sled_temporary() -> Result<Self, MemoryError> {
469 Ok(Self {
470 backend: Arc::new(SledBackend::temporary()?),
471 })
472 }
473
474 pub fn backend(&self) -> &dyn MemoryBackend {
476 self.backend.as_ref()
477 }
478
479 pub async fn create_session(&self, agent_id: &str) -> Result<Session, MemoryError> {
481 let session = Session::new(agent_id);
482 self.backend.save_session(&session).await?;
483 tracing::debug!(
484 backend = self.backend.backend_name(),
485 session_id = %session.id,
486 agent_id = %agent_id,
487 "Created new session"
488 );
489 Ok(session)
490 }
491
492 pub async fn get_or_create_session(
494 &self,
495 agent_id: &str,
496 session_id: Option<&str>,
497 ) -> Result<Session, MemoryError> {
498 if let Some(id) = session_id {
499 if let Some(session) = self.backend.load_session(id).await? {
500 return Ok(session);
501 }
502 }
503 self.create_session(agent_id).await
504 }
505
506 pub async fn resume_session(&self, token: &str) -> Result<Option<Session>, MemoryError> {
508 let session = self.backend.load_by_resume_token(token).await?;
509 if session.is_some() {
510 tracing::debug!(
511 backend = self.backend.backend_name(),
512 token = %token,
513 "Resumed session from token"
514 );
515 }
516 Ok(session)
517 }
518
519 pub async fn save_session(&self, session: &Session) -> Result<(), MemoryError> {
521 self.backend.save_session(session).await
522 }
523
524 pub async fn add_message(
526 &self,
527 session_id: &str,
528 message: Message,
529 ) -> Result<Session, MemoryError> {
530 let mut session = self
531 .backend
532 .load_session(session_id)
533 .await?
534 .ok_or_else(|| MemoryError::SessionNotFound(session_id.to_string()))?;
535
536 session.add_message(message);
537 self.backend.save_session(&session).await?;
538 Ok(session)
539 }
540
541 pub async fn create_resume_token(&self, session_id: &str) -> Result<String, MemoryError> {
543 let mut session = self
544 .backend
545 .load_session(session_id)
546 .await?
547 .ok_or_else(|| MemoryError::SessionNotFound(session_id.to_string()))?;
548
549 let token = session.generate_resume_token();
550 self.backend.save_session(&session).await?;
551
552 tracing::debug!(
553 backend = self.backend.backend_name(),
554 session_id = %session_id,
555 "Generated resume token"
556 );
557
558 Ok(token)
559 }
560
561 pub async fn delete_session(&self, session_id: &str) -> Result<(), MemoryError> {
563 self.backend.delete_session(session_id).await
564 }
565
566 pub async fn list_sessions(&self, agent_id: &str) -> Result<Vec<Session>, MemoryError> {
568 self.backend.list_sessions(agent_id).await
569 }
570
571 pub async fn clear_agent_sessions(&self, agent_id: &str) -> Result<usize, MemoryError> {
573 let count = self.backend.clear_agent_sessions(agent_id).await?;
574 tracing::info!(
575 backend = self.backend.backend_name(),
576 agent_id = %agent_id,
577 count = count,
578 "Cleared agent sessions"
579 );
580 Ok(count)
581 }
582}
583
584#[cfg(test)]
585mod tests {
586 use super::*;
587 use cortexai_core::{types::AgentId, Content};
588
589 fn make_user_message(text: &str) -> Message {
590 Message::user(AgentId::new("test-agent"), text)
591 }
592
593 fn make_assistant_message(text: &str) -> Message {
594 Message::new(
595 AgentId::new("test-agent"),
596 AgentId::new("user"),
597 Content::Text(text.to_string()),
598 )
599 }
600
601 #[tokio::test]
602 async fn test_in_memory_backend_basic() {
603 let backend = InMemoryBackend::new();
604
605 let mut session = Session::new("test-agent");
607 session.add_message(make_user_message("Hello"));
608 session.add_message(make_assistant_message("Hi there!"));
609
610 backend.save_session(&session).await.unwrap();
611
612 let loaded = backend.load_session(&session.id).await.unwrap().unwrap();
614 assert_eq!(loaded.id, session.id);
615 assert_eq!(loaded.messages.len(), 2);
616 assert_eq!(loaded.agent_id, "test-agent");
617 }
618
619 #[tokio::test]
620 async fn test_in_memory_resume_token() {
621 let backend = InMemoryBackend::new();
622
623 let mut session = Session::new("test-agent");
624 let token = session.generate_resume_token();
625 backend.save_session(&session).await.unwrap();
626
627 let resumed = backend.load_by_resume_token(&token).await.unwrap().unwrap();
629 assert_eq!(resumed.id, session.id);
630 }
631
632 #[tokio::test]
633 async fn test_in_memory_list_sessions() {
634 let backend = InMemoryBackend::new();
635
636 let session1 = Session::new("agent-1");
638 let session2 = Session::new("agent-1");
639 let session3 = Session::new("agent-2");
640
641 backend.save_session(&session1).await.unwrap();
642 backend.save_session(&session2).await.unwrap();
643 backend.save_session(&session3).await.unwrap();
644
645 let sessions = backend.list_sessions("agent-1").await.unwrap();
647 assert_eq!(sessions.len(), 2);
648
649 let sessions = backend.list_sessions("agent-2").await.unwrap();
651 assert_eq!(sessions.len(), 1);
652 }
653
654 #[tokio::test]
655 async fn test_sled_backend_basic() {
656 let backend = SledBackend::temporary().unwrap();
657
658 let mut session = Session::new("test-agent");
660 session.add_message(make_user_message("Test message"));
661 session.set_metadata("key", serde_json::json!("value"));
662
663 backend.save_session(&session).await.unwrap();
664
665 let loaded = backend.load_session(&session.id).await.unwrap().unwrap();
667 assert_eq!(loaded.id, session.id);
668 assert_eq!(loaded.messages.len(), 1);
669 assert_eq!(
670 loaded.get_metadata("key"),
671 Some(&serde_json::json!("value"))
672 );
673 }
674
675 #[tokio::test]
676 async fn test_sled_backend_resume_token() {
677 let backend = SledBackend::temporary().unwrap();
678
679 let mut session = Session::new("test-agent");
680 let token = session.generate_resume_token();
681 backend.save_session(&session).await.unwrap();
682
683 let resumed = backend.load_by_resume_token(&token).await.unwrap().unwrap();
685 assert_eq!(resumed.id, session.id);
686 }
687
688 #[tokio::test]
689 async fn test_sled_backend_delete() {
690 let backend = SledBackend::temporary().unwrap();
691
692 let session = Session::new("test-agent");
693 backend.save_session(&session).await.unwrap();
694
695 assert!(backend.session_exists(&session.id).await.unwrap());
696
697 backend.delete_session(&session.id).await.unwrap();
698
699 assert!(!backend.session_exists(&session.id).await.unwrap());
700 }
701
702 #[tokio::test]
703 async fn test_session_manager_create_and_resume() {
704 let manager = SessionManager::in_memory();
705
706 let session = manager.create_session("my-agent").await.unwrap();
708
709 let token = manager.create_resume_token(&session.id).await.unwrap();
711
712 let resumed = manager.resume_session(&token).await.unwrap().unwrap();
714 assert_eq!(resumed.id, session.id);
715 }
716
717 #[tokio::test]
718 async fn test_session_manager_add_message() {
719 let manager = SessionManager::in_memory();
720
721 let session = manager.create_session("my-agent").await.unwrap();
722
723 let updated = manager
725 .add_message(&session.id, make_user_message("Hello"))
726 .await
727 .unwrap();
728
729 assert_eq!(updated.messages.len(), 1);
730 assert!(updated.messages[0].is_text());
731 }
732
733 #[tokio::test]
734 async fn test_session_manager_sled() {
735 let manager = SessionManager::sled_temporary().unwrap();
736
737 let session = manager.create_session("persistent-agent").await.unwrap();
739
740 let _updated = manager
742 .add_message(&session.id, make_user_message("Persistent message"))
743 .await
744 .unwrap();
745
746 let sessions = manager.list_sessions("persistent-agent").await.unwrap();
748 assert_eq!(sessions.len(), 1);
749 assert_eq!(sessions[0].messages.len(), 1);
750 }
751
752 #[tokio::test]
753 async fn test_session_manager_clear() {
754 let manager = SessionManager::in_memory();
755
756 manager.create_session("agent-1").await.unwrap();
758 manager.create_session("agent-1").await.unwrap();
759 manager.create_session("agent-2").await.unwrap();
760
761 let count = manager.clear_agent_sessions("agent-1").await.unwrap();
763 assert_eq!(count, 2);
764
765 let sessions = manager.list_sessions("agent-1").await.unwrap();
767 assert_eq!(sessions.len(), 0);
768
769 let sessions = manager.list_sessions("agent-2").await.unwrap();
770 assert_eq!(sessions.len(), 1);
771 }
772}