Skip to main content

heartbit_core/channel/
session.rs

1//! Session management for WebSocket-connected agent interactions.
2
3#![allow(missing_docs)]
4use parking_lot::RwLock;
5use std::collections::HashMap;
6
7use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use uuid::Uuid;
10
11use crate::error::Error;
12
13/// A conversation session containing message history.
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct Session {
16    pub id: Uuid,
17    pub title: Option<String>,
18    pub created_at: DateTime<Utc>,
19    pub messages: Vec<SessionMessage>,
20    /// User who owns this session (multi-tenant isolation).
21    #[serde(default, skip_serializing_if = "Option::is_none")]
22    pub user_id: Option<String>,
23    /// Tenant that owns this session (multi-tenant isolation).
24    #[serde(default, skip_serializing_if = "Option::is_none")]
25    pub tenant_id: Option<String>,
26}
27
28/// A single message within a session.
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct SessionMessage {
31    pub role: SessionRole,
32    pub content: String,
33    pub timestamp: DateTime<Utc>,
34}
35
36/// Role of a session message participant.
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
38#[serde(rename_all = "snake_case")]
39pub enum SessionRole {
40    User,
41    Assistant,
42}
43
44/// Format session history as context to prepend to a new message.
45///
46/// When there is prior conversation history, returns the new message prefixed
47/// with a formatted history section. When history is empty, returns the message
48/// unchanged.
49pub fn format_session_context(history: &[SessionMessage], message: &str) -> String {
50    if history.is_empty() {
51        return message.to_string();
52    }
53
54    let mut ctx = String::from("## Conversation history\n");
55    for msg in history {
56        let role = match msg.role {
57            SessionRole::User => "User",
58            SessionRole::Assistant => "Assistant",
59        };
60        ctx.push_str(&format!("{role}: {}\n", msg.content));
61    }
62    ctx.push_str(&format!("\n## Current message\n{message}"));
63    ctx
64}
65
66/// Trait for session persistence.
67pub trait SessionStore: Send + Sync {
68    /// Create a new session with an optional title.
69    fn create(&self, title: Option<String>) -> Result<Session, Error>;
70    /// Get a session by ID. Returns `None` if not found.
71    fn get(&self, id: Uuid) -> Result<Option<Session>, Error>;
72    /// List all sessions (most recent first).
73    fn list(&self) -> Result<Vec<Session>, Error>;
74    /// Delete a session. Returns true if found and deleted.
75    fn delete(&self, id: Uuid) -> Result<bool, Error>;
76    /// Append a message to an existing session.
77    fn add_message(&self, id: Uuid, message: SessionMessage) -> Result<(), Error>;
78
79    /// Create a session with user/tenant context for multi-tenant isolation.
80    /// Default: delegates to `create()` and patches user/tenant fields.
81    fn create_with_user(
82        &self,
83        title: Option<String>,
84        user_id: &str,
85        tenant_id: &str,
86    ) -> Result<Session, Error> {
87        let mut session = self.create(title)?;
88        session.user_id = Some(user_id.to_string());
89        session.tenant_id = Some(tenant_id.to_string());
90        Ok(session)
91    }
92
93    /// List sessions scoped to a tenant (most recent first).
94    /// Default: calls `list()` and filters in-memory.
95    fn list_for_tenant(&self, tenant_id: &str) -> Result<Vec<Session>, Error> {
96        let all = self.list()?;
97        Ok(all
98            .into_iter()
99            .filter(|s| s.tenant_id.as_deref() == Some(tenant_id))
100            .collect())
101    }
102}
103
104/// In-memory session store using `parking_lot::RwLock` (not tokio — matches
105/// codebase pattern for locks never held across `.await`; `parking_lot` is
106/// adopted on the channel hot path for ~2× faster uncontended reads, see T2
107/// in `tasks/performance-audit-heartbit-core-2026-05-06.md`).
108pub struct InMemorySessionStore {
109    sessions: RwLock<HashMap<Uuid, Session>>,
110}
111
112impl InMemorySessionStore {
113    pub fn new() -> Self {
114        Self {
115            sessions: RwLock::new(HashMap::new()),
116        }
117    }
118}
119
120impl Default for InMemorySessionStore {
121    fn default() -> Self {
122        Self::new()
123    }
124}
125
126impl SessionStore for InMemorySessionStore {
127    fn create(&self, title: Option<String>) -> Result<Session, Error> {
128        let session = Session {
129            id: Uuid::new_v4(),
130            title,
131            created_at: Utc::now(),
132            messages: Vec::new(),
133            user_id: None,
134            tenant_id: None,
135        };
136        self.sessions.write().insert(session.id, session.clone());
137        Ok(session)
138    }
139
140    fn create_with_user(
141        &self,
142        title: Option<String>,
143        user_id: &str,
144        tenant_id: &str,
145    ) -> Result<Session, Error> {
146        let session = Session {
147            id: Uuid::new_v4(),
148            title,
149            created_at: Utc::now(),
150            messages: Vec::new(),
151            user_id: Some(user_id.to_string()),
152            tenant_id: Some(tenant_id.to_string()),
153        };
154        self.sessions.write().insert(session.id, session.clone());
155        Ok(session)
156    }
157
158    fn get(&self, id: Uuid) -> Result<Option<Session>, Error> {
159        Ok(self.sessions.read().get(&id).cloned())
160    }
161
162    fn list(&self) -> Result<Vec<Session>, Error> {
163        let mut list: Vec<Session> = self.sessions.read().values().cloned().collect();
164        // Most recent first
165        list.sort_by_key(|s| std::cmp::Reverse(s.created_at));
166        Ok(list)
167    }
168
169    fn delete(&self, id: Uuid) -> Result<bool, Error> {
170        Ok(self.sessions.write().remove(&id).is_some())
171    }
172
173    fn add_message(&self, id: Uuid, message: SessionMessage) -> Result<(), Error> {
174        match self.sessions.write().get_mut(&id) {
175            Some(session) => {
176                session.messages.push(message);
177                Ok(())
178            }
179            None => Err(Error::Channel(format!("session {id} not found"))),
180        }
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187
188    fn make_message(role: SessionRole, content: &str) -> SessionMessage {
189        SessionMessage {
190            role,
191            content: content.to_string(),
192            timestamp: Utc::now(),
193        }
194    }
195
196    #[test]
197    fn create_session() {
198        let store = InMemorySessionStore::new();
199        let session = store.create(None).unwrap();
200        assert!(session.title.is_none());
201        assert!(session.messages.is_empty());
202        assert!(session.created_at <= Utc::now());
203    }
204
205    #[test]
206    fn create_session_with_title() {
207        let store = InMemorySessionStore::new();
208        let session = store.create(Some("My Chat".to_string())).unwrap();
209        assert_eq!(session.title.as_deref(), Some("My Chat"));
210        assert!(session.messages.is_empty());
211    }
212
213    #[test]
214    fn get_existing_session() {
215        let store = InMemorySessionStore::new();
216        let created = store.create(Some("Test".to_string())).unwrap();
217        let fetched = store
218            .get(created.id)
219            .unwrap()
220            .expect("session should exist");
221        assert_eq!(fetched.id, created.id);
222        assert_eq!(fetched.title, created.title);
223        assert_eq!(fetched.messages.len(), created.messages.len());
224    }
225
226    #[test]
227    fn get_missing_session() {
228        let store = InMemorySessionStore::new();
229        let result = store.get(Uuid::new_v4()).unwrap();
230        assert!(result.is_none());
231    }
232
233    #[test]
234    fn list_empty() {
235        let store = InMemorySessionStore::new();
236        let list = store.list().unwrap();
237        assert!(list.is_empty());
238    }
239
240    #[test]
241    fn list_multiple() {
242        let store = InMemorySessionStore::new();
243        store.create(None).unwrap();
244        store.create(None).unwrap();
245        store.create(None).unwrap();
246        let list = store.list().unwrap();
247        assert_eq!(list.len(), 3);
248    }
249
250    #[test]
251    fn list_ordered_by_created_at() {
252        let store = InMemorySessionStore::new();
253        // Create sessions — they get Utc::now() timestamps so ordering depends on
254        // insertion order. To test sorting, manually insert with controlled timestamps.
255        {
256            let mut sessions = store.sessions.write();
257
258            let old = Session {
259                id: Uuid::new_v4(),
260                title: Some("old".to_string()),
261                created_at: Utc::now() - chrono::Duration::hours(2),
262                messages: Vec::new(),
263                user_id: None,
264                tenant_id: None,
265            };
266            let mid = Session {
267                id: Uuid::new_v4(),
268                title: Some("mid".to_string()),
269                created_at: Utc::now() - chrono::Duration::hours(1),
270                messages: Vec::new(),
271                user_id: None,
272                tenant_id: None,
273            };
274            let new = Session {
275                id: Uuid::new_v4(),
276                title: Some("new".to_string()),
277                created_at: Utc::now(),
278                messages: Vec::new(),
279                user_id: None,
280                tenant_id: None,
281            };
282
283            // Insert in non-sorted order
284            sessions.insert(mid.id, mid);
285            sessions.insert(old.id, old);
286            sessions.insert(new.id, new);
287        }
288
289        let list = store.list().unwrap();
290        assert_eq!(list.len(), 3);
291        assert_eq!(list[0].title.as_deref(), Some("new"));
292        assert_eq!(list[1].title.as_deref(), Some("mid"));
293        assert_eq!(list[2].title.as_deref(), Some("old"));
294    }
295
296    #[test]
297    fn delete_existing() {
298        let store = InMemorySessionStore::new();
299        let session = store.create(None).unwrap();
300        assert!(store.delete(session.id).unwrap());
301        assert!(store.get(session.id).unwrap().is_none());
302    }
303
304    #[test]
305    fn delete_missing() {
306        let store = InMemorySessionStore::new();
307        assert!(!store.delete(Uuid::new_v4()).unwrap());
308    }
309
310    #[test]
311    fn add_message_to_existing() {
312        let store = InMemorySessionStore::new();
313        let session = store.create(None).unwrap();
314        let msg = make_message(SessionRole::User, "hello");
315        store.add_message(session.id, msg).unwrap();
316
317        let fetched = store.get(session.id).unwrap().unwrap();
318        assert_eq!(fetched.messages.len(), 1);
319        assert_eq!(fetched.messages[0].content, "hello");
320        assert_eq!(fetched.messages[0].role, SessionRole::User);
321    }
322
323    #[test]
324    fn add_message_to_missing() {
325        let store = InMemorySessionStore::new();
326        let msg = make_message(SessionRole::User, "hello");
327        let err = store.add_message(Uuid::new_v4(), msg).unwrap_err();
328        assert!(err.to_string().contains("not found"));
329    }
330
331    #[test]
332    fn add_multiple_messages() {
333        let store = InMemorySessionStore::new();
334        let session = store.create(None).unwrap();
335
336        store
337            .add_message(session.id, make_message(SessionRole::User, "first"))
338            .unwrap();
339        store
340            .add_message(session.id, make_message(SessionRole::Assistant, "second"))
341            .unwrap();
342        store
343            .add_message(session.id, make_message(SessionRole::User, "third"))
344            .unwrap();
345
346        let fetched = store.get(session.id).unwrap().unwrap();
347        assert_eq!(fetched.messages.len(), 3);
348        assert_eq!(fetched.messages[0].content, "first");
349        assert_eq!(fetched.messages[1].content, "second");
350        assert_eq!(fetched.messages[2].content, "third");
351        assert_eq!(fetched.messages[0].role, SessionRole::User);
352        assert_eq!(fetched.messages[1].role, SessionRole::Assistant);
353        assert_eq!(fetched.messages[2].role, SessionRole::User);
354    }
355
356    #[test]
357    fn session_role_serde() {
358        let user_json = serde_json::to_string(&SessionRole::User).unwrap();
359        assert_eq!(user_json, "\"user\"");
360
361        let assistant_json = serde_json::to_string(&SessionRole::Assistant).unwrap();
362        assert_eq!(assistant_json, "\"assistant\"");
363
364        let user: SessionRole = serde_json::from_str("\"user\"").unwrap();
365        assert_eq!(user, SessionRole::User);
366
367        let assistant: SessionRole = serde_json::from_str("\"assistant\"").unwrap();
368        assert_eq!(assistant, SessionRole::Assistant);
369    }
370
371    #[test]
372    fn session_message_roundtrip() {
373        let msg = SessionMessage {
374            role: SessionRole::Assistant,
375            content: "Hello, world!".to_string(),
376            timestamp: Utc::now(),
377        };
378        let json = serde_json::to_string(&msg).unwrap();
379        let deserialized: SessionMessage = serde_json::from_str(&json).unwrap();
380        assert_eq!(deserialized.role, msg.role);
381        assert_eq!(deserialized.content, msg.content);
382        assert_eq!(deserialized.timestamp, msg.timestamp);
383    }
384
385    #[test]
386    fn concurrent_access() {
387        use std::sync::Arc;
388        use std::thread;
389
390        let store = Arc::new(InMemorySessionStore::new());
391        let mut handles = Vec::new();
392
393        // Spawn threads that create sessions
394        for i in 0..10 {
395            let store = Arc::clone(&store);
396            handles.push(thread::spawn(move || {
397                let session = store
398                    .create(Some(format!("thread-{i}")))
399                    .expect("create should succeed");
400                // Add a message to the session we just created
401                let msg = SessionMessage {
402                    role: SessionRole::User,
403                    content: format!("msg from thread {i}"),
404                    timestamp: Utc::now(),
405                };
406                store
407                    .add_message(session.id, msg)
408                    .expect("add_message should succeed");
409                session.id
410            }));
411        }
412
413        let ids: Vec<Uuid> = handles.into_iter().map(|h| h.join().unwrap()).collect();
414
415        // All sessions should exist with one message each
416        for id in &ids {
417            let session = store.get(*id).unwrap().expect("session should exist");
418            assert_eq!(session.messages.len(), 1);
419        }
420
421        let list = store.list().unwrap();
422        assert_eq!(list.len(), 10);
423    }
424
425    // --- format_session_context tests ---
426
427    #[test]
428    fn format_context_no_history() {
429        let result = format_session_context(&[], "Hello");
430        assert_eq!(result, "Hello");
431    }
432
433    #[test]
434    fn format_context_with_history() {
435        let history = vec![
436            make_message(SessionRole::User, "What is Rust?"),
437            make_message(SessionRole::Assistant, "A systems programming language."),
438        ];
439        let result = format_session_context(&history, "Tell me more");
440        assert!(result.contains("## Conversation history"));
441        assert!(result.contains("User: What is Rust?"));
442        assert!(result.contains("Assistant: A systems programming language."));
443        assert!(result.contains("## Current message"));
444        assert!(result.contains("Tell me more"));
445    }
446
447    #[test]
448    fn format_context_preserves_message_order() {
449        let history = vec![
450            make_message(SessionRole::User, "First"),
451            make_message(SessionRole::Assistant, "Second"),
452            make_message(SessionRole::User, "Third"),
453            make_message(SessionRole::Assistant, "Fourth"),
454        ];
455        let result = format_session_context(&history, "Fifth");
456        let first_pos = result.find("First").unwrap();
457        let second_pos = result.find("Second").unwrap();
458        let third_pos = result.find("Third").unwrap();
459        let fourth_pos = result.find("Fourth").unwrap();
460        let fifth_pos = result.find("Fifth").unwrap();
461        assert!(first_pos < second_pos);
462        assert!(second_pos < third_pos);
463        assert!(third_pos < fourth_pos);
464        assert!(fourth_pos < fifth_pos);
465    }
466
467    #[test]
468    fn format_context_single_message_history() {
469        let history = vec![make_message(SessionRole::User, "Prior question")];
470        let result = format_session_context(&history, "Follow-up");
471        assert!(result.contains("User: Prior question"));
472        assert!(result.contains("Follow-up"));
473    }
474
475    // --- Multi-tenant session tests ---
476
477    #[test]
478    fn create_with_user_sets_fields() {
479        let store = InMemorySessionStore::new();
480        let session = store
481            .create_with_user(Some("Test".into()), "alice", "acme")
482            .unwrap();
483        assert_eq!(session.user_id.as_deref(), Some("alice"));
484        assert_eq!(session.tenant_id.as_deref(), Some("acme"));
485        assert_eq!(session.title.as_deref(), Some("Test"));
486    }
487
488    #[test]
489    fn create_without_user_has_none_fields() {
490        let store = InMemorySessionStore::new();
491        let session = store.create(None).unwrap();
492        assert!(session.user_id.is_none());
493        assert!(session.tenant_id.is_none());
494    }
495
496    #[test]
497    fn list_for_tenant_filters_by_tenant() {
498        let store = InMemorySessionStore::new();
499        store
500            .create_with_user(Some("acme-1".into()), "alice", "acme")
501            .unwrap();
502        store
503            .create_with_user(Some("acme-2".into()), "bob", "acme")
504            .unwrap();
505        store
506            .create_with_user(Some("globex-1".into()), "charlie", "globex")
507            .unwrap();
508        store.create(Some("legacy".into())).unwrap(); // no tenant
509
510        let acme = store.list_for_tenant("acme").unwrap();
511        assert_eq!(acme.len(), 2);
512        assert!(acme.iter().all(|s| s.tenant_id.as_deref() == Some("acme")));
513
514        let globex = store.list_for_tenant("globex").unwrap();
515        assert_eq!(globex.len(), 1);
516        assert_eq!(globex[0].tenant_id.as_deref(), Some("globex"));
517
518        // Legacy sessions (no tenant) are not returned by list_for_tenant
519        let all = store.list().unwrap();
520        assert_eq!(all.len(), 4);
521    }
522
523    #[test]
524    fn session_serde_backward_compat() {
525        // Old JSON without user_id/tenant_id should deserialize with None
526        let json = r#"{"id":"00000000-0000-0000-0000-000000000000","title":"old","created_at":"2026-01-01T00:00:00Z","messages":[]}"#;
527        let session: Session = serde_json::from_str(json).unwrap();
528        assert!(session.user_id.is_none());
529        assert!(session.tenant_id.is_none());
530        assert_eq!(session.title.as_deref(), Some("old"));
531    }
532
533    #[test]
534    fn session_serde_with_tenant() {
535        let session = Session {
536            id: Uuid::nil(),
537            title: None,
538            created_at: Utc::now(),
539            messages: Vec::new(),
540            user_id: Some("alice".into()),
541            tenant_id: Some("acme".into()),
542        };
543        let json = serde_json::to_string(&session).unwrap();
544        assert!(json.contains(r#""user_id":"alice""#));
545        assert!(json.contains(r#""tenant_id":"acme""#));
546
547        let deserialized: Session = serde_json::from_str(&json).unwrap();
548        assert_eq!(deserialized.user_id.as_deref(), Some("alice"));
549        assert_eq!(deserialized.tenant_id.as_deref(), Some("acme"));
550    }
551}