1#![allow(missing_docs)]
4use parking_lot::RwLock;
5use std::collections::HashMap;
6
7use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use uuid::Uuid;
10
11use crate::error::Error;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct Session {
16 pub id: Uuid,
17 pub title: Option<String>,
18 pub created_at: DateTime<Utc>,
19 pub messages: Vec<SessionMessage>,
20 #[serde(default, skip_serializing_if = "Option::is_none")]
22 pub user_id: Option<String>,
23 #[serde(default, skip_serializing_if = "Option::is_none")]
25 pub tenant_id: Option<String>,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct SessionMessage {
31 pub role: SessionRole,
32 pub content: String,
33 pub timestamp: DateTime<Utc>,
34}
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
38#[serde(rename_all = "snake_case")]
39pub enum SessionRole {
40 User,
41 Assistant,
42}
43
44pub fn format_session_context(history: &[SessionMessage], message: &str) -> String {
50 if history.is_empty() {
51 return message.to_string();
52 }
53
54 let mut ctx = String::from("## Conversation history\n");
55 for msg in history {
56 let role = match msg.role {
57 SessionRole::User => "User",
58 SessionRole::Assistant => "Assistant",
59 };
60 ctx.push_str(&format!("{role}: {}\n", msg.content));
61 }
62 ctx.push_str(&format!("\n## Current message\n{message}"));
63 ctx
64}
65
66pub trait SessionStore: Send + Sync {
68 fn create(&self, title: Option<String>) -> Result<Session, Error>;
70 fn get(&self, id: Uuid) -> Result<Option<Session>, Error>;
72 fn list(&self) -> Result<Vec<Session>, Error>;
74 fn delete(&self, id: Uuid) -> Result<bool, Error>;
76 fn add_message(&self, id: Uuid, message: SessionMessage) -> Result<(), Error>;
78
79 fn create_with_user(
82 &self,
83 title: Option<String>,
84 user_id: &str,
85 tenant_id: &str,
86 ) -> Result<Session, Error> {
87 let mut session = self.create(title)?;
88 session.user_id = Some(user_id.to_string());
89 session.tenant_id = Some(tenant_id.to_string());
90 Ok(session)
91 }
92
93 fn list_for_tenant(&self, tenant_id: &str) -> Result<Vec<Session>, Error> {
96 let all = self.list()?;
97 Ok(all
98 .into_iter()
99 .filter(|s| s.tenant_id.as_deref() == Some(tenant_id))
100 .collect())
101 }
102}
103
104pub struct InMemorySessionStore {
109 sessions: RwLock<HashMap<Uuid, Session>>,
110}
111
112impl InMemorySessionStore {
113 pub fn new() -> Self {
114 Self {
115 sessions: RwLock::new(HashMap::new()),
116 }
117 }
118}
119
120impl Default for InMemorySessionStore {
121 fn default() -> Self {
122 Self::new()
123 }
124}
125
126impl SessionStore for InMemorySessionStore {
127 fn create(&self, title: Option<String>) -> Result<Session, Error> {
128 let session = Session {
129 id: Uuid::new_v4(),
130 title,
131 created_at: Utc::now(),
132 messages: Vec::new(),
133 user_id: None,
134 tenant_id: None,
135 };
136 self.sessions.write().insert(session.id, session.clone());
137 Ok(session)
138 }
139
140 fn create_with_user(
141 &self,
142 title: Option<String>,
143 user_id: &str,
144 tenant_id: &str,
145 ) -> Result<Session, Error> {
146 let session = Session {
147 id: Uuid::new_v4(),
148 title,
149 created_at: Utc::now(),
150 messages: Vec::new(),
151 user_id: Some(user_id.to_string()),
152 tenant_id: Some(tenant_id.to_string()),
153 };
154 self.sessions.write().insert(session.id, session.clone());
155 Ok(session)
156 }
157
158 fn get(&self, id: Uuid) -> Result<Option<Session>, Error> {
159 Ok(self.sessions.read().get(&id).cloned())
160 }
161
162 fn list(&self) -> Result<Vec<Session>, Error> {
163 let mut list: Vec<Session> = self.sessions.read().values().cloned().collect();
164 list.sort_by_key(|s| std::cmp::Reverse(s.created_at));
166 Ok(list)
167 }
168
169 fn delete(&self, id: Uuid) -> Result<bool, Error> {
170 Ok(self.sessions.write().remove(&id).is_some())
171 }
172
173 fn add_message(&self, id: Uuid, message: SessionMessage) -> Result<(), Error> {
174 match self.sessions.write().get_mut(&id) {
175 Some(session) => {
176 session.messages.push(message);
177 Ok(())
178 }
179 None => Err(Error::Channel(format!("session {id} not found"))),
180 }
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use super::*;
187
188 fn make_message(role: SessionRole, content: &str) -> SessionMessage {
189 SessionMessage {
190 role,
191 content: content.to_string(),
192 timestamp: Utc::now(),
193 }
194 }
195
196 #[test]
197 fn create_session() {
198 let store = InMemorySessionStore::new();
199 let session = store.create(None).unwrap();
200 assert!(session.title.is_none());
201 assert!(session.messages.is_empty());
202 assert!(session.created_at <= Utc::now());
203 }
204
205 #[test]
206 fn create_session_with_title() {
207 let store = InMemorySessionStore::new();
208 let session = store.create(Some("My Chat".to_string())).unwrap();
209 assert_eq!(session.title.as_deref(), Some("My Chat"));
210 assert!(session.messages.is_empty());
211 }
212
213 #[test]
214 fn get_existing_session() {
215 let store = InMemorySessionStore::new();
216 let created = store.create(Some("Test".to_string())).unwrap();
217 let fetched = store
218 .get(created.id)
219 .unwrap()
220 .expect("session should exist");
221 assert_eq!(fetched.id, created.id);
222 assert_eq!(fetched.title, created.title);
223 assert_eq!(fetched.messages.len(), created.messages.len());
224 }
225
226 #[test]
227 fn get_missing_session() {
228 let store = InMemorySessionStore::new();
229 let result = store.get(Uuid::new_v4()).unwrap();
230 assert!(result.is_none());
231 }
232
233 #[test]
234 fn list_empty() {
235 let store = InMemorySessionStore::new();
236 let list = store.list().unwrap();
237 assert!(list.is_empty());
238 }
239
240 #[test]
241 fn list_multiple() {
242 let store = InMemorySessionStore::new();
243 store.create(None).unwrap();
244 store.create(None).unwrap();
245 store.create(None).unwrap();
246 let list = store.list().unwrap();
247 assert_eq!(list.len(), 3);
248 }
249
250 #[test]
251 fn list_ordered_by_created_at() {
252 let store = InMemorySessionStore::new();
253 {
256 let mut sessions = store.sessions.write();
257
258 let old = Session {
259 id: Uuid::new_v4(),
260 title: Some("old".to_string()),
261 created_at: Utc::now() - chrono::Duration::hours(2),
262 messages: Vec::new(),
263 user_id: None,
264 tenant_id: None,
265 };
266 let mid = Session {
267 id: Uuid::new_v4(),
268 title: Some("mid".to_string()),
269 created_at: Utc::now() - chrono::Duration::hours(1),
270 messages: Vec::new(),
271 user_id: None,
272 tenant_id: None,
273 };
274 let new = Session {
275 id: Uuid::new_v4(),
276 title: Some("new".to_string()),
277 created_at: Utc::now(),
278 messages: Vec::new(),
279 user_id: None,
280 tenant_id: None,
281 };
282
283 sessions.insert(mid.id, mid);
285 sessions.insert(old.id, old);
286 sessions.insert(new.id, new);
287 }
288
289 let list = store.list().unwrap();
290 assert_eq!(list.len(), 3);
291 assert_eq!(list[0].title.as_deref(), Some("new"));
292 assert_eq!(list[1].title.as_deref(), Some("mid"));
293 assert_eq!(list[2].title.as_deref(), Some("old"));
294 }
295
296 #[test]
297 fn delete_existing() {
298 let store = InMemorySessionStore::new();
299 let session = store.create(None).unwrap();
300 assert!(store.delete(session.id).unwrap());
301 assert!(store.get(session.id).unwrap().is_none());
302 }
303
304 #[test]
305 fn delete_missing() {
306 let store = InMemorySessionStore::new();
307 assert!(!store.delete(Uuid::new_v4()).unwrap());
308 }
309
310 #[test]
311 fn add_message_to_existing() {
312 let store = InMemorySessionStore::new();
313 let session = store.create(None).unwrap();
314 let msg = make_message(SessionRole::User, "hello");
315 store.add_message(session.id, msg).unwrap();
316
317 let fetched = store.get(session.id).unwrap().unwrap();
318 assert_eq!(fetched.messages.len(), 1);
319 assert_eq!(fetched.messages[0].content, "hello");
320 assert_eq!(fetched.messages[0].role, SessionRole::User);
321 }
322
323 #[test]
324 fn add_message_to_missing() {
325 let store = InMemorySessionStore::new();
326 let msg = make_message(SessionRole::User, "hello");
327 let err = store.add_message(Uuid::new_v4(), msg).unwrap_err();
328 assert!(err.to_string().contains("not found"));
329 }
330
331 #[test]
332 fn add_multiple_messages() {
333 let store = InMemorySessionStore::new();
334 let session = store.create(None).unwrap();
335
336 store
337 .add_message(session.id, make_message(SessionRole::User, "first"))
338 .unwrap();
339 store
340 .add_message(session.id, make_message(SessionRole::Assistant, "second"))
341 .unwrap();
342 store
343 .add_message(session.id, make_message(SessionRole::User, "third"))
344 .unwrap();
345
346 let fetched = store.get(session.id).unwrap().unwrap();
347 assert_eq!(fetched.messages.len(), 3);
348 assert_eq!(fetched.messages[0].content, "first");
349 assert_eq!(fetched.messages[1].content, "second");
350 assert_eq!(fetched.messages[2].content, "third");
351 assert_eq!(fetched.messages[0].role, SessionRole::User);
352 assert_eq!(fetched.messages[1].role, SessionRole::Assistant);
353 assert_eq!(fetched.messages[2].role, SessionRole::User);
354 }
355
356 #[test]
357 fn session_role_serde() {
358 let user_json = serde_json::to_string(&SessionRole::User).unwrap();
359 assert_eq!(user_json, "\"user\"");
360
361 let assistant_json = serde_json::to_string(&SessionRole::Assistant).unwrap();
362 assert_eq!(assistant_json, "\"assistant\"");
363
364 let user: SessionRole = serde_json::from_str("\"user\"").unwrap();
365 assert_eq!(user, SessionRole::User);
366
367 let assistant: SessionRole = serde_json::from_str("\"assistant\"").unwrap();
368 assert_eq!(assistant, SessionRole::Assistant);
369 }
370
371 #[test]
372 fn session_message_roundtrip() {
373 let msg = SessionMessage {
374 role: SessionRole::Assistant,
375 content: "Hello, world!".to_string(),
376 timestamp: Utc::now(),
377 };
378 let json = serde_json::to_string(&msg).unwrap();
379 let deserialized: SessionMessage = serde_json::from_str(&json).unwrap();
380 assert_eq!(deserialized.role, msg.role);
381 assert_eq!(deserialized.content, msg.content);
382 assert_eq!(deserialized.timestamp, msg.timestamp);
383 }
384
385 #[test]
386 fn concurrent_access() {
387 use std::sync::Arc;
388 use std::thread;
389
390 let store = Arc::new(InMemorySessionStore::new());
391 let mut handles = Vec::new();
392
393 for i in 0..10 {
395 let store = Arc::clone(&store);
396 handles.push(thread::spawn(move || {
397 let session = store
398 .create(Some(format!("thread-{i}")))
399 .expect("create should succeed");
400 let msg = SessionMessage {
402 role: SessionRole::User,
403 content: format!("msg from thread {i}"),
404 timestamp: Utc::now(),
405 };
406 store
407 .add_message(session.id, msg)
408 .expect("add_message should succeed");
409 session.id
410 }));
411 }
412
413 let ids: Vec<Uuid> = handles.into_iter().map(|h| h.join().unwrap()).collect();
414
415 for id in &ids {
417 let session = store.get(*id).unwrap().expect("session should exist");
418 assert_eq!(session.messages.len(), 1);
419 }
420
421 let list = store.list().unwrap();
422 assert_eq!(list.len(), 10);
423 }
424
425 #[test]
428 fn format_context_no_history() {
429 let result = format_session_context(&[], "Hello");
430 assert_eq!(result, "Hello");
431 }
432
433 #[test]
434 fn format_context_with_history() {
435 let history = vec![
436 make_message(SessionRole::User, "What is Rust?"),
437 make_message(SessionRole::Assistant, "A systems programming language."),
438 ];
439 let result = format_session_context(&history, "Tell me more");
440 assert!(result.contains("## Conversation history"));
441 assert!(result.contains("User: What is Rust?"));
442 assert!(result.contains("Assistant: A systems programming language."));
443 assert!(result.contains("## Current message"));
444 assert!(result.contains("Tell me more"));
445 }
446
447 #[test]
448 fn format_context_preserves_message_order() {
449 let history = vec![
450 make_message(SessionRole::User, "First"),
451 make_message(SessionRole::Assistant, "Second"),
452 make_message(SessionRole::User, "Third"),
453 make_message(SessionRole::Assistant, "Fourth"),
454 ];
455 let result = format_session_context(&history, "Fifth");
456 let first_pos = result.find("First").unwrap();
457 let second_pos = result.find("Second").unwrap();
458 let third_pos = result.find("Third").unwrap();
459 let fourth_pos = result.find("Fourth").unwrap();
460 let fifth_pos = result.find("Fifth").unwrap();
461 assert!(first_pos < second_pos);
462 assert!(second_pos < third_pos);
463 assert!(third_pos < fourth_pos);
464 assert!(fourth_pos < fifth_pos);
465 }
466
467 #[test]
468 fn format_context_single_message_history() {
469 let history = vec![make_message(SessionRole::User, "Prior question")];
470 let result = format_session_context(&history, "Follow-up");
471 assert!(result.contains("User: Prior question"));
472 assert!(result.contains("Follow-up"));
473 }
474
475 #[test]
478 fn create_with_user_sets_fields() {
479 let store = InMemorySessionStore::new();
480 let session = store
481 .create_with_user(Some("Test".into()), "alice", "acme")
482 .unwrap();
483 assert_eq!(session.user_id.as_deref(), Some("alice"));
484 assert_eq!(session.tenant_id.as_deref(), Some("acme"));
485 assert_eq!(session.title.as_deref(), Some("Test"));
486 }
487
488 #[test]
489 fn create_without_user_has_none_fields() {
490 let store = InMemorySessionStore::new();
491 let session = store.create(None).unwrap();
492 assert!(session.user_id.is_none());
493 assert!(session.tenant_id.is_none());
494 }
495
496 #[test]
497 fn list_for_tenant_filters_by_tenant() {
498 let store = InMemorySessionStore::new();
499 store
500 .create_with_user(Some("acme-1".into()), "alice", "acme")
501 .unwrap();
502 store
503 .create_with_user(Some("acme-2".into()), "bob", "acme")
504 .unwrap();
505 store
506 .create_with_user(Some("globex-1".into()), "charlie", "globex")
507 .unwrap();
508 store.create(Some("legacy".into())).unwrap(); let acme = store.list_for_tenant("acme").unwrap();
511 assert_eq!(acme.len(), 2);
512 assert!(acme.iter().all(|s| s.tenant_id.as_deref() == Some("acme")));
513
514 let globex = store.list_for_tenant("globex").unwrap();
515 assert_eq!(globex.len(), 1);
516 assert_eq!(globex[0].tenant_id.as_deref(), Some("globex"));
517
518 let all = store.list().unwrap();
520 assert_eq!(all.len(), 4);
521 }
522
523 #[test]
524 fn session_serde_backward_compat() {
525 let json = r#"{"id":"00000000-0000-0000-0000-000000000000","title":"old","created_at":"2026-01-01T00:00:00Z","messages":[]}"#;
527 let session: Session = serde_json::from_str(json).unwrap();
528 assert!(session.user_id.is_none());
529 assert!(session.tenant_id.is_none());
530 assert_eq!(session.title.as_deref(), Some("old"));
531 }
532
533 #[test]
534 fn session_serde_with_tenant() {
535 let session = Session {
536 id: Uuid::nil(),
537 title: None,
538 created_at: Utc::now(),
539 messages: Vec::new(),
540 user_id: Some("alice".into()),
541 tenant_id: Some("acme".into()),
542 };
543 let json = serde_json::to_string(&session).unwrap();
544 assert!(json.contains(r#""user_id":"alice""#));
545 assert!(json.contains(r#""tenant_id":"acme""#));
546
547 let deserialized: Session = serde_json::from_str(&json).unwrap();
548 assert_eq!(deserialized.user_id.as_deref(), Some("alice"));
549 assert_eq!(deserialized.tenant_id.as_deref(), Some("acme"));
550 }
551}