brainwires_stores/session/
in_memory.rs1use std::collections::HashMap;
7use std::sync::Arc;
8
9use anyhow::Result;
10use async_trait::async_trait;
11use chrono::Utc;
12use tokio::sync::Mutex;
13
14use brainwires_core::Message;
15
16use super::{SessionId, SessionRecord, SessionStore};
17
18#[derive(Debug)]
19struct Entry {
20 messages: Vec<Message>,
21 created_at: chrono::DateTime<chrono::Utc>,
22 updated_at: chrono::DateTime<chrono::Utc>,
23}
24
25#[derive(Clone, Default)]
27pub struct InMemorySessionStore {
28 inner: Arc<Mutex<HashMap<SessionId, Entry>>>,
29}
30
31impl InMemorySessionStore {
32 pub fn new() -> Self {
34 Self::default()
35 }
36}
37
38#[async_trait]
39impl SessionStore for InMemorySessionStore {
40 async fn load(&self, id: &SessionId) -> Result<Option<Vec<Message>>> {
41 let map = self.inner.lock().await;
42 Ok(map.get(id).map(|e| e.messages.clone()))
43 }
44
45 async fn save(&self, id: &SessionId, messages: &[Message]) -> Result<()> {
46 let mut map = self.inner.lock().await;
47 let now = Utc::now();
48 match map.get_mut(id) {
49 Some(entry) => {
50 entry.messages = messages.to_vec();
51 entry.updated_at = now;
52 }
53 None => {
54 map.insert(
55 id.clone(),
56 Entry {
57 messages: messages.to_vec(),
58 created_at: now,
59 updated_at: now,
60 },
61 );
62 }
63 }
64 Ok(())
65 }
66
67 async fn list(&self) -> Result<Vec<SessionRecord>> {
68 let map = self.inner.lock().await;
69 let mut out: Vec<SessionRecord> = map
70 .iter()
71 .map(|(id, e)| SessionRecord {
72 id: id.clone(),
73 message_count: e.messages.len(),
74 created_at: e.created_at,
75 updated_at: e.updated_at,
76 })
77 .collect();
78 out.sort_by_key(|r| r.updated_at);
79 Ok(out)
80 }
81
82 async fn delete(&self, id: &SessionId) -> Result<()> {
83 self.inner.lock().await.remove(id);
84 Ok(())
85 }
86}
87
88#[cfg(test)]
89mod tests {
90 use super::*;
91 use brainwires_core::Message;
92
93 #[tokio::test]
94 async fn roundtrip_save_load_delete() {
95 let store = InMemorySessionStore::new();
96 let id = SessionId::new("alice");
97
98 assert!(store.load(&id).await.unwrap().is_none());
99
100 let msgs = vec![Message::user("hi"), Message::assistant("hello")];
101 store.save(&id, &msgs).await.unwrap();
102
103 let loaded = store.load(&id).await.unwrap().unwrap();
104 assert_eq!(loaded.len(), 2);
105 assert_eq!(loaded[0].text(), Some("hi"));
106
107 store.delete(&id).await.unwrap();
108 assert!(store.load(&id).await.unwrap().is_none());
109 }
110
111 #[tokio::test]
112 async fn save_overwrites_atomically() {
113 let store = InMemorySessionStore::new();
114 let id = SessionId::new("bob");
115 store.save(&id, &[Message::user("one")]).await.unwrap();
116 store
117 .save(&id, &[Message::user("two"), Message::user("three")])
118 .await
119 .unwrap();
120 let loaded = store.load(&id).await.unwrap().unwrap();
121 assert_eq!(loaded.len(), 2);
122 assert_eq!(loaded[0].text(), Some("two"));
123 }
124
125 #[tokio::test]
126 async fn list_returns_known_sessions() {
127 let store = InMemorySessionStore::new();
128 store
129 .save(&SessionId::new("a"), &[Message::user("x")])
130 .await
131 .unwrap();
132 store
133 .save(&SessionId::new("b"), &[Message::user("y")])
134 .await
135 .unwrap();
136 let list = store.list().await.unwrap();
137 assert_eq!(list.len(), 2);
138 let ids: Vec<&str> = list.iter().map(|r| r.id.as_str()).collect();
139 assert!(ids.contains(&"a") && ids.contains(&"b"));
140 }
141
142 #[tokio::test]
143 async fn delete_unknown_is_noop() {
144 let store = InMemorySessionStore::new();
145 store
146 .delete(&SessionId::new("never-existed"))
147 .await
148 .unwrap();
149 }
150}