1use std::path::Path;
26use std::sync::{Arc, Mutex};
27
28use aonyx_core::{AonyxError, Message, Result, Role};
29use async_trait::async_trait;
30use chrono::{DateTime, Utc};
31use rusqlite::{params, Connection, OptionalExtension};
32use serde::{Deserialize, Serialize};
33use uuid::Uuid;
34
35pub type SessionId = Uuid;
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct SessionRecord {
41 pub id: SessionId,
43 pub project: String,
45 pub created_at: DateTime<Utc>,
47 pub updated_at: DateTime<Utc>,
49 pub parent_id: Option<SessionId>,
51 pub title: String,
53 pub turns: u32,
55 pub messages: Vec<Message>,
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct SearchHit {
64 pub id: SessionId,
66 pub project: String,
68 pub title: String,
70 pub updated_at: DateTime<Utc>,
72 pub turns: u32,
74 pub snippet: String,
76}
77
78#[async_trait]
80pub trait SessionStore: Send + Sync {
81 async fn create(&self, project: &str, messages: Vec<Message>) -> Result<SessionRecord>;
83
84 async fn fork(
88 &self,
89 project: &str,
90 parent: SessionId,
91 messages: Vec<Message>,
92 turns: u32,
93 ) -> Result<SessionRecord>;
94
95 async fn update(&self, id: SessionId, messages: Vec<Message>, turns: u32) -> Result<()>;
97
98 async fn rename(&self, id: SessionId, title: &str) -> Result<()>;
102
103 async fn list_by_project(&self, project: &str, limit: usize) -> Result<Vec<SessionRecord>>;
105
106 async fn get(&self, id: SessionId) -> Result<Option<SessionRecord>>;
108
109 async fn delete(&self, id: SessionId) -> Result<()>;
111
112 async fn latest(&self, project: &str) -> Result<Option<SessionRecord>>;
114
115 async fn search(&self, query: &str, limit: usize) -> Result<Vec<SearchHit>>;
119
120 async fn find_by_id_prefix(&self, prefix: &str, limit: usize) -> Result<Vec<SessionRecord>>;
124}
125
126#[derive(Clone)]
128pub struct SqliteSessionStore {
129 conn: Arc<Mutex<Connection>>,
130}
131
132impl SqliteSessionStore {
133 pub fn open(path: impl AsRef<Path>) -> Result<Self> {
135 let conn = Connection::open(path.as_ref())
136 .map_err(|e| AonyxError::Memory(format!("open sessions db: {e}")))?;
137 Self::migrate(&conn)?;
138 Ok(Self {
139 conn: Arc::new(Mutex::new(conn)),
140 })
141 }
142
143 pub fn open_in_memory() -> Result<Self> {
145 let conn = Connection::open_in_memory()
146 .map_err(|e| AonyxError::Memory(format!("open in-memory sessions: {e}")))?;
147 Self::migrate(&conn)?;
148 Ok(Self {
149 conn: Arc::new(Mutex::new(conn)),
150 })
151 }
152
153 fn migrate(conn: &Connection) -> Result<()> {
154 conn.execute_batch(MIGRATION_V1)
155 .map_err(|e| AonyxError::Memory(format!("migrate sessions schema: {e}")))?;
156 Ok(())
157 }
158
159 async fn insert_record(&self, record: SessionRecord) -> Result<SessionRecord> {
162 let conn = self.conn.clone();
163 let to_insert = record.clone();
164 tokio::task::spawn_blocking(move || -> Result<()> {
165 let lock = conn.lock().expect("sessions mutex poisoned");
166 let json = serde_json::to_string(&to_insert.messages)
167 .map_err(|e| AonyxError::Memory(format!("encode messages: {e}")))?;
168 lock.execute(
169 &format!(
170 "INSERT INTO sessions ({COLUMNS}) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)"
171 ),
172 params![
173 to_insert.id.to_string(),
174 to_insert.project,
175 to_insert.created_at.to_rfc3339(),
176 to_insert.updated_at.to_rfc3339(),
177 to_insert.parent_id.map(|u| u.to_string()),
178 to_insert.title,
179 to_insert.turns as i64,
180 json,
181 ],
182 )
183 .map_err(|e| AonyxError::Memory(format!("insert session: {e}")))?;
184 Ok(())
185 })
186 .await
187 .map_err(|e| AonyxError::Memory(format!("insert join: {e}")))??;
188 Ok(record)
189 }
190}
191
192const MIGRATION_V1: &str = r#"
193CREATE TABLE IF NOT EXISTS sessions (
194 id TEXT PRIMARY KEY,
195 project TEXT NOT NULL,
196 created_at TEXT NOT NULL,
197 updated_at TEXT NOT NULL,
198 parent_id TEXT,
199 title TEXT NOT NULL,
200 turns INTEGER NOT NULL DEFAULT 0,
201 messages_json TEXT NOT NULL
202);
203CREATE INDEX IF NOT EXISTS idx_sessions_project_updated
204 ON sessions(project, updated_at DESC);
205"#;
206
207const COLUMNS: &str = "id, project, created_at, updated_at, parent_id, title, turns, messages_json";
208
209fn extract_title(messages: &[Message]) -> String {
210 let raw = messages
211 .iter()
212 .find(|m| m.role == Role::User)
213 .map(|m| m.content.trim().to_string())
214 .unwrap_or_else(|| "new session".to_string());
215 let single_line = raw.replace('\n', " ");
216 if single_line.chars().count() > 60 {
217 let cut: String = single_line.chars().take(60).collect();
218 format!("{cut}…")
219 } else if single_line.is_empty() {
220 "new session".to_string()
221 } else {
222 single_line
223 }
224}
225
226fn row_to_record(row: &rusqlite::Row<'_>) -> rusqlite::Result<SessionRecord> {
227 let id_str: String = row.get(0)?;
228 let project: String = row.get(1)?;
229 let created_raw: String = row.get(2)?;
230 let updated_raw: String = row.get(3)?;
231 let parent_raw: Option<String> = row.get(4)?;
232 let title: String = row.get(5)?;
233 let turns: i64 = row.get(6)?;
234 let messages_raw: String = row.get(7)?;
235
236 let id = Uuid::parse_str(&id_str).map_err(|e| {
237 rusqlite::Error::FromSqlConversionFailure(0, rusqlite::types::Type::Text, Box::new(e))
238 })?;
239 let parent_id = parent_raw
240 .as_deref()
241 .map(Uuid::parse_str)
242 .transpose()
243 .map_err(|e| {
244 rusqlite::Error::FromSqlConversionFailure(0, rusqlite::types::Type::Text, Box::new(e))
245 })?;
246 let created_at = DateTime::parse_from_rfc3339(&created_raw)
247 .map(|d| d.with_timezone(&Utc))
248 .unwrap_or_else(|_| Utc::now());
249 let updated_at = DateTime::parse_from_rfc3339(&updated_raw)
250 .map(|d| d.with_timezone(&Utc))
251 .unwrap_or_else(|_| Utc::now());
252 let messages: Vec<Message> = serde_json::from_str(&messages_raw).map_err(|e| {
253 rusqlite::Error::FromSqlConversionFailure(0, rusqlite::types::Type::Text, Box::new(e))
254 })?;
255
256 Ok(SessionRecord {
257 id,
258 project,
259 created_at,
260 updated_at,
261 parent_id,
262 title,
263 turns: turns.max(0) as u32,
264 messages,
265 })
266}
267
268#[async_trait]
269impl SessionStore for SqliteSessionStore {
270 async fn create(&self, project: &str, messages: Vec<Message>) -> Result<SessionRecord> {
271 let record = SessionRecord {
272 id: Uuid::new_v4(),
273 project: project.to_string(),
274 created_at: Utc::now(),
275 updated_at: Utc::now(),
276 parent_id: None,
277 title: extract_title(&messages),
278 turns: 0,
279 messages,
280 };
281 self.insert_record(record).await
282 }
283
284 async fn fork(
285 &self,
286 project: &str,
287 parent: SessionId,
288 messages: Vec<Message>,
289 turns: u32,
290 ) -> Result<SessionRecord> {
291 let record = SessionRecord {
292 id: Uuid::new_v4(),
293 project: project.to_string(),
294 created_at: Utc::now(),
295 updated_at: Utc::now(),
296 parent_id: Some(parent),
297 title: extract_title(&messages),
298 turns,
299 messages,
300 };
301 self.insert_record(record).await
302 }
303
304 async fn update(&self, id: SessionId, messages: Vec<Message>, turns: u32) -> Result<()> {
305 let conn = self.conn.clone();
306 let title = extract_title(&messages);
307 tokio::task::spawn_blocking(move || -> Result<()> {
308 let lock = conn.lock().expect("sessions mutex poisoned");
309 let json = serde_json::to_string(&messages)
310 .map_err(|e| AonyxError::Memory(format!("encode messages: {e}")))?;
311 lock.execute(
312 "UPDATE sessions
313 SET updated_at = ?2, messages_json = ?3, turns = ?4, title = ?5
314 WHERE id = ?1",
315 params![
316 id.to_string(),
317 Utc::now().to_rfc3339(),
318 json,
319 turns as i64,
320 title,
321 ],
322 )
323 .map_err(|e| AonyxError::Memory(format!("update session: {e}")))?;
324 Ok(())
325 })
326 .await
327 .map_err(|e| AonyxError::Memory(format!("update join: {e}")))?
328 }
329
330 async fn rename(&self, id: SessionId, title: &str) -> Result<()> {
331 let conn = self.conn.clone();
332 let title = title.to_string();
333 tokio::task::spawn_blocking(move || -> Result<()> {
334 let lock = conn.lock().expect("sessions mutex poisoned");
335 let n = lock
336 .execute(
337 "UPDATE sessions SET title = ?2, updated_at = ?3 WHERE id = ?1",
338 params![id.to_string(), title, Utc::now().to_rfc3339()],
339 )
340 .map_err(|e| AonyxError::Memory(format!("rename session: {e}")))?;
341 if n == 0 {
342 return Err(AonyxError::Memory(format!("rename: no session {id}")));
343 }
344 Ok(())
345 })
346 .await
347 .map_err(|e| AonyxError::Memory(format!("rename join: {e}")))?
348 }
349
350 async fn list_by_project(&self, project: &str, limit: usize) -> Result<Vec<SessionRecord>> {
351 let conn = self.conn.clone();
352 let project = project.to_string();
353 let limit = limit as i64;
354 tokio::task::spawn_blocking(move || -> Result<Vec<SessionRecord>> {
355 let lock = conn.lock().expect("sessions mutex poisoned");
356 let mut stmt = lock
357 .prepare(&format!(
358 "SELECT {COLUMNS} FROM sessions
359 WHERE project = ?1
360 ORDER BY updated_at DESC
361 LIMIT ?2"
362 ))
363 .map_err(|e| AonyxError::Memory(format!("prepare list: {e}")))?;
364 let rows = stmt
365 .query_map(params![project, limit], row_to_record)
366 .map_err(|e| AonyxError::Memory(format!("query list: {e}")))?;
367 let mut out = Vec::new();
368 for r in rows {
369 out.push(r.map_err(|e| AonyxError::Memory(format!("row decode: {e}")))?);
370 }
371 Ok(out)
372 })
373 .await
374 .map_err(|e| AonyxError::Memory(format!("list join: {e}")))?
375 }
376
377 async fn get(&self, id: SessionId) -> Result<Option<SessionRecord>> {
378 let conn = self.conn.clone();
379 tokio::task::spawn_blocking(move || -> Result<Option<SessionRecord>> {
380 let lock = conn.lock().expect("sessions mutex poisoned");
381 let mut stmt = lock
382 .prepare(&format!("SELECT {COLUMNS} FROM sessions WHERE id = ?1"))
383 .map_err(|e| AonyxError::Memory(format!("prepare get: {e}")))?;
384 stmt.query_row(params![id.to_string()], row_to_record)
385 .optional()
386 .map_err(|e| AonyxError::Memory(format!("get session: {e}")))
387 })
388 .await
389 .map_err(|e| AonyxError::Memory(format!("get join: {e}")))?
390 }
391
392 async fn delete(&self, id: SessionId) -> Result<()> {
393 let conn = self.conn.clone();
394 tokio::task::spawn_blocking(move || -> Result<()> {
395 let lock = conn.lock().expect("sessions mutex poisoned");
396 lock.execute(
397 "DELETE FROM sessions WHERE id = ?1",
398 params![id.to_string()],
399 )
400 .map_err(|e| AonyxError::Memory(format!("delete session: {e}")))?;
401 Ok(())
402 })
403 .await
404 .map_err(|e| AonyxError::Memory(format!("delete join: {e}")))?
405 }
406
407 async fn latest(&self, project: &str) -> Result<Option<SessionRecord>> {
408 let list = self.list_by_project(project, 1).await?;
409 Ok(list.into_iter().next())
410 }
411
412 async fn search(&self, query: &str, limit: usize) -> Result<Vec<SearchHit>> {
413 let conn = self.conn.clone();
414 let needle = query.to_string();
415 let like = format!("%{}%", needle);
416 tokio::task::spawn_blocking(move || -> Result<Vec<SearchHit>> {
417 let lock = conn.lock().expect("sessions mutex poisoned");
418 let mut stmt = lock
419 .prepare(&format!(
420 "SELECT {COLUMNS} FROM sessions
421 WHERE messages_json LIKE ?1 COLLATE NOCASE
422 OR title LIKE ?1 COLLATE NOCASE
423 ORDER BY updated_at DESC
424 LIMIT ?2"
425 ))
426 .map_err(|e| AonyxError::Memory(format!("prepare search: {e}")))?;
427 let rows = stmt
428 .query_map(params![like, limit as i64], row_to_record)
429 .map_err(|e| AonyxError::Memory(format!("query search: {e}")))?;
430 let mut out = Vec::new();
431 for r in rows {
432 let rec = r.map_err(|e| AonyxError::Memory(format!("row decode: {e}")))?;
433 let snippet = extract_snippet(&rec.messages, &needle);
434 out.push(SearchHit {
435 id: rec.id,
436 project: rec.project,
437 title: rec.title,
438 updated_at: rec.updated_at,
439 turns: rec.turns,
440 snippet,
441 });
442 }
443 Ok(out)
444 })
445 .await
446 .map_err(|e| AonyxError::Memory(format!("search join: {e}")))?
447 }
448
449 async fn find_by_id_prefix(&self, prefix: &str, limit: usize) -> Result<Vec<SessionRecord>> {
450 let conn = self.conn.clone();
451 let like = format!("{}%", prefix.to_lowercase());
452 tokio::task::spawn_blocking(move || -> Result<Vec<SessionRecord>> {
453 let lock = conn.lock().expect("sessions mutex poisoned");
454 let mut stmt = lock
455 .prepare(&format!(
456 "SELECT {COLUMNS} FROM sessions
457 WHERE id LIKE ?1 COLLATE NOCASE
458 ORDER BY updated_at DESC
459 LIMIT ?2"
460 ))
461 .map_err(|e| AonyxError::Memory(format!("prepare prefix: {e}")))?;
462 let rows = stmt
463 .query_map(params![like, limit as i64], row_to_record)
464 .map_err(|e| AonyxError::Memory(format!("query prefix: {e}")))?;
465 let mut out = Vec::new();
466 for r in rows {
467 out.push(r.map_err(|e| AonyxError::Memory(format!("row decode: {e}")))?);
468 }
469 Ok(out)
470 })
471 .await
472 .map_err(|e| AonyxError::Memory(format!("prefix join: {e}")))?
473 }
474}
475
476fn extract_snippet(messages: &[Message], needle: &str) -> String {
481 const WINDOW: usize = 120;
482 let lower_needle = needle.to_lowercase();
483 for m in messages {
484 let lower = m.content.to_lowercase();
485 if let Some(idx) = lower.find(&lower_needle) {
486 let chars: Vec<char> = m.content.chars().collect();
489 let mut byte_count = 0usize;
492 let mut char_idx = 0usize;
493 for (i, c) in chars.iter().enumerate() {
494 if byte_count >= idx {
495 char_idx = i;
496 break;
497 }
498 byte_count += c.len_utf8();
499 }
500 let start = char_idx.saturating_sub(WINDOW / 4);
501 let end = (start + WINDOW).min(chars.len());
502 let mut snip: String = chars[start..end].iter().collect();
503 snip = snip.replace('\n', " ");
504 if start > 0 {
505 snip.insert(0, '…');
506 }
507 if end < chars.len() {
508 snip.push('…');
509 }
510 return snip;
511 }
512 }
513 let first = messages
515 .iter()
516 .find(|m| m.role == Role::User)
517 .or_else(|| messages.first())
518 .map(|m| m.content.clone())
519 .unwrap_or_default();
520 let single: String = first.replace('\n', " ");
521 if single.chars().count() > 120 {
522 let cut: String = single.chars().take(120).collect();
523 format!("{cut}…")
524 } else {
525 single
526 }
527}
528
529#[cfg(test)]
530mod tests {
531 use super::*;
532 use aonyx_core::Role;
533
534 fn msg(role: Role, content: &str) -> Message {
535 Message::new(role, content.to_string())
536 }
537
538 #[tokio::test]
539 async fn create_then_get_round_trips() {
540 let store = SqliteSessionStore::open_in_memory().unwrap();
541 let messages = vec![msg(Role::System, "be brief"), msg(Role::User, "hello")];
542 let created = store.create("demo", messages.clone()).await.unwrap();
543 let got = store.get(created.id).await.unwrap().expect("found");
544 assert_eq!(got.project, "demo");
545 assert_eq!(got.title, "hello");
546 assert_eq!(got.messages.len(), 2);
547 assert_eq!(got.turns, 0);
548 }
549
550 #[tokio::test]
551 async fn update_bumps_turns_and_title() {
552 let store = SqliteSessionStore::open_in_memory().unwrap();
553 let created = store
554 .create("demo", vec![msg(Role::User, "first")])
555 .await
556 .unwrap();
557 let new_msgs = vec![
558 msg(Role::User, "second user query that drives the title"),
559 msg(Role::Assistant, "ok"),
560 ];
561 store.update(created.id, new_msgs, 1).await.unwrap();
562 let got = store.get(created.id).await.unwrap().unwrap();
563 assert_eq!(got.turns, 1);
564 assert!(got.title.starts_with("second user"));
565 }
566
567 #[tokio::test]
568 async fn rename_sets_explicit_title_and_survives() {
569 let store = SqliteSessionStore::open_in_memory().unwrap();
570 let created = store
571 .create("demo", vec![msg(Role::User, "auto-derived title")])
572 .await
573 .unwrap();
574 store.rename(created.id, "my refactor").await.unwrap();
575 let got = store.get(created.id).await.unwrap().unwrap();
576 assert_eq!(got.title, "my refactor");
577 assert!(store.rename(SessionId::new_v4(), "x").await.is_err());
579 }
580
581 #[tokio::test]
582 async fn list_orders_by_updated_desc_and_scopes_project() {
583 let store = SqliteSessionStore::open_in_memory().unwrap();
584 let _a = store
585 .create("demo", vec![msg(Role::User, "older")])
586 .await
587 .unwrap();
588 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
589 let b = store
590 .create("demo", vec![msg(Role::User, "newer")])
591 .await
592 .unwrap();
593 let _c = store
594 .create("other", vec![msg(Role::User, "wrong project")])
595 .await
596 .unwrap();
597
598 let list = store.list_by_project("demo", 10).await.unwrap();
599 assert_eq!(list.len(), 2);
600 assert_eq!(list[0].id, b.id);
601 }
602
603 #[tokio::test]
604 async fn latest_returns_most_recent_for_project() {
605 let store = SqliteSessionStore::open_in_memory().unwrap();
606 let _ = store
607 .create("demo", vec![msg(Role::User, "old")])
608 .await
609 .unwrap();
610 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
611 let recent = store
612 .create("demo", vec![msg(Role::User, "fresh")])
613 .await
614 .unwrap();
615
616 let latest = store.latest("demo").await.unwrap().unwrap();
617 assert_eq!(latest.id, recent.id);
618 assert!(store.latest("nothing").await.unwrap().is_none());
619 }
620
621 #[tokio::test]
622 async fn delete_removes_a_session() {
623 let store = SqliteSessionStore::open_in_memory().unwrap();
624 let s = store
625 .create("demo", vec![msg(Role::User, "x")])
626 .await
627 .unwrap();
628 store.delete(s.id).await.unwrap();
629 assert!(store.get(s.id).await.unwrap().is_none());
630 }
631
632 #[test]
633 fn extract_title_truncates_long_first_user_message() {
634 let m = vec![msg(Role::User, &"a".repeat(200))];
635 let title = extract_title(&m);
636 assert!(title.chars().count() <= 61);
637 assert!(title.ends_with('…'));
638 }
639
640 #[test]
641 fn extract_title_collapses_newlines() {
642 let m = vec![msg(Role::User, "line one\nline two\nline three")];
643 let title = extract_title(&m);
644 assert!(!title.contains('\n'));
645 assert!(title.contains("line one"));
646 }
647
648 #[tokio::test]
649 async fn fork_copies_history_and_sets_parent_id() {
650 let store = SqliteSessionStore::open_in_memory().unwrap();
651 let parent = store
652 .create("demo", vec![msg(Role::User, "original line")])
653 .await
654 .unwrap();
655 let forked = store
656 .fork(
657 "demo",
658 parent.id,
659 vec![
660 msg(Role::User, "original line"),
661 msg(Role::Assistant, "reply"),
662 ],
663 3,
664 )
665 .await
666 .unwrap();
667 assert_ne!(forked.id, parent.id);
668 assert_eq!(forked.parent_id, Some(parent.id));
669 assert_eq!(forked.turns, 3);
670 assert_eq!(forked.messages.len(), 2);
671 let reloaded = store.get(forked.id).await.unwrap().unwrap();
673 assert_eq!(reloaded.parent_id, Some(parent.id));
674 }
675
676 #[tokio::test]
677 async fn search_finds_hits_across_message_bodies() {
678 let store = SqliteSessionStore::open_in_memory().unwrap();
679 let _ = store
680 .create(
681 "demo",
682 vec![msg(Role::User, "implement OAuth flow for the API")],
683 )
684 .await
685 .unwrap();
686 let _ = store
687 .create("demo", vec![msg(Role::User, "unrelated work")])
688 .await
689 .unwrap();
690 let hits = store.search("oauth", 10).await.unwrap();
691 assert_eq!(hits.len(), 1);
692 assert!(hits[0].snippet.to_lowercase().contains("oauth"));
693 }
694
695 #[tokio::test]
696 async fn search_is_case_insensitive() {
697 let store = SqliteSessionStore::open_in_memory().unwrap();
698 let _ = store
699 .create("demo", vec![msg(Role::User, "FIX THE LOGIN BUG")])
700 .await
701 .unwrap();
702 let hits = store.search("login", 10).await.unwrap();
703 assert_eq!(hits.len(), 1);
704 }
705
706 #[tokio::test]
707 async fn search_matches_title_when_body_does_not() {
708 let store = SqliteSessionStore::open_in_memory().unwrap();
709 let _ = store
710 .create("demo", vec![msg(Role::User, "deploy pipeline rework")])
711 .await
712 .unwrap();
713 let hits = store.search("deploy", 10).await.unwrap();
714 assert_eq!(hits.len(), 1);
715 }
716
717 #[tokio::test]
718 async fn find_by_id_prefix_resolves_short_id() {
719 let store = SqliteSessionStore::open_in_memory().unwrap();
720 let created = store
721 .create("demo", vec![msg(Role::User, "x")])
722 .await
723 .unwrap();
724 let prefix: String = created.id.to_string().chars().take(8).collect();
725 let matches = store.find_by_id_prefix(&prefix, 5).await.unwrap();
726 assert_eq!(matches.len(), 1);
727 assert_eq!(matches[0].id, created.id);
728 }
729
730 #[test]
731 fn extract_snippet_returns_window_around_match() {
732 let msgs = vec![msg(
733 Role::User,
734 "this is a long preamble describing the OAuth flow setup and then more text",
735 )];
736 let snip = extract_snippet(&msgs, "oauth");
737 assert!(snip.to_lowercase().contains("oauth"));
738 assert!(snip.starts_with("…") || snip.starts_with("this"));
739 }
740
741 #[test]
742 fn extract_snippet_falls_back_to_first_user_message() {
743 let msgs = vec![msg(Role::User, "no match here")];
744 let snip = extract_snippet(&msgs, "missing");
745 assert!(snip.contains("no match here"));
746 }
747}