1use crate::llm::{Message, TokenUsage, ToolDefinition};
33use crate::planning::Task;
34use crate::prompts::PlanningMode;
35use crate::queue::SessionQueueConfig;
36use crate::tools::ArtifactStore;
37use crate::trace::TraceEvent;
38use crate::verification::VerificationReport;
39use anyhow::{Context, Result};
40use serde::{Deserialize, Serialize};
41use std::collections::HashMap;
42use std::path::{Path, PathBuf};
43use tokio::fs;
44use tokio::io::AsyncWriteExt;
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
52pub enum SessionState {
53 #[default]
54 Unknown = 0,
55 Active = 1,
56 Paused = 2,
57 Completed = 3,
58 Error = 4,
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct ContextUsage {
64 pub used_tokens: usize,
65 pub max_tokens: usize,
66 pub percent: f32,
67 pub turns: usize,
68}
69
70impl Default for ContextUsage {
71 fn default() -> Self {
72 Self {
73 used_tokens: 0,
74 max_tokens: 200_000,
75 percent: 0.0,
76 turns: 0,
77 }
78 }
79}
80
81pub const DEFAULT_AUTO_COMPACT_THRESHOLD: f32 = 0.80;
83
84fn default_auto_compact_threshold() -> f32 {
85 DEFAULT_AUTO_COMPACT_THRESHOLD
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct SessionConfig {
91 pub name: String,
92 pub workspace: String,
93 pub system_prompt: Option<String>,
94 pub max_context_length: u32,
95 pub auto_compact: bool,
96 #[serde(default = "default_auto_compact_threshold")]
99 pub auto_compact_threshold: f32,
100 #[serde(default)]
102 pub storage_type: crate::config::StorageBackend,
103 #[serde(skip_serializing_if = "Option::is_none")]
108 pub queue_config: Option<SessionQueueConfig>,
109 #[serde(skip_serializing_if = "Option::is_none")]
111 pub confirmation_policy: Option<crate::hitl::ConfirmationPolicy>,
112 #[serde(skip_serializing_if = "Option::is_none")]
114 pub permission_policy: Option<crate::permissions::PermissionPolicy>,
115 #[serde(skip_serializing_if = "Option::is_none")]
117 pub parent_id: Option<String>,
118 #[serde(skip_serializing_if = "Option::is_none")]
120 pub security_config: Option<crate::security::SecurityConfig>,
121 #[serde(skip)]
123 pub hook_engine: Option<std::sync::Arc<dyn crate::hooks::HookExecutor>>,
124 #[serde(default)]
126 pub planning_mode: PlanningMode,
127 #[serde(default)]
129 pub goal_tracking: bool,
130}
131
132impl Default for SessionConfig {
133 fn default() -> Self {
134 Self {
135 name: String::new(),
136 workspace: String::new(),
137 system_prompt: None,
138 max_context_length: 0,
139 auto_compact: false,
140 auto_compact_threshold: DEFAULT_AUTO_COMPACT_THRESHOLD,
141 storage_type: crate::config::StorageBackend::default(),
142 queue_config: None,
143 confirmation_policy: None,
144 permission_policy: None,
145 parent_id: None,
146 security_config: None,
147 hook_engine: None,
148 planning_mode: PlanningMode::default(),
149 goal_tracking: false,
150 }
151 }
152}
153
154#[derive(Debug, Clone, Serialize, Deserialize)]
159pub struct SessionData {
160 pub id: String,
162
163 pub config: SessionConfig,
165
166 pub state: SessionState,
168
169 pub messages: Vec<Message>,
171
172 pub context_usage: ContextUsage,
174
175 pub total_usage: TokenUsage,
177
178 #[serde(default)]
180 pub total_cost: f64,
181
182 #[serde(skip_serializing_if = "Option::is_none")]
184 pub model_name: Option<String>,
185
186 #[serde(default)]
188 pub cost_records: Vec<crate::telemetry::LlmCostRecord>,
189
190 pub tool_names: Vec<String>,
192
193 pub thinking_enabled: bool,
195
196 pub thinking_budget: Option<usize>,
198
199 pub created_at: i64,
201
202 pub updated_at: i64,
204
205 #[serde(skip_serializing_if = "Option::is_none")]
207 pub llm_config: Option<LlmConfigData>,
208
209 #[serde(default, alias = "todos")]
211 pub tasks: Vec<Task>,
212
213 #[serde(skip_serializing_if = "Option::is_none")]
215 pub parent_id: Option<String>,
216}
217
218#[derive(Debug, Clone, Serialize, Deserialize)]
220pub struct LlmConfigData {
221 pub provider: String,
222 pub model: String,
223 #[serde(skip_serializing, default)]
225 pub api_key: Option<String>,
226 pub base_url: Option<String>,
227}
228
229impl SessionData {
230 pub fn tool_names_from_definitions(tools: &[ToolDefinition]) -> Vec<String> {
232 tools.iter().map(|t| t.name.clone()).collect()
233 }
234}
235
236#[async_trait::async_trait]
242pub trait SessionStore: Send + Sync {
243 async fn save(&self, session: &SessionData) -> Result<()>;
245
246 async fn load(&self, id: &str) -> Result<Option<SessionData>>;
248
249 async fn delete(&self, id: &str) -> Result<()>;
251
252 async fn list(&self) -> Result<Vec<String>>;
254
255 async fn exists(&self, id: &str) -> Result<bool>;
257
258 async fn save_artifacts(&self, _id: &str, _artifacts: &ArtifactStore) -> Result<()> {
260 Ok(())
261 }
262
263 async fn load_artifacts(&self, _id: &str) -> Result<Option<ArtifactStore>> {
265 Ok(None)
266 }
267
268 async fn save_trace_events(&self, _id: &str, _events: &[TraceEvent]) -> Result<()> {
270 Ok(())
271 }
272
273 async fn load_trace_events(&self, _id: &str) -> Result<Option<Vec<TraceEvent>>> {
275 Ok(None)
276 }
277
278 async fn save_verification_reports(
280 &self,
281 _id: &str,
282 _reports: &[VerificationReport],
283 ) -> Result<()> {
284 Ok(())
285 }
286
287 async fn load_verification_reports(
289 &self,
290 _id: &str,
291 ) -> Result<Option<Vec<VerificationReport>>> {
292 Ok(None)
293 }
294
295 async fn health_check(&self) -> Result<()> {
297 Ok(())
298 }
299
300 fn backend_name(&self) -> &str {
302 "unknown"
303 }
304}
305
306pub struct FileSessionStore {
319 dir: PathBuf,
321}
322
323impl FileSessionStore {
324 pub async fn new<P: AsRef<Path>>(dir: P) -> Result<Self> {
328 let dir = dir.as_ref().to_path_buf();
329
330 fs::create_dir_all(&dir)
332 .await
333 .with_context(|| format!("Failed to create session directory: {}", dir.display()))?;
334
335 Ok(Self { dir })
336 }
337
338 fn session_path(&self, id: &str) -> PathBuf {
340 self.dir.join(format!("{}.json", safe_session_id(id)))
342 }
343
344 fn artifact_dir(&self, id: &str) -> PathBuf {
345 self.dir.join("artifacts").join(safe_session_id(id))
346 }
347
348 fn trace_path(&self, id: &str) -> PathBuf {
349 self.dir
350 .join("traces")
351 .join(format!("{}.json", safe_session_id(id)))
352 }
353
354 fn verification_path(&self, id: &str) -> PathBuf {
355 self.dir
356 .join("verification")
357 .join(format!("{}.json", safe_session_id(id)))
358 }
359}
360
361fn safe_session_id(id: &str) -> String {
362 id.replace(['/', '\\'], "_").replace("..", "_")
363}
364
365#[async_trait::async_trait]
366impl SessionStore for FileSessionStore {
367 async fn save(&self, session: &SessionData) -> Result<()> {
368 let path = self.session_path(&session.id);
369
370 let json = serde_json::to_string_pretty(session)
372 .with_context(|| format!("Failed to serialize session: {}", session.id))?;
373
374 let unique_suffix = format!(
377 "{}.{}",
378 std::time::SystemTime::now()
379 .duration_since(std::time::UNIX_EPOCH)
380 .unwrap()
381 .as_nanos(),
382 std::process::id()
383 );
384 let temp_path = path.with_extension(format!("json.{}.tmp", unique_suffix));
385
386 let mut file = fs::File::create(&temp_path)
387 .await
388 .with_context(|| format!("Failed to create temp file: {}", temp_path.display()))?;
389
390 file.write_all(json.as_bytes())
391 .await
392 .with_context(|| format!("Failed to write session data: {}", session.id))?;
393
394 file.sync_all()
395 .await
396 .with_context(|| format!("Failed to sync session file: {}", session.id))?;
397
398 fs::rename(&temp_path, &path)
400 .await
401 .with_context(|| format!("Failed to rename session file: {}", session.id))?;
402
403 tracing::debug!("Saved session {} to {}", session.id, path.display());
404 Ok(())
405 }
406
407 async fn load(&self, id: &str) -> Result<Option<SessionData>> {
408 let path = self.session_path(id);
409
410 if !path.exists() {
411 return Ok(None);
412 }
413
414 let json = fs::read_to_string(&path)
415 .await
416 .with_context(|| format!("Failed to read session file: {}", path.display()))?;
417
418 let session: SessionData = serde_json::from_str(&json)
419 .with_context(|| format!("Failed to parse session file: {}", path.display()))?;
420
421 tracing::debug!("Loaded session {} from {}", id, path.display());
422 Ok(Some(session))
423 }
424
425 async fn delete(&self, id: &str) -> Result<()> {
426 let path = self.session_path(id);
427
428 if path.exists() {
429 fs::remove_file(&path)
430 .await
431 .with_context(|| format!("Failed to delete session file: {}", path.display()))?;
432
433 tracing::debug!("Deleted session {} from {}", id, path.display());
434 }
435
436 let artifact_dir = self.artifact_dir(id);
437 if artifact_dir.exists() {
438 fs::remove_dir_all(&artifact_dir).await.with_context(|| {
439 format!(
440 "Failed to delete artifact directory for session {}: {}",
441 id,
442 artifact_dir.display()
443 )
444 })?;
445 }
446
447 let trace_path = self.trace_path(id);
448 if trace_path.exists() {
449 fs::remove_file(&trace_path).await.with_context(|| {
450 format!(
451 "Failed to delete trace file for session {}: {}",
452 id,
453 trace_path.display()
454 )
455 })?;
456 }
457
458 let verification_path = self.verification_path(id);
459 if verification_path.exists() {
460 fs::remove_file(&verification_path).await.with_context(|| {
461 format!(
462 "Failed to delete verification report file for session {}: {}",
463 id,
464 verification_path.display()
465 )
466 })?;
467 }
468
469 Ok(())
470 }
471
472 async fn list(&self) -> Result<Vec<String>> {
473 let mut session_ids = Vec::new();
474
475 let mut entries = fs::read_dir(&self.dir)
476 .await
477 .with_context(|| format!("Failed to read session directory: {}", self.dir.display()))?;
478
479 while let Some(entry) = entries.next_entry().await? {
480 let path = entry.path();
481
482 if path.extension().is_some_and(|ext| ext == "json") {
483 if let Some(stem) = path.file_stem() {
484 if let Some(id) = stem.to_str() {
485 session_ids.push(id.to_string());
486 }
487 }
488 }
489 }
490
491 Ok(session_ids)
492 }
493
494 async fn exists(&self, id: &str) -> Result<bool> {
495 let path = self.session_path(id);
496 Ok(path.exists())
497 }
498
499 async fn save_artifacts(&self, id: &str, artifacts: &ArtifactStore) -> Result<()> {
500 let artifact_dir = self.artifact_dir(id);
501 artifacts.save_to_dir(&artifact_dir).with_context(|| {
502 format!(
503 "Failed to save artifacts for session {} to {}",
504 id,
505 artifact_dir.display()
506 )
507 })
508 }
509
510 async fn load_artifacts(&self, id: &str) -> Result<Option<ArtifactStore>> {
511 let artifact_dir = self.artifact_dir(id);
512 if !artifact_dir.exists() {
513 return Ok(None);
514 }
515
516 let artifacts = ArtifactStore::load_from_dir(&artifact_dir).with_context(|| {
517 format!(
518 "Failed to load artifacts for session {} from {}",
519 id,
520 artifact_dir.display()
521 )
522 })?;
523 Ok(Some(artifacts))
524 }
525
526 async fn save_trace_events(&self, id: &str, events: &[TraceEvent]) -> Result<()> {
527 let path = self.trace_path(id);
528 if let Some(parent) = path.parent() {
529 fs::create_dir_all(parent).await.with_context(|| {
530 format!("Failed to create trace directory: {}", parent.display())
531 })?;
532 }
533
534 let json = serde_json::to_string_pretty(events)
535 .with_context(|| format!("Failed to serialize trace events for session {id}"))?;
536 fs::write(&path, json)
537 .await
538 .with_context(|| format!("Failed to write trace events to {}", path.display()))?;
539 Ok(())
540 }
541
542 async fn load_trace_events(&self, id: &str) -> Result<Option<Vec<TraceEvent>>> {
543 let path = self.trace_path(id);
544 if !path.exists() {
545 return Ok(None);
546 }
547
548 let json = fs::read_to_string(&path)
549 .await
550 .with_context(|| format!("Failed to read trace events from {}", path.display()))?;
551 let events = serde_json::from_str(&json)
552 .with_context(|| format!("Failed to parse trace events from {}", path.display()))?;
553 Ok(Some(events))
554 }
555
556 async fn save_verification_reports(
557 &self,
558 id: &str,
559 reports: &[VerificationReport],
560 ) -> Result<()> {
561 let path = self.verification_path(id);
562 if let Some(parent) = path.parent() {
563 fs::create_dir_all(parent).await.with_context(|| {
564 format!(
565 "Failed to create verification report directory: {}",
566 parent.display()
567 )
568 })?;
569 }
570
571 let json = serde_json::to_string_pretty(reports).with_context(|| {
572 format!("Failed to serialize verification reports for session {id}")
573 })?;
574 fs::write(&path, json).await.with_context(|| {
575 format!("Failed to write verification reports to {}", path.display())
576 })?;
577 Ok(())
578 }
579
580 async fn load_verification_reports(&self, id: &str) -> Result<Option<Vec<VerificationReport>>> {
581 let path = self.verification_path(id);
582 if !path.exists() {
583 return Ok(None);
584 }
585
586 let json = fs::read_to_string(&path).await.with_context(|| {
587 format!(
588 "Failed to read verification reports from {}",
589 path.display()
590 )
591 })?;
592 let reports = serde_json::from_str(&json).with_context(|| {
593 format!(
594 "Failed to parse verification reports from {}",
595 path.display()
596 )
597 })?;
598 Ok(Some(reports))
599 }
600
601 async fn health_check(&self) -> Result<()> {
602 let probe = self.dir.join(".health_check");
604 fs::write(&probe, b"ok")
605 .await
606 .with_context(|| format!("Store directory not writable: {}", self.dir.display()))?;
607 let _ = fs::remove_file(&probe).await;
608 Ok(())
609 }
610
611 fn backend_name(&self) -> &str {
612 "file"
613 }
614}
615
616pub struct MemorySessionStore {
622 sessions: tokio::sync::RwLock<HashMap<String, SessionData>>,
623 artifacts: tokio::sync::RwLock<HashMap<String, ArtifactStore>>,
624 trace_events: tokio::sync::RwLock<HashMap<String, Vec<TraceEvent>>>,
625 verification_reports: tokio::sync::RwLock<HashMap<String, Vec<VerificationReport>>>,
626}
627
628impl MemorySessionStore {
629 pub fn new() -> Self {
630 Self {
631 sessions: tokio::sync::RwLock::new(HashMap::new()),
632 artifacts: tokio::sync::RwLock::new(HashMap::new()),
633 trace_events: tokio::sync::RwLock::new(HashMap::new()),
634 verification_reports: tokio::sync::RwLock::new(HashMap::new()),
635 }
636 }
637}
638
639impl Default for MemorySessionStore {
640 fn default() -> Self {
641 Self::new()
642 }
643}
644
645#[async_trait::async_trait]
646impl SessionStore for MemorySessionStore {
647 async fn save(&self, session: &SessionData) -> Result<()> {
648 let mut sessions = self.sessions.write().await;
649 sessions.insert(session.id.clone(), session.clone());
650 Ok(())
651 }
652
653 async fn load(&self, id: &str) -> Result<Option<SessionData>> {
654 let sessions = self.sessions.read().await;
655 Ok(sessions.get(id).cloned())
656 }
657
658 async fn delete(&self, id: &str) -> Result<()> {
659 let mut sessions = self.sessions.write().await;
660 sessions.remove(id);
661 self.artifacts.write().await.remove(id);
662 self.trace_events.write().await.remove(id);
663 self.verification_reports.write().await.remove(id);
664 Ok(())
665 }
666
667 async fn list(&self) -> Result<Vec<String>> {
668 let sessions = self.sessions.read().await;
669 Ok(sessions.keys().cloned().collect())
670 }
671
672 async fn exists(&self, id: &str) -> Result<bool> {
673 let sessions = self.sessions.read().await;
674 Ok(sessions.contains_key(id))
675 }
676
677 async fn save_artifacts(&self, id: &str, artifacts: &ArtifactStore) -> Result<()> {
678 self.artifacts
679 .write()
680 .await
681 .insert(id.to_string(), artifacts.clone());
682 Ok(())
683 }
684
685 async fn load_artifacts(&self, id: &str) -> Result<Option<ArtifactStore>> {
686 Ok(self.artifacts.read().await.get(id).cloned())
687 }
688
689 async fn save_trace_events(&self, id: &str, events: &[TraceEvent]) -> Result<()> {
690 self.trace_events
691 .write()
692 .await
693 .insert(id.to_string(), events.to_vec());
694 Ok(())
695 }
696
697 async fn load_trace_events(&self, id: &str) -> Result<Option<Vec<TraceEvent>>> {
698 Ok(self.trace_events.read().await.get(id).cloned())
699 }
700
701 async fn save_verification_reports(
702 &self,
703 id: &str,
704 reports: &[VerificationReport],
705 ) -> Result<()> {
706 self.verification_reports
707 .write()
708 .await
709 .insert(id.to_string(), reports.to_vec());
710 Ok(())
711 }
712
713 async fn load_verification_reports(&self, id: &str) -> Result<Option<Vec<VerificationReport>>> {
714 Ok(self.verification_reports.read().await.get(id).cloned())
715 }
716
717 fn backend_name(&self) -> &str {
718 "memory"
719 }
720}
721
722#[cfg(test)]
727mod tests {
728 use super::*;
729 use crate::hitl::ConfirmationPolicy;
730 use crate::permissions::PermissionPolicy;
731 use crate::prompts::PlanningMode;
732 use crate::queue::SessionQueueConfig;
733 use tempfile::tempdir;
734
735 fn create_test_session_data() -> SessionData {
736 SessionData {
737 id: "test-session-1".to_string(),
738 config: SessionConfig {
739 name: "Test Session".to_string(),
740 workspace: "/tmp/workspace".to_string(),
741 system_prompt: Some("You are helpful.".to_string()),
742 max_context_length: 200000,
743 auto_compact: false,
744 auto_compact_threshold: DEFAULT_AUTO_COMPACT_THRESHOLD,
745 storage_type: crate::config::StorageBackend::File,
746 queue_config: None,
747 confirmation_policy: None,
748 permission_policy: None,
749 parent_id: None,
750 security_config: None,
751 hook_engine: None,
752 planning_mode: PlanningMode::default(),
753 goal_tracking: false,
754 },
755 state: SessionState::Active,
756 messages: vec![
757 Message::user("Hello"),
758 Message {
759 role: "assistant".to_string(),
760 content: vec![crate::llm::ContentBlock::Text {
761 text: "Hi there!".to_string(),
762 }],
763 reasoning_content: None,
764 },
765 ],
766 context_usage: ContextUsage {
767 used_tokens: 100,
768 max_tokens: 200000,
769 percent: 0.0005,
770 turns: 2,
771 },
772 total_usage: TokenUsage {
773 prompt_tokens: 50,
774 completion_tokens: 50,
775 total_tokens: 100,
776 cache_read_tokens: None,
777 cache_write_tokens: None,
778 },
779 tool_names: vec!["bash".to_string(), "read".to_string()],
780 thinking_enabled: false,
781 thinking_budget: None,
782 created_at: 1700000000,
783 updated_at: 1700000100,
784 llm_config: None,
785 tasks: vec![],
786 parent_id: None,
787 total_cost: 0.0,
788 model_name: None,
789 cost_records: Vec::new(),
790 }
791 }
792
793 fn create_test_verification_report() -> VerificationReport {
794 VerificationReport::new(
795 "program:test",
796 vec![crate::verification::VerificationCheck::required(
797 "check:test",
798 "test",
799 "Run tests",
800 )
801 .with_status(crate::verification::VerificationStatus::Passed)],
802 )
803 }
804
805 #[tokio::test]
810 async fn test_file_store_save_and_load() {
811 let dir = tempdir().unwrap();
812 let store = FileSessionStore::new(dir.path()).await.unwrap();
813
814 let session = create_test_session_data();
815
816 store.save(&session).await.unwrap();
818
819 let loaded = store.load(&session.id).await.unwrap();
821 assert!(loaded.is_some());
822
823 let loaded = loaded.unwrap();
824 assert_eq!(loaded.id, session.id);
825 assert_eq!(loaded.config.name, session.config.name);
826 assert_eq!(loaded.messages.len(), 2);
827 assert_eq!(loaded.state, SessionState::Active);
828 }
829
830 #[tokio::test]
831 async fn test_file_store_load_nonexistent() {
832 let dir = tempdir().unwrap();
833 let store = FileSessionStore::new(dir.path()).await.unwrap();
834
835 let loaded = store.load("nonexistent").await.unwrap();
836 assert!(loaded.is_none());
837 }
838
839 #[tokio::test]
840 async fn test_file_store_delete() {
841 let dir = tempdir().unwrap();
842 let store = FileSessionStore::new(dir.path()).await.unwrap();
843
844 let session = create_test_session_data();
845 store.save(&session).await.unwrap();
846
847 assert!(store.exists(&session.id).await.unwrap());
849
850 store.delete(&session.id).await.unwrap();
852
853 assert!(!store.exists(&session.id).await.unwrap());
855 assert!(store.load(&session.id).await.unwrap().is_none());
856 }
857
858 #[tokio::test]
859 async fn test_file_store_save_and_load_artifacts() {
860 let dir = tempdir().unwrap();
861 let store = FileSessionStore::new(dir.path()).await.unwrap();
862 let artifacts = ArtifactStore::new();
863 artifacts.put(crate::tools::ToolArtifact {
864 artifact_id: "tool-output:test:a".to_string(),
865 artifact_uri: "a3s://tool-output/test/a".to_string(),
866 tool_name: "test".to_string(),
867 content: "artifact content".to_string(),
868 original_bytes: 16,
869 shown_bytes: 4,
870 });
871
872 store.save_artifacts("session/a", &artifacts).await.unwrap();
873 let loaded = store
874 .load_artifacts("session/a")
875 .await
876 .unwrap()
877 .expect("artifacts");
878
879 assert_eq!(loaded.len(), 1);
880 assert_eq!(
881 loaded
882 .get("a3s://tool-output/test/a")
883 .expect("artifact")
884 .content,
885 "artifact content"
886 );
887 }
888
889 #[tokio::test]
890 async fn test_file_store_save_and_load_trace_events() {
891 let dir = tempdir().unwrap();
892 let store = FileSessionStore::new(dir.path()).await.unwrap();
893 let event = TraceEvent::tool_execution(
894 "read",
895 true,
896 0,
897 std::time::Duration::from_millis(9),
898 12,
899 Some(&serde_json::json!({
900 "artifact": {
901 "artifact_uri": "a3s://tool-output/read/abc"
902 }
903 })),
904 );
905
906 store
907 .save_trace_events("session/a", std::slice::from_ref(&event))
908 .await
909 .unwrap();
910 let loaded = store
911 .load_trace_events("session/a")
912 .await
913 .unwrap()
914 .expect("trace events");
915
916 assert_eq!(loaded, vec![event]);
917 }
918
919 #[tokio::test]
920 async fn test_file_store_save_and_load_verification_reports() {
921 let dir = tempdir().unwrap();
922 let store = FileSessionStore::new(dir.path()).await.unwrap();
923 let report = create_test_verification_report();
924
925 store
926 .save_verification_reports("session/a", std::slice::from_ref(&report))
927 .await
928 .unwrap();
929 let loaded = store
930 .load_verification_reports("session/a")
931 .await
932 .unwrap()
933 .expect("verification reports");
934
935 assert_eq!(loaded, vec![report]);
936 }
937
938 #[tokio::test]
939 async fn test_memory_store_save_load_and_delete_artifacts() {
940 let store = MemorySessionStore::new();
941 let session = create_test_session_data();
942 store.save(&session).await.unwrap();
943 let artifacts = ArtifactStore::new();
944 artifacts.put(crate::tools::ToolArtifact {
945 artifact_id: "tool-output:test:a".to_string(),
946 artifact_uri: "a3s://tool-output/test/a".to_string(),
947 tool_name: "test".to_string(),
948 content: "artifact content".to_string(),
949 original_bytes: 16,
950 shown_bytes: 4,
951 });
952
953 store.save_artifacts(&session.id, &artifacts).await.unwrap();
954 assert!(store
955 .load_artifacts(&session.id)
956 .await
957 .unwrap()
958 .expect("artifacts")
959 .get("a3s://tool-output/test/a")
960 .is_some());
961
962 store.delete(&session.id).await.unwrap();
963 assert!(store.load_artifacts(&session.id).await.unwrap().is_none());
964 }
965
966 #[tokio::test]
967 async fn test_memory_store_save_load_and_delete_trace_events() {
968 let store = MemorySessionStore::new();
969 let session = create_test_session_data();
970 let event = TraceEvent::tool_execution(
971 "grep",
972 false,
973 1,
974 std::time::Duration::from_millis(2),
975 24,
976 None,
977 );
978
979 store.save(&session).await.unwrap();
980 store
981 .save_trace_events(&session.id, std::slice::from_ref(&event))
982 .await
983 .unwrap();
984 let loaded = store
985 .load_trace_events(&session.id)
986 .await
987 .unwrap()
988 .expect("trace events");
989 assert_eq!(loaded, vec![event]);
990
991 store.delete(&session.id).await.unwrap();
992 assert!(store
993 .load_trace_events(&session.id)
994 .await
995 .unwrap()
996 .is_none());
997 }
998
999 #[tokio::test]
1000 async fn test_memory_store_save_load_and_delete_verification_reports() {
1001 let store = MemorySessionStore::new();
1002 let session = create_test_session_data();
1003 let report = create_test_verification_report();
1004
1005 store.save(&session).await.unwrap();
1006 store
1007 .save_verification_reports(&session.id, std::slice::from_ref(&report))
1008 .await
1009 .unwrap();
1010 let loaded = store
1011 .load_verification_reports(&session.id)
1012 .await
1013 .unwrap()
1014 .expect("verification reports");
1015 assert_eq!(loaded, vec![report]);
1016
1017 store.delete(&session.id).await.unwrap();
1018 assert!(store
1019 .load_verification_reports(&session.id)
1020 .await
1021 .unwrap()
1022 .is_none());
1023 }
1024
1025 #[tokio::test]
1026 async fn test_file_store_list() {
1027 let dir = tempdir().unwrap();
1028 let store = FileSessionStore::new(dir.path()).await.unwrap();
1029
1030 let list = store.list().await.unwrap();
1032 assert!(list.is_empty());
1033
1034 for i in 1..=3 {
1036 let mut session = create_test_session_data();
1037 session.id = format!("session-{}", i);
1038 store.save(&session).await.unwrap();
1039 }
1040
1041 let list = store.list().await.unwrap();
1043 assert_eq!(list.len(), 3);
1044 assert!(list.contains(&"session-1".to_string()));
1045 assert!(list.contains(&"session-2".to_string()));
1046 assert!(list.contains(&"session-3".to_string()));
1047 }
1048
1049 #[tokio::test]
1050 async fn test_file_store_overwrite() {
1051 let dir = tempdir().unwrap();
1052 let store = FileSessionStore::new(dir.path()).await.unwrap();
1053
1054 let mut session = create_test_session_data();
1055 store.save(&session).await.unwrap();
1056
1057 session.messages.push(Message::user("Another message"));
1059 session.updated_at = 1700000200;
1060 store.save(&session).await.unwrap();
1061
1062 let loaded = store.load(&session.id).await.unwrap().unwrap();
1064 assert_eq!(loaded.messages.len(), 3);
1065 assert_eq!(loaded.updated_at, 1700000200);
1066 }
1067
1068 #[tokio::test]
1069 async fn test_file_store_path_traversal_prevention() {
1070 let dir = tempdir().unwrap();
1071 let store = FileSessionStore::new(dir.path()).await.unwrap();
1072
1073 let mut session = create_test_session_data();
1075 session.id = "../../../etc/passwd".to_string();
1076 store.save(&session).await.unwrap();
1077
1078 let files: Vec<_> = std::fs::read_dir(dir.path())
1080 .unwrap()
1081 .filter_map(|e| e.ok())
1082 .collect();
1083 assert_eq!(files.len(), 1);
1084
1085 let loaded = store.load(&session.id).await.unwrap();
1087 assert!(loaded.is_some());
1088 }
1089
1090 #[tokio::test]
1091 async fn test_file_store_with_policies() {
1092 let dir = tempdir().unwrap();
1093 let store = FileSessionStore::new(dir.path()).await.unwrap();
1094
1095 let mut session = create_test_session_data();
1096 session.config.confirmation_policy = Some(ConfirmationPolicy::enabled());
1097 session.config.permission_policy = Some(PermissionPolicy::new().allow("Bash(cargo:*)"));
1098 session.config.queue_config = Some(SessionQueueConfig::default());
1099
1100 store.save(&session).await.unwrap();
1101
1102 let loaded = store.load(&session.id).await.unwrap().unwrap();
1103 assert!(loaded.config.confirmation_policy.is_some());
1104 assert!(loaded.config.permission_policy.is_some());
1105 assert!(loaded.config.queue_config.is_some());
1106 }
1107
1108 #[tokio::test]
1109 async fn test_file_store_with_llm_config() {
1110 let dir = tempdir().unwrap();
1111 let store = FileSessionStore::new(dir.path()).await.unwrap();
1112
1113 let mut session = create_test_session_data();
1114 session.llm_config = Some(LlmConfigData {
1115 provider: "anthropic".to_string(),
1116 model: "claude-3-5-sonnet-20241022".to_string(),
1117 api_key: Some("secret".to_string()), base_url: None,
1119 });
1120
1121 store.save(&session).await.unwrap();
1122
1123 let loaded = store.load(&session.id).await.unwrap().unwrap();
1124 let llm_config = loaded.llm_config.unwrap();
1125 assert_eq!(llm_config.provider, "anthropic");
1126 assert_eq!(llm_config.model, "claude-3-5-sonnet-20241022");
1127 assert!(llm_config.api_key.is_none());
1129 }
1130
1131 #[tokio::test]
1136 async fn test_memory_store_save_and_load() {
1137 let store = MemorySessionStore::new();
1138 let session = create_test_session_data();
1139
1140 store.save(&session).await.unwrap();
1141
1142 let loaded = store.load(&session.id).await.unwrap();
1143 assert!(loaded.is_some());
1144 assert_eq!(loaded.unwrap().id, session.id);
1145 }
1146
1147 #[tokio::test]
1148 async fn test_memory_store_delete() {
1149 let store = MemorySessionStore::new();
1150 let session = create_test_session_data();
1151
1152 store.save(&session).await.unwrap();
1153 assert!(store.exists(&session.id).await.unwrap());
1154
1155 store.delete(&session.id).await.unwrap();
1156 assert!(!store.exists(&session.id).await.unwrap());
1157 }
1158
1159 #[tokio::test]
1160 async fn test_memory_store_list() {
1161 let store = MemorySessionStore::new();
1162
1163 for i in 1..=3 {
1164 let mut session = create_test_session_data();
1165 session.id = format!("session-{}", i);
1166 store.save(&session).await.unwrap();
1167 }
1168
1169 let list = store.list().await.unwrap();
1170 assert_eq!(list.len(), 3);
1171 }
1172
1173 #[test]
1178 fn test_session_data_serialization() {
1179 let session = create_test_session_data();
1180 let json = serde_json::to_string(&session).unwrap();
1181 let parsed: SessionData = serde_json::from_str(&json).unwrap();
1182
1183 assert_eq!(parsed.id, session.id);
1184 assert_eq!(parsed.messages.len(), session.messages.len());
1185 }
1186
1187 #[test]
1188 fn test_tool_names_from_definitions() {
1189 let tools = vec![
1190 crate::llm::ToolDefinition {
1191 name: "bash".to_string(),
1192 description: "Execute bash".to_string(),
1193 parameters: serde_json::json!({}),
1194 },
1195 crate::llm::ToolDefinition {
1196 name: "read".to_string(),
1197 description: "Read file".to_string(),
1198 parameters: serde_json::json!({}),
1199 },
1200 ];
1201
1202 let names = SessionData::tool_names_from_definitions(&tools);
1203 assert_eq!(names, vec!["bash", "read"]);
1204 }
1205
1206 #[tokio::test]
1211 async fn test_file_store_backslash_sanitization() {
1212 let dir = tempdir().unwrap();
1213 let store = FileSessionStore::new(dir.path()).await.unwrap();
1214
1215 let mut session = create_test_session_data();
1216 session.id = r"foo\bar\baz".to_string();
1217 store.save(&session).await.unwrap();
1218
1219 let loaded = store.load(&session.id).await.unwrap();
1220 assert!(loaded.is_some());
1221
1222 let loaded = loaded.unwrap();
1223 assert_eq!(loaded.id, session.id);
1224
1225 let expected_path = dir.path().join("foo_bar_baz.json");
1227 assert!(expected_path.exists());
1228 }
1229
1230 #[tokio::test]
1231 async fn test_file_store_mixed_separator_sanitization() {
1232 let dir = tempdir().unwrap();
1233 let store = FileSessionStore::new(dir.path()).await.unwrap();
1234
1235 let mut session = create_test_session_data();
1236 session.id = r"foo/bar\baz..qux".to_string();
1237 store.save(&session).await.unwrap();
1238
1239 let loaded = store.load(&session.id).await.unwrap();
1240 assert!(loaded.is_some());
1241
1242 let loaded = loaded.unwrap();
1243 assert_eq!(loaded.id, session.id);
1244
1245 let expected_path = dir.path().join("foo_bar_baz_qux.json");
1247 assert!(expected_path.exists());
1248 }
1249
1250 #[tokio::test]
1255 async fn test_file_store_corrupted_json_recovery() {
1256 let dir = tempdir().unwrap();
1257 let store = FileSessionStore::new(dir.path()).await.unwrap();
1258
1259 let corrupted_path = dir.path().join("test-id.json");
1261 tokio::fs::write(&corrupted_path, b"not valid json {{{")
1262 .await
1263 .unwrap();
1264
1265 let result = store.load("test-id").await;
1267 assert!(result.is_err());
1268 }
1269
1270 #[tokio::test]
1275 async fn test_file_store_exists() {
1276 let dir = tempdir().unwrap();
1277 let store = FileSessionStore::new(dir.path()).await.unwrap();
1278
1279 let session = create_test_session_data();
1280
1281 assert!(!store.exists(&session.id).await.unwrap());
1283
1284 store.save(&session).await.unwrap();
1286 assert!(store.exists(&session.id).await.unwrap());
1287
1288 store.delete(&session.id).await.unwrap();
1290 assert!(!store.exists(&session.id).await.unwrap());
1291 }
1292
1293 #[tokio::test]
1294 async fn test_memory_store_exists() {
1295 let store = MemorySessionStore::new();
1296
1297 assert!(!store.exists("unknown-id").await.unwrap());
1299
1300 let session = create_test_session_data();
1302 store.save(&session).await.unwrap();
1303 assert!(store.exists(&session.id).await.unwrap());
1304 }
1305
1306 #[tokio::test]
1307 async fn test_file_store_health_check() {
1308 let dir = tempfile::tempdir().unwrap();
1309 let store = FileSessionStore::new(dir.path()).await.unwrap();
1310 assert!(store.health_check().await.is_ok());
1311 assert_eq!(store.backend_name(), "file");
1312 }
1313
1314 #[tokio::test]
1315 async fn test_file_store_health_check_bad_dir() {
1316 let store = FileSessionStore {
1317 dir: std::path::PathBuf::from("/nonexistent/path/that/does/not/exist"),
1318 };
1319 assert!(store.health_check().await.is_err());
1320 }
1321
1322 #[tokio::test]
1323 async fn test_memory_store_health_check() {
1324 let store = MemorySessionStore::new();
1325 assert!(store.health_check().await.is_ok());
1326 assert_eq!(store.backend_name(), "memory");
1327 }
1328
1329 #[tokio::test]
1334 async fn test_file_store_load_empty_file() {
1335 let dir = tempdir().unwrap();
1336 let store = FileSessionStore::new(dir.path()).await.unwrap();
1337
1338 let empty_path = dir.path().join("empty-session.json");
1340 tokio::fs::write(&empty_path, b"").await.unwrap();
1341
1342 let result = store.load("empty-session").await;
1343 assert!(
1344 result.is_err(),
1345 "Empty file must return error, not Ok(None)"
1346 );
1347 }
1348
1349 #[tokio::test]
1350 async fn test_file_store_load_partial_json() {
1351 let dir = tempdir().unwrap();
1352 let store = FileSessionStore::new(dir.path()).await.unwrap();
1353
1354 let partial_path = dir.path().join("partial-session.json");
1356 tokio::fs::write(&partial_path, b"{\"id\":\"partial-session\",\"message")
1357 .await
1358 .unwrap();
1359
1360 let result = store.load("partial-session").await;
1361 assert!(result.is_err(), "Partial JSON must return error");
1362 }
1363
1364 #[tokio::test]
1365 async fn test_file_store_concurrent_save() {
1366 let dir = tempdir().unwrap();
1367 let store = std::sync::Arc::new(FileSessionStore::new(dir.path()).await.unwrap());
1368
1369 let session = create_test_session_data();
1370 let id = session.id.clone();
1371
1372 store.save(&session).await.unwrap();
1374
1375 let mut handles = Vec::new();
1377 for _ in 0..5 {
1378 let s = store.clone();
1379 let sess = session.clone();
1380 handles.push(tokio::spawn(async move { s.save(&sess).await }));
1381 }
1382 for h in handles {
1383 h.await.unwrap().unwrap();
1384 }
1385
1386 let loaded = store.load(&id).await.unwrap();
1388 assert!(loaded.is_some());
1389 assert_eq!(loaded.unwrap().id, id);
1390 }
1391
1392 #[tokio::test]
1393 async fn test_file_store_load_nonexistent_returns_none() {
1394 let dir = tempdir().unwrap();
1395 let store = FileSessionStore::new(dir.path()).await.unwrap();
1396
1397 let result = store.load("does-not-exist-at-all").await.unwrap();
1398 assert!(result.is_none(), "Missing session must return Ok(None)");
1399 }
1400}