1use crate::llm::{Message, TokenUsage, ToolDefinition};
33use crate::planning::Task;
34use crate::session::{ContextUsage, SessionConfig, SessionState};
35use anyhow::{Context, Result};
36use serde::{Deserialize, Serialize};
37use std::collections::HashMap;
38use std::path::{Path, PathBuf};
39use tokio::fs;
40use tokio::io::AsyncWriteExt;
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct SessionData {
52 pub id: String,
54
55 pub config: SessionConfig,
57
58 pub state: SessionState,
60
61 pub messages: Vec<Message>,
63
64 pub context_usage: ContextUsage,
66
67 pub total_usage: TokenUsage,
69
70 #[serde(default)]
72 pub total_cost: f64,
73
74 #[serde(skip_serializing_if = "Option::is_none")]
76 pub model_name: Option<String>,
77
78 #[serde(default)]
80 pub cost_records: Vec<crate::telemetry::LlmCostRecord>,
81
82 pub tool_names: Vec<String>,
84
85 pub thinking_enabled: bool,
87
88 pub thinking_budget: Option<usize>,
90
91 pub created_at: i64,
93
94 pub updated_at: i64,
96
97 #[serde(skip_serializing_if = "Option::is_none")]
99 pub llm_config: Option<LlmConfigData>,
100
101 #[serde(default, alias = "todos")]
103 pub tasks: Vec<Task>,
104
105 #[serde(skip_serializing_if = "Option::is_none")]
107 pub parent_id: Option<String>,
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize)]
112pub struct LlmConfigData {
113 pub provider: String,
114 pub model: String,
115 #[serde(skip_serializing, default)]
117 pub api_key: Option<String>,
118 pub base_url: Option<String>,
119}
120
121impl SessionData {
122 pub fn tool_names_from_definitions(tools: &[ToolDefinition]) -> Vec<String> {
124 tools.iter().map(|t| t.name.clone()).collect()
125 }
126}
127
128#[async_trait::async_trait]
134pub trait SessionStore: Send + Sync {
135 async fn save(&self, session: &SessionData) -> Result<()>;
137
138 async fn load(&self, id: &str) -> Result<Option<SessionData>>;
140
141 async fn delete(&self, id: &str) -> Result<()>;
143
144 async fn list(&self) -> Result<Vec<String>>;
146
147 async fn exists(&self, id: &str) -> Result<bool>;
149
150 async fn health_check(&self) -> Result<()> {
152 Ok(())
153 }
154
155 fn backend_name(&self) -> &str {
157 "unknown"
158 }
159}
160
161pub struct FileSessionStore {
174 dir: PathBuf,
176}
177
178impl FileSessionStore {
179 pub async fn new<P: AsRef<Path>>(dir: P) -> Result<Self> {
183 let dir = dir.as_ref().to_path_buf();
184
185 fs::create_dir_all(&dir)
187 .await
188 .with_context(|| format!("Failed to create session directory: {}", dir.display()))?;
189
190 Ok(Self { dir })
191 }
192
193 fn session_path(&self, id: &str) -> PathBuf {
195 let safe_id = id.replace(['/', '\\'], "_").replace("..", "_");
197 self.dir.join(format!("{}.json", safe_id))
198 }
199}
200
201#[async_trait::async_trait]
202impl SessionStore for FileSessionStore {
203 async fn save(&self, session: &SessionData) -> Result<()> {
204 let path = self.session_path(&session.id);
205
206 let json = serde_json::to_string_pretty(session)
208 .with_context(|| format!("Failed to serialize session: {}", session.id))?;
209
210 let unique_suffix = format!(
213 "{}.{}",
214 std::time::SystemTime::now()
215 .duration_since(std::time::UNIX_EPOCH)
216 .unwrap()
217 .as_nanos(),
218 std::process::id()
219 );
220 let temp_path = path.with_extension(format!("json.{}.tmp", unique_suffix));
221
222 let mut file = fs::File::create(&temp_path)
223 .await
224 .with_context(|| format!("Failed to create temp file: {}", temp_path.display()))?;
225
226 file.write_all(json.as_bytes())
227 .await
228 .with_context(|| format!("Failed to write session data: {}", session.id))?;
229
230 file.sync_all()
231 .await
232 .with_context(|| format!("Failed to sync session file: {}", session.id))?;
233
234 fs::rename(&temp_path, &path)
236 .await
237 .with_context(|| format!("Failed to rename session file: {}", session.id))?;
238
239 tracing::debug!("Saved session {} to {}", session.id, path.display());
240 Ok(())
241 }
242
243 async fn load(&self, id: &str) -> Result<Option<SessionData>> {
244 let path = self.session_path(id);
245
246 if !path.exists() {
247 return Ok(None);
248 }
249
250 let json = fs::read_to_string(&path)
251 .await
252 .with_context(|| format!("Failed to read session file: {}", path.display()))?;
253
254 let session: SessionData = serde_json::from_str(&json)
255 .with_context(|| format!("Failed to parse session file: {}", path.display()))?;
256
257 tracing::debug!("Loaded session {} from {}", id, path.display());
258 Ok(Some(session))
259 }
260
261 async fn delete(&self, id: &str) -> Result<()> {
262 let path = self.session_path(id);
263
264 if path.exists() {
265 fs::remove_file(&path)
266 .await
267 .with_context(|| format!("Failed to delete session file: {}", path.display()))?;
268
269 tracing::debug!("Deleted session {} from {}", id, path.display());
270 }
271
272 Ok(())
273 }
274
275 async fn list(&self) -> Result<Vec<String>> {
276 let mut session_ids = Vec::new();
277
278 let mut entries = fs::read_dir(&self.dir)
279 .await
280 .with_context(|| format!("Failed to read session directory: {}", self.dir.display()))?;
281
282 while let Some(entry) = entries.next_entry().await? {
283 let path = entry.path();
284
285 if path.extension().is_some_and(|ext| ext == "json") {
286 if let Some(stem) = path.file_stem() {
287 if let Some(id) = stem.to_str() {
288 session_ids.push(id.to_string());
289 }
290 }
291 }
292 }
293
294 Ok(session_ids)
295 }
296
297 async fn exists(&self, id: &str) -> Result<bool> {
298 let path = self.session_path(id);
299 Ok(path.exists())
300 }
301
302 async fn health_check(&self) -> Result<()> {
303 let probe = self.dir.join(".health_check");
305 fs::write(&probe, b"ok")
306 .await
307 .with_context(|| format!("Store directory not writable: {}", self.dir.display()))?;
308 let _ = fs::remove_file(&probe).await;
309 Ok(())
310 }
311
312 fn backend_name(&self) -> &str {
313 "file"
314 }
315}
316
317pub struct MemorySessionStore {
323 sessions: tokio::sync::RwLock<HashMap<String, SessionData>>,
324}
325
326impl MemorySessionStore {
327 pub fn new() -> Self {
328 Self {
329 sessions: tokio::sync::RwLock::new(HashMap::new()),
330 }
331 }
332}
333
334impl Default for MemorySessionStore {
335 fn default() -> Self {
336 Self::new()
337 }
338}
339
340#[async_trait::async_trait]
341impl SessionStore for MemorySessionStore {
342 async fn save(&self, session: &SessionData) -> Result<()> {
343 let mut sessions = self.sessions.write().await;
344 sessions.insert(session.id.clone(), session.clone());
345 Ok(())
346 }
347
348 async fn load(&self, id: &str) -> Result<Option<SessionData>> {
349 let sessions = self.sessions.read().await;
350 Ok(sessions.get(id).cloned())
351 }
352
353 async fn delete(&self, id: &str) -> Result<()> {
354 let mut sessions = self.sessions.write().await;
355 sessions.remove(id);
356 Ok(())
357 }
358
359 async fn list(&self) -> Result<Vec<String>> {
360 let sessions = self.sessions.read().await;
361 Ok(sessions.keys().cloned().collect())
362 }
363
364 async fn exists(&self, id: &str) -> Result<bool> {
365 let sessions = self.sessions.read().await;
366 Ok(sessions.contains_key(id))
367 }
368
369 fn backend_name(&self) -> &str {
370 "memory"
371 }
372}
373
374#[cfg(test)]
379mod tests {
380 use super::*;
381 use crate::hitl::ConfirmationPolicy;
382 use crate::permissions::PermissionPolicy;
383 use crate::queue::SessionQueueConfig;
384 use tempfile::tempdir;
385
386 fn create_test_session_data() -> SessionData {
387 SessionData {
388 id: "test-session-1".to_string(),
389 config: SessionConfig {
390 name: "Test Session".to_string(),
391 workspace: "/tmp/workspace".to_string(),
392 system_prompt: Some("You are helpful.".to_string()),
393 max_context_length: 200000,
394 auto_compact: false,
395 auto_compact_threshold: crate::session::DEFAULT_AUTO_COMPACT_THRESHOLD,
396 storage_type: crate::config::StorageBackend::File,
397 queue_config: None,
398 confirmation_policy: None,
399 permission_policy: None,
400 parent_id: None,
401 security_config: None,
402 hook_engine: None,
403 planning_enabled: false,
404 goal_tracking: false,
405 },
406 state: SessionState::Active,
407 messages: vec![
408 Message::user("Hello"),
409 Message {
410 role: "assistant".to_string(),
411 content: vec![crate::llm::ContentBlock::Text {
412 text: "Hi there!".to_string(),
413 }],
414 reasoning_content: None,
415 },
416 ],
417 context_usage: ContextUsage {
418 used_tokens: 100,
419 max_tokens: 200000,
420 percent: 0.0005,
421 turns: 2,
422 },
423 total_usage: TokenUsage {
424 prompt_tokens: 50,
425 completion_tokens: 50,
426 total_tokens: 100,
427 cache_read_tokens: None,
428 cache_write_tokens: None,
429 },
430 tool_names: vec!["bash".to_string(), "read".to_string()],
431 thinking_enabled: false,
432 thinking_budget: None,
433 created_at: 1700000000,
434 updated_at: 1700000100,
435 llm_config: None,
436 tasks: vec![],
437 parent_id: None,
438 total_cost: 0.0,
439 model_name: None,
440 cost_records: Vec::new(),
441 }
442 }
443
444 #[tokio::test]
449 async fn test_file_store_save_and_load() {
450 let dir = tempdir().unwrap();
451 let store = FileSessionStore::new(dir.path()).await.unwrap();
452
453 let session = create_test_session_data();
454
455 store.save(&session).await.unwrap();
457
458 let loaded = store.load(&session.id).await.unwrap();
460 assert!(loaded.is_some());
461
462 let loaded = loaded.unwrap();
463 assert_eq!(loaded.id, session.id);
464 assert_eq!(loaded.config.name, session.config.name);
465 assert_eq!(loaded.messages.len(), 2);
466 assert_eq!(loaded.state, SessionState::Active);
467 }
468
469 #[tokio::test]
470 async fn test_file_store_load_nonexistent() {
471 let dir = tempdir().unwrap();
472 let store = FileSessionStore::new(dir.path()).await.unwrap();
473
474 let loaded = store.load("nonexistent").await.unwrap();
475 assert!(loaded.is_none());
476 }
477
478 #[tokio::test]
479 async fn test_file_store_delete() {
480 let dir = tempdir().unwrap();
481 let store = FileSessionStore::new(dir.path()).await.unwrap();
482
483 let session = create_test_session_data();
484 store.save(&session).await.unwrap();
485
486 assert!(store.exists(&session.id).await.unwrap());
488
489 store.delete(&session.id).await.unwrap();
491
492 assert!(!store.exists(&session.id).await.unwrap());
494 assert!(store.load(&session.id).await.unwrap().is_none());
495 }
496
497 #[tokio::test]
498 async fn test_file_store_list() {
499 let dir = tempdir().unwrap();
500 let store = FileSessionStore::new(dir.path()).await.unwrap();
501
502 let list = store.list().await.unwrap();
504 assert!(list.is_empty());
505
506 for i in 1..=3 {
508 let mut session = create_test_session_data();
509 session.id = format!("session-{}", i);
510 store.save(&session).await.unwrap();
511 }
512
513 let list = store.list().await.unwrap();
515 assert_eq!(list.len(), 3);
516 assert!(list.contains(&"session-1".to_string()));
517 assert!(list.contains(&"session-2".to_string()));
518 assert!(list.contains(&"session-3".to_string()));
519 }
520
521 #[tokio::test]
522 async fn test_file_store_overwrite() {
523 let dir = tempdir().unwrap();
524 let store = FileSessionStore::new(dir.path()).await.unwrap();
525
526 let mut session = create_test_session_data();
527 store.save(&session).await.unwrap();
528
529 session.messages.push(Message::user("Another message"));
531 session.updated_at = 1700000200;
532 store.save(&session).await.unwrap();
533
534 let loaded = store.load(&session.id).await.unwrap().unwrap();
536 assert_eq!(loaded.messages.len(), 3);
537 assert_eq!(loaded.updated_at, 1700000200);
538 }
539
540 #[tokio::test]
541 async fn test_file_store_path_traversal_prevention() {
542 let dir = tempdir().unwrap();
543 let store = FileSessionStore::new(dir.path()).await.unwrap();
544
545 let mut session = create_test_session_data();
547 session.id = "../../../etc/passwd".to_string();
548 store.save(&session).await.unwrap();
549
550 let files: Vec<_> = std::fs::read_dir(dir.path())
552 .unwrap()
553 .filter_map(|e| e.ok())
554 .collect();
555 assert_eq!(files.len(), 1);
556
557 let loaded = store.load(&session.id).await.unwrap();
559 assert!(loaded.is_some());
560 }
561
562 #[tokio::test]
563 async fn test_file_store_with_policies() {
564 let dir = tempdir().unwrap();
565 let store = FileSessionStore::new(dir.path()).await.unwrap();
566
567 let mut session = create_test_session_data();
568 session.config.confirmation_policy = Some(ConfirmationPolicy::enabled());
569 session.config.permission_policy = Some(PermissionPolicy::new().allow("Bash(cargo:*)"));
570 session.config.queue_config = Some(SessionQueueConfig::default());
571
572 store.save(&session).await.unwrap();
573
574 let loaded = store.load(&session.id).await.unwrap().unwrap();
575 assert!(loaded.config.confirmation_policy.is_some());
576 assert!(loaded.config.permission_policy.is_some());
577 assert!(loaded.config.queue_config.is_some());
578 }
579
580 #[tokio::test]
581 async fn test_file_store_with_llm_config() {
582 let dir = tempdir().unwrap();
583 let store = FileSessionStore::new(dir.path()).await.unwrap();
584
585 let mut session = create_test_session_data();
586 session.llm_config = Some(LlmConfigData {
587 provider: "anthropic".to_string(),
588 model: "claude-3-5-sonnet-20241022".to_string(),
589 api_key: Some("secret".to_string()), base_url: None,
591 });
592
593 store.save(&session).await.unwrap();
594
595 let loaded = store.load(&session.id).await.unwrap().unwrap();
596 let llm_config = loaded.llm_config.unwrap();
597 assert_eq!(llm_config.provider, "anthropic");
598 assert_eq!(llm_config.model, "claude-3-5-sonnet-20241022");
599 assert!(llm_config.api_key.is_none());
601 }
602
603 #[tokio::test]
608 async fn test_memory_store_save_and_load() {
609 let store = MemorySessionStore::new();
610 let session = create_test_session_data();
611
612 store.save(&session).await.unwrap();
613
614 let loaded = store.load(&session.id).await.unwrap();
615 assert!(loaded.is_some());
616 assert_eq!(loaded.unwrap().id, session.id);
617 }
618
619 #[tokio::test]
620 async fn test_memory_store_delete() {
621 let store = MemorySessionStore::new();
622 let session = create_test_session_data();
623
624 store.save(&session).await.unwrap();
625 assert!(store.exists(&session.id).await.unwrap());
626
627 store.delete(&session.id).await.unwrap();
628 assert!(!store.exists(&session.id).await.unwrap());
629 }
630
631 #[tokio::test]
632 async fn test_memory_store_list() {
633 let store = MemorySessionStore::new();
634
635 for i in 1..=3 {
636 let mut session = create_test_session_data();
637 session.id = format!("session-{}", i);
638 store.save(&session).await.unwrap();
639 }
640
641 let list = store.list().await.unwrap();
642 assert_eq!(list.len(), 3);
643 }
644
645 #[test]
650 fn test_session_data_serialization() {
651 let session = create_test_session_data();
652 let json = serde_json::to_string(&session).unwrap();
653 let parsed: SessionData = serde_json::from_str(&json).unwrap();
654
655 assert_eq!(parsed.id, session.id);
656 assert_eq!(parsed.messages.len(), session.messages.len());
657 }
658
659 #[test]
660 fn test_tool_names_from_definitions() {
661 let tools = vec![
662 crate::llm::ToolDefinition {
663 name: "bash".to_string(),
664 description: "Execute bash".to_string(),
665 parameters: serde_json::json!({}),
666 },
667 crate::llm::ToolDefinition {
668 name: "read".to_string(),
669 description: "Read file".to_string(),
670 parameters: serde_json::json!({}),
671 },
672 ];
673
674 let names = SessionData::tool_names_from_definitions(&tools);
675 assert_eq!(names, vec!["bash", "read"]);
676 }
677
678 #[tokio::test]
683 async fn test_file_store_backslash_sanitization() {
684 let dir = tempdir().unwrap();
685 let store = FileSessionStore::new(dir.path()).await.unwrap();
686
687 let mut session = create_test_session_data();
688 session.id = r"foo\bar\baz".to_string();
689 store.save(&session).await.unwrap();
690
691 let loaded = store.load(&session.id).await.unwrap();
692 assert!(loaded.is_some());
693
694 let loaded = loaded.unwrap();
695 assert_eq!(loaded.id, session.id);
696
697 let expected_path = dir.path().join("foo_bar_baz.json");
699 assert!(expected_path.exists());
700 }
701
702 #[tokio::test]
703 async fn test_file_store_mixed_separator_sanitization() {
704 let dir = tempdir().unwrap();
705 let store = FileSessionStore::new(dir.path()).await.unwrap();
706
707 let mut session = create_test_session_data();
708 session.id = r"foo/bar\baz..qux".to_string();
709 store.save(&session).await.unwrap();
710
711 let loaded = store.load(&session.id).await.unwrap();
712 assert!(loaded.is_some());
713
714 let loaded = loaded.unwrap();
715 assert_eq!(loaded.id, session.id);
716
717 let expected_path = dir.path().join("foo_bar_baz_qux.json");
719 assert!(expected_path.exists());
720 }
721
722 #[tokio::test]
727 async fn test_file_store_corrupted_json_recovery() {
728 let dir = tempdir().unwrap();
729 let store = FileSessionStore::new(dir.path()).await.unwrap();
730
731 let corrupted_path = dir.path().join("test-id.json");
733 tokio::fs::write(&corrupted_path, b"not valid json {{{")
734 .await
735 .unwrap();
736
737 let result = store.load("test-id").await;
739 assert!(result.is_err());
740 }
741
742 #[tokio::test]
747 async fn test_file_store_exists() {
748 let dir = tempdir().unwrap();
749 let store = FileSessionStore::new(dir.path()).await.unwrap();
750
751 let session = create_test_session_data();
752
753 assert!(!store.exists(&session.id).await.unwrap());
755
756 store.save(&session).await.unwrap();
758 assert!(store.exists(&session.id).await.unwrap());
759
760 store.delete(&session.id).await.unwrap();
762 assert!(!store.exists(&session.id).await.unwrap());
763 }
764
765 #[tokio::test]
766 async fn test_memory_store_exists() {
767 let store = MemorySessionStore::new();
768
769 assert!(!store.exists("unknown-id").await.unwrap());
771
772 let session = create_test_session_data();
774 store.save(&session).await.unwrap();
775 assert!(store.exists(&session.id).await.unwrap());
776 }
777
778 #[tokio::test]
779 async fn test_file_store_health_check() {
780 let dir = tempfile::tempdir().unwrap();
781 let store = FileSessionStore::new(dir.path()).await.unwrap();
782 assert!(store.health_check().await.is_ok());
783 assert_eq!(store.backend_name(), "file");
784 }
785
786 #[tokio::test]
787 async fn test_file_store_health_check_bad_dir() {
788 let store = FileSessionStore {
789 dir: std::path::PathBuf::from("/nonexistent/path/that/does/not/exist"),
790 };
791 assert!(store.health_check().await.is_err());
792 }
793
794 #[tokio::test]
795 async fn test_memory_store_health_check() {
796 let store = MemorySessionStore::new();
797 assert!(store.health_check().await.is_ok());
798 assert_eq!(store.backend_name(), "memory");
799 }
800
801 #[tokio::test]
806 async fn test_file_store_load_empty_file() {
807 let dir = tempdir().unwrap();
808 let store = FileSessionStore::new(dir.path()).await.unwrap();
809
810 let empty_path = dir.path().join("empty-session.json");
812 tokio::fs::write(&empty_path, b"").await.unwrap();
813
814 let result = store.load("empty-session").await;
815 assert!(
816 result.is_err(),
817 "Empty file must return error, not Ok(None)"
818 );
819 }
820
821 #[tokio::test]
822 async fn test_file_store_load_partial_json() {
823 let dir = tempdir().unwrap();
824 let store = FileSessionStore::new(dir.path()).await.unwrap();
825
826 let partial_path = dir.path().join("partial-session.json");
828 tokio::fs::write(&partial_path, b"{\"id\":\"partial-session\",\"message")
829 .await
830 .unwrap();
831
832 let result = store.load("partial-session").await;
833 assert!(result.is_err(), "Partial JSON must return error");
834 }
835
836 #[tokio::test]
837 async fn test_file_store_concurrent_save() {
838 let dir = tempdir().unwrap();
839 let store = std::sync::Arc::new(FileSessionStore::new(dir.path()).await.unwrap());
840
841 let session = create_test_session_data();
842 let id = session.id.clone();
843
844 store.save(&session).await.unwrap();
846
847 let mut handles = Vec::new();
849 for _ in 0..5 {
850 let s = store.clone();
851 let sess = session.clone();
852 handles.push(tokio::spawn(async move { s.save(&sess).await }));
853 }
854 for h in handles {
855 h.await.unwrap().unwrap();
856 }
857
858 let loaded = store.load(&id).await.unwrap();
860 assert!(loaded.is_some());
861 assert_eq!(loaded.unwrap().id, id);
862 }
863
864 #[tokio::test]
865 async fn test_file_store_load_nonexistent_returns_none() {
866 let dir = tempdir().unwrap();
867 let store = FileSessionStore::new(dir.path()).await.unwrap();
868
869 let result = store.load("does-not-exist-at-all").await.unwrap();
870 assert!(result.is_none(), "Missing session must return Ok(None)");
871 }
872}