Skip to main content

brainwires_stores/session/
in_memory.rs

1//! In-memory [`SessionStore`] implementation backed by a mutex-guarded map.
2//!
3//! Intended for tests, ephemeral sessions, and embedding use-cases. Nothing
4//! persists across process restarts.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use anyhow::Result;
10use async_trait::async_trait;
11use chrono::Utc;
12use tokio::sync::Mutex;
13
14use brainwires_core::Message;
15
16use super::{SessionId, SessionRecord, SessionStore};
17
18#[derive(Debug)]
19struct Entry {
20    messages: Vec<Message>,
21    created_at: chrono::DateTime<chrono::Utc>,
22    updated_at: chrono::DateTime<chrono::Utc>,
23}
24
25/// In-memory session store. Cheap to `Arc`-clone — all clones share state.
26#[derive(Clone, Default)]
27pub struct InMemorySessionStore {
28    inner: Arc<Mutex<HashMap<SessionId, Entry>>>,
29}
30
31impl InMemorySessionStore {
32    /// Build an empty store.
33    pub fn new() -> Self {
34        Self::default()
35    }
36}
37
38#[async_trait]
39impl SessionStore for InMemorySessionStore {
40    async fn load(&self, id: &SessionId) -> Result<Option<Vec<Message>>> {
41        let map = self.inner.lock().await;
42        Ok(map.get(id).map(|e| e.messages.clone()))
43    }
44
45    async fn save(&self, id: &SessionId, messages: &[Message]) -> Result<()> {
46        let mut map = self.inner.lock().await;
47        let now = Utc::now();
48        match map.get_mut(id) {
49            Some(entry) => {
50                entry.messages = messages.to_vec();
51                entry.updated_at = now;
52            }
53            None => {
54                map.insert(
55                    id.clone(),
56                    Entry {
57                        messages: messages.to_vec(),
58                        created_at: now,
59                        updated_at: now,
60                    },
61                );
62            }
63        }
64        Ok(())
65    }
66
67    async fn list(&self) -> Result<Vec<SessionRecord>> {
68        let map = self.inner.lock().await;
69        let mut out: Vec<SessionRecord> = map
70            .iter()
71            .map(|(id, e)| SessionRecord {
72                id: id.clone(),
73                message_count: e.messages.len(),
74                created_at: e.created_at,
75                updated_at: e.updated_at,
76            })
77            .collect();
78        out.sort_by_key(|r| r.updated_at);
79        Ok(out)
80    }
81
82    async fn delete(&self, id: &SessionId) -> Result<()> {
83        self.inner.lock().await.remove(id);
84        Ok(())
85    }
86}
87
88#[cfg(test)]
89mod tests {
90    use super::*;
91    use brainwires_core::Message;
92
93    #[tokio::test]
94    async fn roundtrip_save_load_delete() {
95        let store = InMemorySessionStore::new();
96        let id = SessionId::new("alice");
97
98        assert!(store.load(&id).await.unwrap().is_none());
99
100        let msgs = vec![Message::user("hi"), Message::assistant("hello")];
101        store.save(&id, &msgs).await.unwrap();
102
103        let loaded = store.load(&id).await.unwrap().unwrap();
104        assert_eq!(loaded.len(), 2);
105        assert_eq!(loaded[0].text(), Some("hi"));
106
107        store.delete(&id).await.unwrap();
108        assert!(store.load(&id).await.unwrap().is_none());
109    }
110
111    #[tokio::test]
112    async fn save_overwrites_atomically() {
113        let store = InMemorySessionStore::new();
114        let id = SessionId::new("bob");
115        store.save(&id, &[Message::user("one")]).await.unwrap();
116        store
117            .save(&id, &[Message::user("two"), Message::user("three")])
118            .await
119            .unwrap();
120        let loaded = store.load(&id).await.unwrap().unwrap();
121        assert_eq!(loaded.len(), 2);
122        assert_eq!(loaded[0].text(), Some("two"));
123    }
124
125    #[tokio::test]
126    async fn list_returns_known_sessions() {
127        let store = InMemorySessionStore::new();
128        store
129            .save(&SessionId::new("a"), &[Message::user("x")])
130            .await
131            .unwrap();
132        store
133            .save(&SessionId::new("b"), &[Message::user("y")])
134            .await
135            .unwrap();
136        let list = store.list().await.unwrap();
137        assert_eq!(list.len(), 2);
138        let ids: Vec<&str> = list.iter().map(|r| r.id.as_str()).collect();
139        assert!(ids.contains(&"a") && ids.contains(&"b"));
140    }
141
142    #[tokio::test]
143    async fn delete_unknown_is_noop() {
144        let store = InMemorySessionStore::new();
145        store
146            .delete(&SessionId::new("never-existed"))
147            .await
148            .unwrap();
149    }
150}