1use crate::llm::{Message, TokenUsage, ToolDefinition};
33use crate::planning::Task;
34use crate::session::{ContextUsage, SessionConfig, SessionState};
35use anyhow::{Context, Result};
36use serde::{Deserialize, Serialize};
37use std::collections::HashMap;
38use std::path::{Path, PathBuf};
39use tokio::fs;
40use tokio::io::AsyncWriteExt;
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct SessionData {
52 pub id: String,
54
55 pub config: SessionConfig,
57
58 pub state: SessionState,
60
61 pub messages: Vec<Message>,
63
64 pub context_usage: ContextUsage,
66
67 pub total_usage: TokenUsage,
69
70 #[serde(default)]
72 pub total_cost: f64,
73
74 #[serde(skip_serializing_if = "Option::is_none")]
76 pub model_name: Option<String>,
77
78 #[serde(default)]
80 pub cost_records: Vec<crate::telemetry::LlmCostRecord>,
81
82 pub tool_names: Vec<String>,
84
85 pub thinking_enabled: bool,
87
88 pub thinking_budget: Option<usize>,
90
91 pub created_at: i64,
93
94 pub updated_at: i64,
96
97 #[serde(skip_serializing_if = "Option::is_none")]
99 pub llm_config: Option<LlmConfigData>,
100
101 #[serde(default, alias = "todos")]
103 pub tasks: Vec<Task>,
104
105 #[serde(skip_serializing_if = "Option::is_none")]
107 pub parent_id: Option<String>,
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize)]
112pub struct LlmConfigData {
113 pub provider: String,
114 pub model: String,
115 #[serde(skip_serializing, default)]
117 pub api_key: Option<String>,
118 pub base_url: Option<String>,
119}
120
121impl SessionData {
122 pub fn tool_names_from_definitions(tools: &[ToolDefinition]) -> Vec<String> {
124 tools.iter().map(|t| t.name.clone()).collect()
125 }
126}
127
128#[async_trait::async_trait]
134pub trait SessionStore: Send + Sync {
135 async fn save(&self, session: &SessionData) -> Result<()>;
137
138 async fn load(&self, id: &str) -> Result<Option<SessionData>>;
140
141 async fn delete(&self, id: &str) -> Result<()>;
143
144 async fn list(&self) -> Result<Vec<String>>;
146
147 async fn exists(&self, id: &str) -> Result<bool>;
149
150 async fn health_check(&self) -> Result<()> {
152 Ok(())
153 }
154
155 fn backend_name(&self) -> &str {
157 "unknown"
158 }
159}
160
161pub struct FileSessionStore {
174 dir: PathBuf,
176}
177
178impl FileSessionStore {
179 pub async fn new<P: AsRef<Path>>(dir: P) -> Result<Self> {
183 let dir = dir.as_ref().to_path_buf();
184
185 fs::create_dir_all(&dir)
187 .await
188 .with_context(|| format!("Failed to create session directory: {}", dir.display()))?;
189
190 Ok(Self { dir })
191 }
192
193 fn session_path(&self, id: &str) -> PathBuf {
195 let safe_id = id.replace(['/', '\\'], "_").replace("..", "_");
197 self.dir.join(format!("{}.json", safe_id))
198 }
199}
200
201#[async_trait::async_trait]
202impl SessionStore for FileSessionStore {
203 async fn save(&self, session: &SessionData) -> Result<()> {
204 let path = self.session_path(&session.id);
205
206 let json = serde_json::to_string_pretty(session)
208 .with_context(|| format!("Failed to serialize session: {}", session.id))?;
209
210 let temp_path = path.with_extension("json.tmp");
212
213 let mut file = fs::File::create(&temp_path)
214 .await
215 .with_context(|| format!("Failed to create temp file: {}", temp_path.display()))?;
216
217 file.write_all(json.as_bytes())
218 .await
219 .with_context(|| format!("Failed to write session data: {}", session.id))?;
220
221 file.sync_all()
222 .await
223 .with_context(|| format!("Failed to sync session file: {}", session.id))?;
224
225 fs::rename(&temp_path, &path)
227 .await
228 .with_context(|| format!("Failed to rename session file: {}", session.id))?;
229
230 tracing::debug!("Saved session {} to {}", session.id, path.display());
231 Ok(())
232 }
233
234 async fn load(&self, id: &str) -> Result<Option<SessionData>> {
235 let path = self.session_path(id);
236
237 if !path.exists() {
238 return Ok(None);
239 }
240
241 let json = fs::read_to_string(&path)
242 .await
243 .with_context(|| format!("Failed to read session file: {}", path.display()))?;
244
245 let session: SessionData = serde_json::from_str(&json)
246 .with_context(|| format!("Failed to parse session file: {}", path.display()))?;
247
248 tracing::debug!("Loaded session {} from {}", id, path.display());
249 Ok(Some(session))
250 }
251
252 async fn delete(&self, id: &str) -> Result<()> {
253 let path = self.session_path(id);
254
255 if path.exists() {
256 fs::remove_file(&path)
257 .await
258 .with_context(|| format!("Failed to delete session file: {}", path.display()))?;
259
260 tracing::debug!("Deleted session {} from {}", id, path.display());
261 }
262
263 Ok(())
264 }
265
266 async fn list(&self) -> Result<Vec<String>> {
267 let mut session_ids = Vec::new();
268
269 let mut entries = fs::read_dir(&self.dir)
270 .await
271 .with_context(|| format!("Failed to read session directory: {}", self.dir.display()))?;
272
273 while let Some(entry) = entries.next_entry().await? {
274 let path = entry.path();
275
276 if path.extension().is_some_and(|ext| ext == "json") {
277 if let Some(stem) = path.file_stem() {
278 if let Some(id) = stem.to_str() {
279 session_ids.push(id.to_string());
280 }
281 }
282 }
283 }
284
285 Ok(session_ids)
286 }
287
288 async fn exists(&self, id: &str) -> Result<bool> {
289 let path = self.session_path(id);
290 Ok(path.exists())
291 }
292
293 async fn health_check(&self) -> Result<()> {
294 let probe = self.dir.join(".health_check");
296 fs::write(&probe, b"ok")
297 .await
298 .with_context(|| format!("Store directory not writable: {}", self.dir.display()))?;
299 let _ = fs::remove_file(&probe).await;
300 Ok(())
301 }
302
303 fn backend_name(&self) -> &str {
304 "file"
305 }
306}
307
308pub struct MemorySessionStore {
314 sessions: tokio::sync::RwLock<HashMap<String, SessionData>>,
315}
316
317impl MemorySessionStore {
318 pub fn new() -> Self {
319 Self {
320 sessions: tokio::sync::RwLock::new(HashMap::new()),
321 }
322 }
323}
324
325impl Default for MemorySessionStore {
326 fn default() -> Self {
327 Self::new()
328 }
329}
330
331#[async_trait::async_trait]
332impl SessionStore for MemorySessionStore {
333 async fn save(&self, session: &SessionData) -> Result<()> {
334 let mut sessions = self.sessions.write().await;
335 sessions.insert(session.id.clone(), session.clone());
336 Ok(())
337 }
338
339 async fn load(&self, id: &str) -> Result<Option<SessionData>> {
340 let sessions = self.sessions.read().await;
341 Ok(sessions.get(id).cloned())
342 }
343
344 async fn delete(&self, id: &str) -> Result<()> {
345 let mut sessions = self.sessions.write().await;
346 sessions.remove(id);
347 Ok(())
348 }
349
350 async fn list(&self) -> Result<Vec<String>> {
351 let sessions = self.sessions.read().await;
352 Ok(sessions.keys().cloned().collect())
353 }
354
355 async fn exists(&self, id: &str) -> Result<bool> {
356 let sessions = self.sessions.read().await;
357 Ok(sessions.contains_key(id))
358 }
359
360 fn backend_name(&self) -> &str {
361 "memory"
362 }
363}
364
365#[cfg(test)]
370mod tests {
371 use super::*;
372 use crate::hitl::ConfirmationPolicy;
373 use crate::permissions::PermissionPolicy;
374 use crate::queue::SessionQueueConfig;
375 use tempfile::tempdir;
376
377 fn create_test_session_data() -> SessionData {
378 SessionData {
379 id: "test-session-1".to_string(),
380 config: SessionConfig {
381 name: "Test Session".to_string(),
382 workspace: "/tmp/workspace".to_string(),
383 system_prompt: Some("You are helpful.".to_string()),
384 max_context_length: 200000,
385 auto_compact: false,
386 auto_compact_threshold: crate::session::DEFAULT_AUTO_COMPACT_THRESHOLD,
387 storage_type: crate::config::StorageBackend::File,
388 queue_config: None,
389 confirmation_policy: None,
390 permission_policy: None,
391 parent_id: None,
392 security_config: None,
393 hook_engine: None,
394 planning_enabled: false,
395 goal_tracking: false,
396 },
397 state: SessionState::Active,
398 messages: vec![
399 Message::user("Hello"),
400 Message {
401 role: "assistant".to_string(),
402 content: vec![crate::llm::ContentBlock::Text {
403 text: "Hi there!".to_string(),
404 }],
405 reasoning_content: None,
406 },
407 ],
408 context_usage: ContextUsage {
409 used_tokens: 100,
410 max_tokens: 200000,
411 percent: 0.0005,
412 turns: 2,
413 },
414 total_usage: TokenUsage {
415 prompt_tokens: 50,
416 completion_tokens: 50,
417 total_tokens: 100,
418 cache_read_tokens: None,
419 cache_write_tokens: None,
420 },
421 tool_names: vec!["bash".to_string(), "read".to_string()],
422 thinking_enabled: false,
423 thinking_budget: None,
424 created_at: 1700000000,
425 updated_at: 1700000100,
426 llm_config: None,
427 tasks: vec![],
428 parent_id: None,
429 total_cost: 0.0,
430 model_name: None,
431 cost_records: Vec::new(),
432 }
433 }
434
435 #[tokio::test]
440 async fn test_file_store_save_and_load() {
441 let dir = tempdir().unwrap();
442 let store = FileSessionStore::new(dir.path()).await.unwrap();
443
444 let session = create_test_session_data();
445
446 store.save(&session).await.unwrap();
448
449 let loaded = store.load(&session.id).await.unwrap();
451 assert!(loaded.is_some());
452
453 let loaded = loaded.unwrap();
454 assert_eq!(loaded.id, session.id);
455 assert_eq!(loaded.config.name, session.config.name);
456 assert_eq!(loaded.messages.len(), 2);
457 assert_eq!(loaded.state, SessionState::Active);
458 }
459
460 #[tokio::test]
461 async fn test_file_store_load_nonexistent() {
462 let dir = tempdir().unwrap();
463 let store = FileSessionStore::new(dir.path()).await.unwrap();
464
465 let loaded = store.load("nonexistent").await.unwrap();
466 assert!(loaded.is_none());
467 }
468
469 #[tokio::test]
470 async fn test_file_store_delete() {
471 let dir = tempdir().unwrap();
472 let store = FileSessionStore::new(dir.path()).await.unwrap();
473
474 let session = create_test_session_data();
475 store.save(&session).await.unwrap();
476
477 assert!(store.exists(&session.id).await.unwrap());
479
480 store.delete(&session.id).await.unwrap();
482
483 assert!(!store.exists(&session.id).await.unwrap());
485 assert!(store.load(&session.id).await.unwrap().is_none());
486 }
487
488 #[tokio::test]
489 async fn test_file_store_list() {
490 let dir = tempdir().unwrap();
491 let store = FileSessionStore::new(dir.path()).await.unwrap();
492
493 let list = store.list().await.unwrap();
495 assert!(list.is_empty());
496
497 for i in 1..=3 {
499 let mut session = create_test_session_data();
500 session.id = format!("session-{}", i);
501 store.save(&session).await.unwrap();
502 }
503
504 let list = store.list().await.unwrap();
506 assert_eq!(list.len(), 3);
507 assert!(list.contains(&"session-1".to_string()));
508 assert!(list.contains(&"session-2".to_string()));
509 assert!(list.contains(&"session-3".to_string()));
510 }
511
512 #[tokio::test]
513 async fn test_file_store_overwrite() {
514 let dir = tempdir().unwrap();
515 let store = FileSessionStore::new(dir.path()).await.unwrap();
516
517 let mut session = create_test_session_data();
518 store.save(&session).await.unwrap();
519
520 session.messages.push(Message::user("Another message"));
522 session.updated_at = 1700000200;
523 store.save(&session).await.unwrap();
524
525 let loaded = store.load(&session.id).await.unwrap().unwrap();
527 assert_eq!(loaded.messages.len(), 3);
528 assert_eq!(loaded.updated_at, 1700000200);
529 }
530
531 #[tokio::test]
532 async fn test_file_store_path_traversal_prevention() {
533 let dir = tempdir().unwrap();
534 let store = FileSessionStore::new(dir.path()).await.unwrap();
535
536 let mut session = create_test_session_data();
538 session.id = "../../../etc/passwd".to_string();
539 store.save(&session).await.unwrap();
540
541 let files: Vec<_> = std::fs::read_dir(dir.path())
543 .unwrap()
544 .filter_map(|e| e.ok())
545 .collect();
546 assert_eq!(files.len(), 1);
547
548 let loaded = store.load(&session.id).await.unwrap();
550 assert!(loaded.is_some());
551 }
552
553 #[tokio::test]
554 async fn test_file_store_with_policies() {
555 let dir = tempdir().unwrap();
556 let store = FileSessionStore::new(dir.path()).await.unwrap();
557
558 let mut session = create_test_session_data();
559 session.config.confirmation_policy = Some(ConfirmationPolicy::enabled());
560 session.config.permission_policy = Some(PermissionPolicy::new().allow("Bash(cargo:*)"));
561 session.config.queue_config = Some(SessionQueueConfig::default());
562
563 store.save(&session).await.unwrap();
564
565 let loaded = store.load(&session.id).await.unwrap().unwrap();
566 assert!(loaded.config.confirmation_policy.is_some());
567 assert!(loaded.config.permission_policy.is_some());
568 assert!(loaded.config.queue_config.is_some());
569 }
570
571 #[tokio::test]
572 async fn test_file_store_with_llm_config() {
573 let dir = tempdir().unwrap();
574 let store = FileSessionStore::new(dir.path()).await.unwrap();
575
576 let mut session = create_test_session_data();
577 session.llm_config = Some(LlmConfigData {
578 provider: "anthropic".to_string(),
579 model: "claude-3-5-sonnet-20241022".to_string(),
580 api_key: Some("secret".to_string()), base_url: None,
582 });
583
584 store.save(&session).await.unwrap();
585
586 let loaded = store.load(&session.id).await.unwrap().unwrap();
587 let llm_config = loaded.llm_config.unwrap();
588 assert_eq!(llm_config.provider, "anthropic");
589 assert_eq!(llm_config.model, "claude-3-5-sonnet-20241022");
590 assert!(llm_config.api_key.is_none());
592 }
593
594 #[tokio::test]
599 async fn test_memory_store_save_and_load() {
600 let store = MemorySessionStore::new();
601 let session = create_test_session_data();
602
603 store.save(&session).await.unwrap();
604
605 let loaded = store.load(&session.id).await.unwrap();
606 assert!(loaded.is_some());
607 assert_eq!(loaded.unwrap().id, session.id);
608 }
609
610 #[tokio::test]
611 async fn test_memory_store_delete() {
612 let store = MemorySessionStore::new();
613 let session = create_test_session_data();
614
615 store.save(&session).await.unwrap();
616 assert!(store.exists(&session.id).await.unwrap());
617
618 store.delete(&session.id).await.unwrap();
619 assert!(!store.exists(&session.id).await.unwrap());
620 }
621
622 #[tokio::test]
623 async fn test_memory_store_list() {
624 let store = MemorySessionStore::new();
625
626 for i in 1..=3 {
627 let mut session = create_test_session_data();
628 session.id = format!("session-{}", i);
629 store.save(&session).await.unwrap();
630 }
631
632 let list = store.list().await.unwrap();
633 assert_eq!(list.len(), 3);
634 }
635
636 #[test]
641 fn test_session_data_serialization() {
642 let session = create_test_session_data();
643 let json = serde_json::to_string(&session).unwrap();
644 let parsed: SessionData = serde_json::from_str(&json).unwrap();
645
646 assert_eq!(parsed.id, session.id);
647 assert_eq!(parsed.messages.len(), session.messages.len());
648 }
649
650 #[test]
651 fn test_tool_names_from_definitions() {
652 let tools = vec![
653 crate::llm::ToolDefinition {
654 name: "bash".to_string(),
655 description: "Execute bash".to_string(),
656 parameters: serde_json::json!({}),
657 },
658 crate::llm::ToolDefinition {
659 name: "read".to_string(),
660 description: "Read file".to_string(),
661 parameters: serde_json::json!({}),
662 },
663 ];
664
665 let names = SessionData::tool_names_from_definitions(&tools);
666 assert_eq!(names, vec!["bash", "read"]);
667 }
668
669 #[tokio::test]
674 async fn test_file_store_backslash_sanitization() {
675 let dir = tempdir().unwrap();
676 let store = FileSessionStore::new(dir.path()).await.unwrap();
677
678 let mut session = create_test_session_data();
679 session.id = r"foo\bar\baz".to_string();
680 store.save(&session).await.unwrap();
681
682 let loaded = store.load(&session.id).await.unwrap();
683 assert!(loaded.is_some());
684
685 let loaded = loaded.unwrap();
686 assert_eq!(loaded.id, session.id);
687
688 let expected_path = dir.path().join("foo_bar_baz.json");
690 assert!(expected_path.exists());
691 }
692
693 #[tokio::test]
694 async fn test_file_store_mixed_separator_sanitization() {
695 let dir = tempdir().unwrap();
696 let store = FileSessionStore::new(dir.path()).await.unwrap();
697
698 let mut session = create_test_session_data();
699 session.id = r"foo/bar\baz..qux".to_string();
700 store.save(&session).await.unwrap();
701
702 let loaded = store.load(&session.id).await.unwrap();
703 assert!(loaded.is_some());
704
705 let loaded = loaded.unwrap();
706 assert_eq!(loaded.id, session.id);
707
708 let expected_path = dir.path().join("foo_bar_baz_qux.json");
710 assert!(expected_path.exists());
711 }
712
713 #[tokio::test]
718 async fn test_file_store_corrupted_json_recovery() {
719 let dir = tempdir().unwrap();
720 let store = FileSessionStore::new(dir.path()).await.unwrap();
721
722 let corrupted_path = dir.path().join("test-id.json");
724 tokio::fs::write(&corrupted_path, b"not valid json {{{")
725 .await
726 .unwrap();
727
728 let result = store.load("test-id").await;
730 assert!(result.is_err());
731 }
732
733 #[tokio::test]
738 async fn test_file_store_exists() {
739 let dir = tempdir().unwrap();
740 let store = FileSessionStore::new(dir.path()).await.unwrap();
741
742 let session = create_test_session_data();
743
744 assert!(!store.exists(&session.id).await.unwrap());
746
747 store.save(&session).await.unwrap();
749 assert!(store.exists(&session.id).await.unwrap());
750
751 store.delete(&session.id).await.unwrap();
753 assert!(!store.exists(&session.id).await.unwrap());
754 }
755
756 #[tokio::test]
757 async fn test_memory_store_exists() {
758 let store = MemorySessionStore::new();
759
760 assert!(!store.exists("unknown-id").await.unwrap());
762
763 let session = create_test_session_data();
765 store.save(&session).await.unwrap();
766 assert!(store.exists(&session.id).await.unwrap());
767 }
768
769 #[tokio::test]
770 async fn test_file_store_health_check() {
771 let dir = tempfile::tempdir().unwrap();
772 let store = FileSessionStore::new(dir.path()).await.unwrap();
773 assert!(store.health_check().await.is_ok());
774 assert_eq!(store.backend_name(), "file");
775 }
776
777 #[tokio::test]
778 async fn test_file_store_health_check_bad_dir() {
779 let store = FileSessionStore {
780 dir: std::path::PathBuf::from("/nonexistent/path/that/does/not/exist"),
781 };
782 assert!(store.health_check().await.is_err());
783 }
784
785 #[tokio::test]
786 async fn test_memory_store_health_check() {
787 let store = MemorySessionStore::new();
788 assert!(store.health_check().await.is_ok());
789 assert_eq!(store.backend_name(), "memory");
790 }
791}