1use anyhow::{bail, Result};
7use chrono::{DateTime, Utc};
8use serde::{de::DeserializeOwned, Deserialize, Deserializer, Serialize, Serializer};
9use std::path::PathBuf;
10use tokio::fs;
11
12#[derive(Debug, Clone, PartialEq, Eq, Hash)]
14pub struct SessionId(pub String);
15
16impl SessionId {
17 pub fn new() -> Self {
19 Self(uuid::Uuid::new_v4().to_string())
20 }
21}
22
23impl Default for SessionId {
24 fn default() -> Self {
25 Self::new()
26 }
27}
28
29impl std::fmt::Display for SessionId {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 write!(f, "{}", self.0)
32 }
33}
34
35impl Serialize for SessionId {
36 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
37 where
38 S: Serializer,
39 {
40 serializer.serialize_str(&self.0)
41 }
42}
43
44impl<'de> Deserialize<'de> for SessionId {
45 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
46 where
47 D: Deserializer<'de>,
48 {
49 let s = String::deserialize(deserializer)?;
50 Ok(Self(s))
51 }
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct UserMessage {
57 pub content: String,
59 pub timestamp: DateTime<Utc>,
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct AgentResponse {
66 pub content: String,
68 pub session_id: Option<String>,
70 pub seed_id: Option<String>,
72 pub phase_reached: Option<String>,
74 pub evaluation_passed: Option<bool>,
76 pub timestamp: DateTime<Utc>,
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct TrajectoryStepRecord {
88 pub tool_name: String,
90 pub tool_args: serde_json::Value,
92 pub output_summary: String,
94 pub duration_ms: u64,
96 pub is_error: bool,
98 pub tool_call_id: String,
100 pub timestamp: DateTime<Utc>,
102}
103
104pub type SessionMetadata = std::collections::HashMap<String, serde_json::Value>;
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct Session {
114 pub id: SessionId,
116 pub user_id: String,
118 #[serde(default)]
120 pub user_messages: Vec<UserMessage>,
121 #[serde(default)]
123 pub agent_responses: Vec<AgentResponse>,
124 #[serde(default, skip_serializing_if = "Vec::is_empty")]
128 pub trajectory_steps: Vec<TrajectoryStepRecord>,
129 #[serde(skip_serializing_if = "Option::is_none")]
131 pub active_seed_id: Option<String>,
132 #[serde(skip_serializing_if = "Option::is_none")]
134 pub active_persona_id: Option<String>,
135 pub created_at: DateTime<Utc>,
137 pub updated_at: DateTime<Utc>,
139 #[serde(default)]
141 pub metadata: SessionMetadata,
142}
143
144impl Session {
145 pub fn new(user_id: impl Into<String>) -> Self {
147 let now = Utc::now();
148 Self {
149 id: SessionId::new(),
150 user_id: user_id.into(),
151 user_messages: Vec::new(),
152 agent_responses: Vec::new(),
153 trajectory_steps: Vec::new(),
154 active_seed_id: None,
155 active_persona_id: None,
156 created_at: now,
157 updated_at: now,
158 metadata: SessionMetadata::new(),
159 }
160 }
161
162 pub fn with_id(user_id: impl Into<String>, session_id: SessionId) -> Self {
164 let now = Utc::now();
165 Self {
166 id: session_id,
167 user_id: user_id.into(),
168 user_messages: Vec::new(),
169 agent_responses: Vec::new(),
170 trajectory_steps: Vec::new(),
171 active_seed_id: None,
172 active_persona_id: None,
173 created_at: now,
174 updated_at: now,
175 metadata: SessionMetadata::new(),
176 }
177 }
178
179 pub fn add_user_message(&mut self, content: impl Into<String>) {
181 self.user_messages.push(UserMessage {
182 content: content.into(),
183 timestamp: Utc::now(),
184 });
185 self.updated_at = Utc::now();
186 }
187
188 pub fn add_agent_response(&mut self, response: AgentResponse) {
190 self.agent_responses.push(response);
191 self.updated_at = Utc::now();
192 }
193
194 pub fn extend_trajectory(&mut self, steps: Vec<TrajectoryStepRecord>) {
199 if steps.is_empty() {
200 return;
201 }
202 self.trajectory_steps.extend(steps);
203 self.updated_at = Utc::now();
204 }
205
206 pub fn trajectory(&self) -> &[TrajectoryStepRecord] {
208 &self.trajectory_steps
209 }
210
211 pub fn set_active_seed(&mut self, seed_id: Option<String>) {
213 self.active_seed_id = seed_id;
214 self.updated_at = Utc::now();
215 }
216
217 pub fn set_active_persona(&mut self, persona_id: Option<String>) {
219 self.active_persona_id = persona_id;
220 self.updated_at = Utc::now();
221 }
222
223 pub fn set_metadata(&mut self, key: impl Into<String>, value: serde_json::Value) {
225 self.metadata.insert(key.into(), value);
226 self.updated_at = Utc::now();
227 }
228
229 pub fn get_metadata(&self, key: &str) -> Option<&serde_json::Value> {
231 self.metadata.get(key)
232 }
233
234 pub fn exchange_count(&self) -> usize {
236 self.user_messages.len().min(self.agent_responses.len())
237 }
238
239 pub fn is_empty(&self) -> bool {
241 self.user_messages.is_empty()
242 }
243}
244#[derive(Clone)]
249pub struct StateStore {
250 pub base_path: PathBuf,
252}
253
254impl StateStore {
255 pub fn new(base_path: PathBuf) -> Result<Self> {
266 Ok(Self { base_path })
267 }
268
269 fn validate_category(category: &str) -> Result<()> {
271 if category.contains("..") || category.contains('\\') {
272 bail!("invalid category name: '{category}'");
273 }
274 if category.is_empty()
275 || category.starts_with('/')
276 || category.ends_with('/')
277 || category.contains("//")
278 {
279 bail!("invalid category name: '{category}'");
280 }
281 Ok(())
282 }
283
284 fn validate_name(name: &str) -> Result<()> {
286 if name.contains("..") || name.contains('/') || name.contains('\\') {
287 bail!("invalid file name: '{name}'");
288 }
289 Ok(())
290 }
291
292 pub async fn save_markdown(&self, category: &str, name: &str, content: &str) -> Result<()> {
294 Self::validate_category(category)?;
295 Self::validate_name(name)?;
296 let dir = self.base_path.join(category);
297 fs::create_dir_all(&dir).await?;
298 let path = dir.join(format!("{name}.md"));
299
300 let temp_path = dir.join(format!(
302 "{name}.{}.{}.tmp",
303 std::process::id(),
304 uuid::Uuid::new_v4()
305 ));
306 fs::write(&temp_path, content).await?;
307 tokio::fs::rename(&temp_path, &path).await?;
308
309 Ok(())
310 }
311
312 pub async fn load_markdown(&self, category: &str, name: &str) -> Result<Option<String>> {
314 Self::validate_category(category)?;
315 Self::validate_name(name)?;
316 let path = self.base_path.join(category).join(format!("{name}.md"));
317 match fs::read_to_string(&path).await {
318 Ok(content) => Ok(Some(content)),
319 Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None),
320 Err(e) => Err(e.into()),
321 }
322 }
323
324 pub async fn list_category(&self, category: &str) -> Result<Vec<String>> {
326 Self::validate_category(category)?;
327 let dir = self.base_path.join(category);
328 if !dir.exists() {
329 return Ok(Vec::new());
330 }
331 let mut entries = fs::read_dir(&dir).await?;
332 let mut names = Vec::new();
333 while let Some(entry) = entries.next_entry().await? {
334 let path = entry.path();
335 if let Some(ext) = path.extension() {
336 if ext == "md" || ext == "json" {
337 if let Some(stem) = path.file_stem() {
338 names.push(stem.to_string_lossy().into_owned());
339 }
340 }
341 }
342 }
343 names.sort();
344 Ok(names)
345 }
346
347 pub async fn save_json<T: Serialize>(
349 &self,
350 category: &str,
351 name: &str,
352 data: &T,
353 ) -> Result<()> {
354 Self::validate_category(category)?;
355 Self::validate_name(name)?;
356 let dir = self.base_path.join(category);
357 fs::create_dir_all(&dir).await?;
358 let path = dir.join(format!("{name}.json"));
359
360 let content = serde_json::to_string_pretty(data)?;
361
362 let temp_path = dir.join(format!(
364 "{name}.{}.{}.tmp",
365 std::process::id(),
366 uuid::Uuid::new_v4()
367 ));
368 fs::write(&temp_path, &content).await?;
369 tokio::fs::rename(&temp_path, &path).await?;
370
371 Ok(())
372 }
373
374 pub async fn load_json<T: DeserializeOwned>(
376 &self,
377 category: &str,
378 name: &str,
379 ) -> Result<Option<T>> {
380 Self::validate_category(category)?;
381 Self::validate_name(name)?;
382 let path = self.base_path.join(category).join(format!("{name}.json"));
383 match fs::read_to_string(&path).await {
384 Ok(content) => Ok(Some(serde_json::from_str(&content)?)),
385 Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None),
386 Err(e) => Err(e.into()),
387 }
388 }
389
390 pub async fn delete_file(&self, category: &str, name: &str) -> Result<bool> {
392 Self::validate_category(category)?;
393 Self::validate_name(name)?;
394 let path = self.base_path.join(category).join(format!("{name}.json"));
395 if path.exists() {
396 tokio::fs::remove_file(path).await?;
397 Ok(true)
398 } else {
399 let path = self.base_path.join(category).join(format!("{name}.md"));
400 if path.exists() {
401 tokio::fs::remove_file(path).await?;
402 Ok(true)
403 } else {
404 Ok(false)
405 }
406 }
407 }
408}
409
410impl std::fmt::Debug for StateStore {
411 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
412 f.debug_struct("StateStore")
413 .field("base_path", &self.base_path)
414 .finish()
415 }
416}
417
418impl StateStore {
419 pub async fn save_session(&self, session: &Session) -> Result<()> {
421 self.save_json("sessions", &session.id.0, session).await
422 }
423
424 pub async fn save_session_with_prune(
426 &self,
427 session: &Session,
428 prune_config: &PruneConfig,
429 ) -> Result<()> {
430 self.save_session(session).await?;
431 let store = self.clone();
433 let config = prune_config.clone();
434 tokio::spawn(async move {
435 if let Err(e) = store.prune_sessions(&config).await {
436 tracing::warn!(error = %e, "Background session pruning failed");
437 }
438 });
439 Ok(())
440 }
441
442 pub async fn load_session(&self, session_id: &SessionId) -> Result<Option<Session>> {
444 self.load_json("sessions", &session_id.0).await
445 }
446
447 pub async fn list_sessions(&self) -> Result<Vec<SessionSummary>> {
449 let mut sessions = Vec::new();
450
451 if let Ok(names) = self.list_category("sessions").await {
452 for name in names {
453 if let Ok(Some(session)) = self.load_json::<Session>("sessions", &name).await {
454 sessions.push(SessionSummary {
455 id: session.id.0.clone(),
456 user_id: session.user_id.clone(),
457 message_count: session.user_messages.len(),
458 active_seed_id: session.active_seed_id.clone(),
459 project_id: session
460 .metadata
461 .get("project_ids")
462 .and_then(|v| v.as_str())
463 .map(String::from),
464 created_at: session.created_at,
465 updated_at: session.updated_at,
466 });
467 }
468 }
469 }
470
471 sessions.sort_by_key(|b| std::cmp::Reverse(b.updated_at));
473 Ok(sessions)
474 }
475
476 pub async fn delete_session(&self, session_id: &SessionId) -> Result<bool> {
478 let path = self
479 .base_path
480 .join("sessions")
481 .join(format!("{}.json", session_id.0));
482 match fs::remove_file(&path).await {
483 Ok(()) => {
484 tracing::info!(session_id = %session_id, "Session deleted");
485 Ok(true)
486 }
487 Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(false),
488 Err(e) => Err(e.into()),
489 }
490 }
491
492 pub async fn get_or_create_session(
494 &self,
495 user_id: &str,
496 session_id: Option<&SessionId>,
497 ) -> Result<Session> {
498 if let Some(sid) = session_id {
499 if let Some(existing) = self.load_session(sid).await? {
500 return Ok(existing);
501 }
502 }
503
504 let session = match session_id {
506 Some(sid) => Session::with_id(user_id, sid.clone()),
507 None => Session::new(user_id),
508 };
509
510 self.save_session(&session).await?;
511 Ok(session)
512 }
513
514 pub async fn update_session(&self, session: &Session) -> Result<()> {
516 self.save_session(session).await
517 }
518
519 pub async fn prune_sessions(&self, config: &PruneConfig) -> Result<usize> {
524 let mut sessions = self.list_sessions().await?;
525 let mut pruned = 0;
526
527 if config.ttl_hours > 0 {
529 let cutoff = Utc::now() - chrono::Duration::hours(config.ttl_hours as i64);
530 let to_prune_ttl: Vec<String> = sessions
531 .iter()
532 .filter(|s| s.updated_at < cutoff)
533 .map(|s| s.id.clone())
534 .collect();
535
536 for id in &to_prune_ttl {
537 let sid = SessionId(id.clone());
538 if self.delete_session(&sid).await.is_ok() {
539 pruned += 1;
540 }
541 }
542
543 sessions.retain(|s| !to_prune_ttl.contains(&s.id));
545 }
546
547 if config.max_sessions > 0 && sessions.len() > config.max_sessions {
549 let excess = sessions.len() - config.max_sessions;
551 for session in sessions.into_iter().rev().take(excess) {
552 let sid = SessionId(session.id);
553 if self.delete_session(&sid).await.is_ok() {
554 pruned += 1;
555 }
556 }
557 }
558
559 if pruned > 0 {
560 tracing::info!(pruned = pruned, "Session pruning completed");
561 }
562
563 Ok(pruned)
564 }
565}
566
567#[derive(Debug, Clone, Serialize, Deserialize)]
569pub struct SessionSummary {
570 pub id: String,
572 pub user_id: String,
574 pub message_count: usize,
576 #[serde(skip_serializing_if = "Option::is_none")]
578 pub active_seed_id: Option<String>,
579 #[serde(skip_serializing_if = "Option::is_none")]
581 pub project_id: Option<String>,
582 pub created_at: DateTime<Utc>,
584 pub updated_at: DateTime<Utc>,
586}
587
588#[derive(Debug, Clone)]
590pub struct PruneConfig {
591 pub max_sessions: usize,
593 pub ttl_hours: u64,
595}
596
597impl Default for PruneConfig {
598 fn default() -> Self {
599 Self {
600 max_sessions: 100,
601 ttl_hours: 168, }
603 }
604}
605
606pub struct PruneThrottle {
608 last_prune: std::sync::Mutex<Option<std::time::Instant>>,
610 cooldown_secs: u64,
612}
613
614impl PruneThrottle {
615 pub fn new(cooldown_secs: u64) -> Self {
617 Self {
618 last_prune: std::sync::Mutex::new(None),
619 cooldown_secs,
620 }
621 }
622
623 pub fn should_prune(&self) -> bool {
626 let mut guard = self.last_prune.lock().unwrap_or_else(|e| {
629 tracing::warn!("PruneThrottle mutex poisoned, recovering: {e}");
630 e.into_inner()
631 });
632 let now = std::time::Instant::now();
633 match *guard {
634 Some(last) => {
635 if now.duration_since(last).as_secs() >= self.cooldown_secs {
636 *guard = Some(now);
637 true
638 } else {
639 false
640 }
641 }
642 None => {
643 *guard = Some(now);
644 true
645 }
646 }
647 }
648}
649
650#[cfg(test)]
651mod tests {
652 use super::*;
653
654 #[tokio::test]
655 async fn test_session_creation_and_persistence() {
656 let temp_dir = tempfile::tempdir().unwrap();
657 let store = StateStore::new(temp_dir.path().to_path_buf()).unwrap();
658
659 let mut session = Session::new("user-123");
661 session.add_user_message("Hello");
662
663 store.save_session(&session).await.unwrap();
665 let loaded = store.load_session(&session.id).await.unwrap();
666 assert!(loaded.is_some());
667 let loaded = loaded.unwrap();
668 assert_eq!(loaded.user_id, "user-123");
669 assert_eq!(loaded.user_messages.len(), 1);
670 }
671
672 #[tokio::test]
673 async fn test_session_list_sorts_by_updated() {
674 let temp_dir = tempfile::tempdir().unwrap();
675 let store = StateStore::new(temp_dir.path().to_path_buf()).unwrap();
676
677 for i in 0..3 {
679 let mut session = Session::new(&format!("user-{}", i));
680 session.add_user_message(&format!("Message {}", i));
681 store.save_session(&session).await.unwrap();
682 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
683 }
684
685 let sessions = store.list_sessions().await.unwrap();
686 assert_eq!(sessions.len(), 3);
687 assert_eq!(sessions[0].user_id, "user-2");
689 }
690
691 #[tokio::test]
692 async fn test_delete_session() {
693 let temp_dir = tempfile::tempdir().unwrap();
694 let store = StateStore::new(temp_dir.path().to_path_buf()).unwrap();
695
696 let session = Session::new("user-123");
697 store.save_session(&session).await.unwrap();
698
699 let deleted = store.delete_session(&session.id).await.unwrap();
701 assert!(deleted);
702
703 let loaded = store.load_session(&session.id).await.unwrap();
704 assert!(loaded.is_none());
705 }
706
707 #[tokio::test]
708 async fn test_get_or_create_session_existing() {
709 let temp_dir = tempfile::tempdir().unwrap();
710 let store = StateStore::new(temp_dir.path().to_path_buf()).unwrap();
711
712 let mut existing = Session::new("user-123");
713 existing.add_user_message("Original message");
714 store.save_session(&existing).await.unwrap();
715
716 let retrieved = store
718 .get_or_create_session("user-123", Some(&existing.id))
719 .await
720 .unwrap();
721 assert_eq!(retrieved.id, existing.id);
722 assert_eq!(retrieved.user_messages.len(), 1);
723 }
724
725 #[tokio::test]
726 async fn test_get_or_create_session_new() {
727 let temp_dir = tempfile::tempdir().unwrap();
728 let store = StateStore::new(temp_dir.path().to_path_buf()).unwrap();
729
730 let session = store.get_or_create_session("user-456", None).await.unwrap();
732 assert_eq!(session.user_id, "user-456");
733 assert!(session.user_messages.is_empty());
734 }
735
736 #[tokio::test]
737 async fn test_prune_sessions_by_count() {
738 let temp_dir = tempfile::tempdir().unwrap();
739 let store = StateStore::new(temp_dir.path().to_path_buf()).unwrap();
740
741 for i in 0..5 {
743 let mut session = Session::new(&format!("user-{}", i));
744 session.add_user_message(&format!("Message {}", i));
745 store.save_session(&session).await.unwrap();
746 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
747 }
748
749 let config = PruneConfig {
751 max_sessions: 3,
752 ttl_hours: 0,
753 };
754 let pruned = store.prune_sessions(&config).await.unwrap();
755 assert_eq!(pruned, 2);
756
757 let remaining = store.list_sessions().await.unwrap();
758 assert_eq!(remaining.len(), 3);
759 let remaining_ids: Vec<&str> = remaining.iter().map(|s| s.user_id.as_str()).collect();
761 assert!(remaining_ids.contains(&"user-2"));
762 assert!(remaining_ids.contains(&"user-3"));
763 assert!(remaining_ids.contains(&"user-4"));
764 }
765
766 #[tokio::test]
767 async fn test_prune_sessions_by_ttl() {
768 let temp_dir = tempfile::tempdir().unwrap();
769 let store = StateStore::new(temp_dir.path().to_path_buf()).unwrap();
770
771 let mut old_session = Session::new("old-user");
773 old_session.updated_at = Utc::now() - chrono::Duration::hours(48);
774 store.save_session(&old_session).await.unwrap();
775
776 let mut recent_session = Session::new("recent-user");
778 recent_session.add_user_message("Hello");
779 store.save_session(&recent_session).await.unwrap();
780
781 let config = PruneConfig {
783 max_sessions: 0,
784 ttl_hours: 24,
785 };
786 let pruned = store.prune_sessions(&config).await.unwrap();
787 assert_eq!(pruned, 1);
788
789 let remaining = store.list_sessions().await.unwrap();
790 assert_eq!(remaining.len(), 1);
791 assert_eq!(remaining[0].user_id, "recent-user");
792 }
793}