Skip to main content

oxios_kernel/
state_store.rs

1//! Filesystem-based state store.
2//!
3//! All state is persisted as markdown or JSON files organized
4//! by category. This is the "filesystem" of Oxios.
5
6use anyhow::{Result, bail};
7use chrono::{DateTime, Utc};
8use serde::{Deserialize, Deserializer, Serialize, Serializer, de::DeserializeOwned};
9use std::path::PathBuf;
10use tokio::fs;
11
12/// Unique identifier for a session.
13#[derive(Debug, Clone, PartialEq, Eq, Hash)]
14pub struct SessionId(pub String);
15
16impl SessionId {
17    /// Creates a new random session ID.
18    pub fn new() -> Self {
19        Self(uuid::Uuid::new_v4().to_string())
20    }
21}
22
23impl Default for SessionId {
24    fn default() -> Self {
25        Self::new()
26    }
27}
28
29impl std::fmt::Display for SessionId {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        write!(f, "{}", self.0)
32    }
33}
34
35impl Serialize for SessionId {
36    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
37    where
38        S: Serializer,
39    {
40        serializer.serialize_str(&self.0)
41    }
42}
43
44impl<'de> Deserialize<'de> for SessionId {
45    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
46    where
47        D: Deserializer<'de>,
48    {
49        let s = String::deserialize(deserializer)?;
50        Ok(Self(s))
51    }
52}
53
54/// A user message in a session.
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct UserMessage {
57    /// Message content.
58    pub content: String,
59    /// Timestamp when the message was sent.
60    pub timestamp: DateTime<Utc>,
61}
62
63/// An agent response in a session.
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct AgentResponse {
66    /// Response content.
67    pub content: String,
68    /// Session ID associated with this response.
69    pub session_id: Option<String>,
70    /// Seed ID used for this response (if any).
71    pub seed_id: Option<String>,
72    /// Phase reached during orchestration.
73    pub phase_reached: Option<String>,
74    /// Whether evaluation passed.
75    pub evaluation_passed: Option<bool>,
76    /// Timestamp when the response was generated.
77    pub timestamp: DateTime<Utc>,
78    /// Index range into `Session::trajectory_steps` for tool calls that
79    /// occurred during this response. `None` when no tool calls were made.
80    /// Used by the Web UI to render per-turn execution timelines.
81    #[serde(skip_serializing_if = "Option::is_none")]
82    pub trajectory_range: Option<TrajectoryRange>,
83}
84
85/// Index range (exclusive end) into `Session::trajectory_steps`.
86#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct TrajectoryRange {
88    /// Start index (inclusive).
89    pub start: usize,
90    /// End index (exclusive).
91    pub end: usize,
92}
93
94/// A single tool execution step recorded in a session (RFC-015).
95///
96/// Persisted alongside the agent response so that the Web UI can render the
97/// execution timeline (tool calls, durations, errors) when the user
98/// re-opens the session later. Mirrors `memory::sona::TrajectoryStep` but
99/// is duplicated here to avoid a kernel-state → memory dependency cycle.
100#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct TrajectoryStepRecord {
102    /// Name of the tool that was called.
103    pub tool_name: String,
104    /// Tool input arguments (JSON).
105    pub tool_args: serde_json::Value,
106    /// Truncated output (max ~500 chars).
107    pub output_summary: String,
108    /// Wall-clock duration in milliseconds.
109    pub duration_ms: u64,
110    /// Whether the tool returned an error.
111    pub is_error: bool,
112    /// Provider-specific tool call ID (for start/end correlation).
113    pub tool_call_id: String,
114    /// Timestamp when the step started.
115    pub timestamp: DateTime<Utc>,
116}
117
118/// Arbitrary key-value metadata for a session.
119pub type SessionMetadata = std::collections::HashMap<String, serde_json::Value>;
120
121/// A session represents a single user conversation.
122///
123/// Sessions track the full message history and metadata for
124/// a user conversation. They are created per user interaction
125/// and persisted for later retrieval.
126#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct Session {
128    /// Unique session identifier.
129    pub id: SessionId,
130    /// User ID who owns this session.
131    pub user_id: String,
132    /// All user messages in this session.
133    #[serde(default)]
134    pub user_messages: Vec<UserMessage>,
135    /// All agent responses in this session.
136    #[serde(default)]
137    pub agent_responses: Vec<AgentResponse>,
138    /// RFC-015: tool execution trajectory accumulated for this session.
139    /// Appended on each orchestrator run; consumed by the Web UI to render
140    /// the execution timeline when the session is re-opened.
141    #[serde(default, skip_serializing_if = "Vec::is_empty")]
142    pub trajectory_steps: Vec<TrajectoryStepRecord>,
143    /// Currently active seed ID (if any).
144    #[serde(skip_serializing_if = "Option::is_none")]
145    pub active_seed_id: Option<String>,
146    /// Currently active persona ID (for future multi-persona support).
147    #[serde(skip_serializing_if = "Option::is_none")]
148    pub active_persona_id: Option<String>,
149    /// Timestamp when the session was created.
150    pub created_at: DateTime<Utc>,
151    /// Timestamp when the session was last updated.
152    pub updated_at: DateTime<Utc>,
153    /// Arbitrary key-value metadata.
154    #[serde(default)]
155    pub metadata: SessionMetadata,
156}
157
158impl Session {
159    /// Creates a new session for a user.
160    pub fn new(user_id: impl Into<String>) -> Self {
161        let now = Utc::now();
162        Self {
163            id: SessionId::new(),
164            user_id: user_id.into(),
165            user_messages: Vec::new(),
166            agent_responses: Vec::new(),
167            trajectory_steps: Vec::new(),
168            active_seed_id: None,
169            active_persona_id: None,
170            created_at: now,
171            updated_at: now,
172            metadata: SessionMetadata::new(),
173        }
174    }
175
176    /// Creates a session with a specific ID.
177    pub fn with_id(user_id: impl Into<String>, session_id: SessionId) -> Self {
178        let now = Utc::now();
179        Self {
180            id: session_id,
181            user_id: user_id.into(),
182            user_messages: Vec::new(),
183            agent_responses: Vec::new(),
184            trajectory_steps: Vec::new(),
185            active_seed_id: None,
186            active_persona_id: None,
187            created_at: now,
188            updated_at: now,
189            metadata: SessionMetadata::new(),
190        }
191    }
192
193    /// Adds a user message to the session.
194    pub fn add_user_message(&mut self, content: impl Into<String>) {
195        self.user_messages.push(UserMessage {
196            content: content.into(),
197            timestamp: Utc::now(),
198        });
199        self.updated_at = Utc::now();
200    }
201
202    /// Adds an agent response to the session.
203    pub fn add_agent_response(&mut self, response: AgentResponse) {
204        self.agent_responses.push(response);
205        self.updated_at = Utc::now();
206    }
207
208    /// Appends trajectory steps to the session (RFC-015).
209    ///
210    /// Called by the orchestrator after each run so the Web UI can
211    /// re-render the execution timeline when the user re-opens the session.
212    pub fn extend_trajectory(&mut self, steps: Vec<TrajectoryStepRecord>) {
213        if steps.is_empty() {
214            return;
215        }
216        self.trajectory_steps.extend(steps);
217        self.updated_at = Utc::now();
218    }
219
220    /// Returns the trajectory steps recorded in this session.
221    pub fn trajectory(&self) -> &[TrajectoryStepRecord] {
222        &self.trajectory_steps
223    }
224
225    /// Sets the active seed ID.
226    pub fn set_active_seed(&mut self, seed_id: Option<String>) {
227        self.active_seed_id = seed_id;
228        self.updated_at = Utc::now();
229    }
230
231    /// Sets the active persona ID.
232    pub fn set_active_persona(&mut self, persona_id: Option<String>) {
233        self.active_persona_id = persona_id;
234        self.updated_at = Utc::now();
235    }
236
237    /// Sets a metadata value.
238    pub fn set_metadata(&mut self, key: impl Into<String>, value: serde_json::Value) {
239        self.metadata.insert(key.into(), value);
240        self.updated_at = Utc::now();
241    }
242
243    /// Gets a metadata value.
244    pub fn get_metadata(&self, key: &str) -> Option<&serde_json::Value> {
245        self.metadata.get(key)
246    }
247
248    /// Returns the total number of exchanges in this session.
249    pub fn exchange_count(&self) -> usize {
250        self.user_messages.len().min(self.agent_responses.len())
251    }
252
253    /// Returns true if the session is empty (no messages).
254    pub fn is_empty(&self) -> bool {
255        self.user_messages.is_empty()
256    }
257}
258/// A filesystem-based persistent state store.
259///
260/// Files are organized as `<base_path>/<category>/<name>.md` or
261/// `<base_path>/<category>/<name>.json`.
262#[derive(Clone)]
263pub struct StateStore {
264    /// Root directory for all state files.
265    pub base_path: PathBuf,
266}
267
268impl StateStore {
269    /// Creates a new state store, initializing the directory if needed.
270    ///
271    /// # Example
272    ///
273    /// ```no_run
274    /// use oxios_kernel::state_store::StateStore;
275    /// use std::path::PathBuf;
276    ///
277    /// let store = StateStore::new(PathBuf::from("/tmp/oxios-state")).unwrap();
278    /// ```
279    pub fn new(base_path: PathBuf) -> Result<Self> {
280        Ok(Self { base_path })
281    }
282
283    /// Validate that a category name does not contain path traversal.
284    fn validate_category(category: &str) -> Result<()> {
285        if category.contains("..") || category.contains('\\') {
286            bail!("invalid category name: '{category}'");
287        }
288        if category.is_empty()
289            || category.starts_with('/')
290            || category.ends_with('/')
291            || category.contains("//")
292        {
293            bail!("invalid category name: '{category}'");
294        }
295        Ok(())
296    }
297
298    /// Validate that a file name does not contain path traversal.
299    fn validate_name(name: &str) -> Result<()> {
300        if name.contains("..") || name.contains('/') || name.contains('\\') {
301            bail!("invalid file name: '{name}'");
302        }
303        Ok(())
304    }
305
306    /// Save a markdown file under the given category.
307    pub async fn save_markdown(&self, category: &str, name: &str, content: &str) -> Result<()> {
308        Self::validate_category(category)?;
309        Self::validate_name(name)?;
310        let dir = self.base_path.join(category);
311        fs::create_dir_all(&dir).await?;
312        let path = dir.join(format!("{name}.md"));
313
314        // Write to temp file first, then atomic rename
315        let temp_path = dir.join(format!(
316            "{name}.{}.{}.tmp",
317            std::process::id(),
318            uuid::Uuid::new_v4()
319        ));
320        fs::write(&temp_path, content).await?;
321        tokio::fs::rename(&temp_path, &path).await?;
322
323        Ok(())
324    }
325
326    /// Load a markdown file from the given category.
327    pub async fn load_markdown(&self, category: &str, name: &str) -> Result<Option<String>> {
328        Self::validate_category(category)?;
329        Self::validate_name(name)?;
330        let path = self.base_path.join(category).join(format!("{name}.md"));
331        match fs::read_to_string(&path).await {
332            Ok(content) => Ok(Some(content)),
333            Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None),
334            Err(e) => Err(e.into()),
335        }
336    }
337
338    /// List all markdown files in a category (names without extension).
339    pub async fn list_category(&self, category: &str) -> Result<Vec<String>> {
340        Self::validate_category(category)?;
341        let dir = self.base_path.join(category);
342        if !dir.exists() {
343            return Ok(Vec::new());
344        }
345        let mut entries = fs::read_dir(&dir).await?;
346        let mut names = Vec::new();
347        while let Some(entry) = entries.next_entry().await? {
348            let path = entry.path();
349            if let Some(ext) = path.extension()
350                && (ext == "md" || ext == "json")
351                && let Some(stem) = path.file_stem()
352            {
353                names.push(stem.to_string_lossy().into_owned());
354            }
355        }
356        names.sort();
357        Ok(names)
358    }
359
360    /// Save a serializable value as JSON under the given category.
361    pub async fn save_json<T: Serialize>(
362        &self,
363        category: &str,
364        name: &str,
365        data: &T,
366    ) -> Result<()> {
367        Self::validate_category(category)?;
368        Self::validate_name(name)?;
369        let dir = self.base_path.join(category);
370        fs::create_dir_all(&dir).await?;
371        let path = dir.join(format!("{name}.json"));
372
373        let content = serde_json::to_string_pretty(data)?;
374
375        // Write to temp file first, then atomic rename
376        let temp_path = dir.join(format!(
377            "{name}.{}.{}.tmp",
378            std::process::id(),
379            uuid::Uuid::new_v4()
380        ));
381        fs::write(&temp_path, &content).await?;
382        tokio::fs::rename(&temp_path, &path).await?;
383
384        Ok(())
385    }
386
387    /// Load a deserializable value from JSON in the given category.
388    pub async fn load_json<T: DeserializeOwned>(
389        &self,
390        category: &str,
391        name: &str,
392    ) -> Result<Option<T>> {
393        Self::validate_category(category)?;
394        Self::validate_name(name)?;
395        let path = self.base_path.join(category).join(format!("{name}.json"));
396        match fs::read_to_string(&path).await {
397            Ok(content) => Ok(Some(serde_json::from_str(&content)?)),
398            Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None),
399            Err(e) => Err(e.into()),
400        }
401    }
402
403    /// Delete a file from the given category.
404    pub async fn delete_file(&self, category: &str, name: &str) -> Result<bool> {
405        Self::validate_category(category)?;
406        Self::validate_name(name)?;
407        let path = self.base_path.join(category).join(format!("{name}.json"));
408        if path.exists() {
409            tokio::fs::remove_file(path).await?;
410            Ok(true)
411        } else {
412            let path = self.base_path.join(category).join(format!("{name}.md"));
413            if path.exists() {
414                tokio::fs::remove_file(path).await?;
415                Ok(true)
416            } else {
417                Ok(false)
418            }
419        }
420    }
421}
422
423impl std::fmt::Debug for StateStore {
424    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
425        f.debug_struct("StateStore")
426            .field("base_path", &self.base_path)
427            .finish()
428    }
429}
430
431impl StateStore {
432    /// Saves a session to the sessions category.
433    pub async fn save_session(&self, session: &Session) -> Result<()> {
434        self.save_json("sessions", &session.id.0, session).await
435    }
436
437    /// Saves a session and then runs pruning if auto_prune is enabled.
438    pub async fn save_session_with_prune(
439        &self,
440        session: &Session,
441        prune_config: &PruneConfig,
442    ) -> Result<()> {
443        self.save_session(session).await?;
444        // Prune in the background — don't block the response
445        let store = self.clone();
446        let config = prune_config.clone();
447        tokio::spawn(async move {
448            if let Err(e) = store.prune_sessions(&config).await {
449                tracing::warn!(error = %e, "Background session pruning failed");
450            }
451        });
452        Ok(())
453    }
454
455    /// Loads a session by ID.
456    pub async fn load_session(&self, session_id: &SessionId) -> Result<Option<Session>> {
457        self.load_json("sessions", &session_id.0).await
458    }
459
460    /// Lists all sessions (sorted by updated_at descending).
461    pub async fn list_sessions(&self) -> Result<Vec<SessionSummary>> {
462        let mut sessions = Vec::new();
463
464        if let Ok(names) = self.list_category("sessions").await {
465            for name in names {
466                if let Ok(Some(session)) = self.load_json::<Session>("sessions", &name).await {
467                    sessions.push(SessionSummary {
468                        id: session.id.0.clone(),
469                        user_id: session.user_id.clone(),
470                        message_count: session.user_messages.len(),
471                        title: session
472                            .metadata
473                            .get("title")
474                            .and_then(|v| v.as_str())
475                            .map(String::from)
476                            .or_else(|| {
477                                // Auto-generate from first user message
478                                session.user_messages.first().map(|m| {
479                                    let s = m.content.lines().next().unwrap_or("");
480                                    if s.len() > 60 {
481                                        format!("{}…", &s[..s.ceil_char_boundary(59)])
482                                    } else {
483                                        s.to_string()
484                                    }
485                                })
486                            }),
487                        active_seed_id: session.active_seed_id.clone(),
488                        project_id: session
489                            .metadata
490                            .get("project_ids")
491                            .and_then(|v| v.as_str())
492                            .map(String::from),
493                        created_at: session.created_at,
494                        updated_at: session.updated_at,
495                    });
496                }
497            }
498        }
499
500        // Sort by updated_at descending (most recent first)
501        sessions.sort_by_key(|b| std::cmp::Reverse(b.updated_at));
502        Ok(sessions)
503    }
504
505    /// Deletes a session by ID.
506    pub async fn delete_session(&self, session_id: &SessionId) -> Result<bool> {
507        let path = self
508            .base_path
509            .join("sessions")
510            .join(format!("{}.json", session_id.0));
511        match fs::remove_file(&path).await {
512            Ok(()) => {
513                tracing::info!(session_id = %session_id, "Session deleted");
514                Ok(true)
515            }
516            Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(false),
517            Err(e) => Err(e.into()),
518        }
519    }
520
521    /// Gets or creates a session for a user, initializing with the given session ID.
522    pub async fn get_or_create_session(
523        &self,
524        user_id: &str,
525        session_id: Option<&SessionId>,
526    ) -> Result<Session> {
527        if let Some(sid) = session_id
528            && let Some(existing) = self.load_session(sid).await?
529        {
530            return Ok(existing);
531        }
532
533        // Create new session
534        let session = match session_id {
535            Some(sid) => Session::with_id(user_id, sid.clone()),
536            None => Session::new(user_id),
537        };
538
539        self.save_session(&session).await?;
540        Ok(session)
541    }
542
543    /// Updates an existing session, saving it to disk.
544    pub async fn update_session(&self, session: &Session) -> Result<()> {
545        self.save_session(session).await
546    }
547
548    /// Prune sessions based on configuration.
549    ///
550    /// Removes sessions that exceed TTL or exceed the maximum count.
551    /// Returns the number of sessions pruned.
552    pub async fn prune_sessions(&self, config: &PruneConfig) -> Result<usize> {
553        let mut sessions = self.list_sessions().await?;
554        let mut pruned = 0;
555
556        // TTL-based pruning: remove sessions older than ttl_hours
557        if config.ttl_hours > 0 {
558            let cutoff = Utc::now() - chrono::Duration::hours(config.ttl_hours as i64);
559            let to_prune_ttl: Vec<String> = sessions
560                .iter()
561                .filter(|s| s.updated_at < cutoff)
562                .map(|s| s.id.clone())
563                .collect();
564
565            for id in &to_prune_ttl {
566                let sid = SessionId(id.clone());
567                if self.delete_session(&sid).await.is_ok() {
568                    pruned += 1;
569                }
570            }
571
572            // Remove pruned sessions from the list for count-based pruning
573            sessions.retain(|s| !to_prune_ttl.contains(&s.id));
574        }
575
576        // Count-based pruning: keep only the most recent `max_sessions`
577        if config.max_sessions > 0 && sessions.len() > config.max_sessions {
578            // sessions are already sorted by updated_at descending
579            let excess = sessions.len() - config.max_sessions;
580            for session in sessions.into_iter().rev().take(excess) {
581                let sid = SessionId(session.id);
582                if self.delete_session(&sid).await.is_ok() {
583                    pruned += 1;
584                }
585            }
586        }
587
588        if pruned > 0 {
589            tracing::info!(pruned = pruned, "Session pruning completed");
590        }
591
592        Ok(pruned)
593    }
594
595    /// Prune agent records based on config.
596    ///
597    /// 1. TTL-based: delete agents with created_at older than ttl_hours.
598    /// 2. Count-based: if still over max_entries, delete oldest.
599    pub async fn prune_agents_by_config(
600        &self,
601        max_entries: usize,
602        ttl_hours: u64,
603        batch_size: usize,
604    ) -> Result<usize> {
605        let mut pruned = 0usize;
606
607        let names = self.list_category("agents").await?;
608        if names.is_empty() {
609            return Ok(0);
610        }
611
612        let now = Utc::now();
613
614        // 1. TTL-based pruning
615        let mut remaining: Vec<(String, DateTime<Utc>)> = Vec::with_capacity(names.len());
616
617        if ttl_hours > 0 {
618            let cutoff = now - chrono::Duration::hours(ttl_hours as i64);
619            for name in &names {
620                // Load just enough to check created_at
621                if let Ok(Some(info)) = self
622                    .load_json::<crate::types::AgentInfo>("agents", name)
623                    .await
624                {
625                    if info.created_at < cutoff {
626                        if self.delete_file("agents", name).await.unwrap_or(false) {
627                            pruned += 1;
628                        }
629                    } else {
630                        remaining.push((name.clone(), info.created_at));
631                    }
632                }
633            }
634        } else {
635            // Load all created_at for count-based pruning
636            for name in &names {
637                if let Ok(Some(info)) = self
638                    .load_json::<crate::types::AgentInfo>("agents", name)
639                    .await
640                {
641                    remaining.push((name.clone(), info.created_at));
642                }
643            }
644        }
645
646        // 2. Count-based pruning
647        if max_entries > 0 && remaining.len() > max_entries {
648            // Sort by created_at ascending (oldest first)
649            remaining.sort_by_key(|a| a.1);
650
651            let excess = remaining.len() - max_entries;
652            let to_delete = excess.min(batch_size);
653
654            for (name, _) in remaining.iter().take(to_delete) {
655                if self.delete_file("agents", name).await.unwrap_or(false) {
656                    pruned += 1;
657                }
658            }
659        }
660
661        if pruned > 0 {
662            tracing::info!(pruned = pruned, "Agent filesystem pruning completed");
663        }
664
665        Ok(pruned)
666    }
667}
668
669/// Summary of a session for listing (without full message history).
670#[derive(Debug, Clone, Serialize, Deserialize)]
671pub struct SessionSummary {
672    /// Session ID.
673    pub id: String,
674    /// User ID who owns this session.
675    pub user_id: String,
676    /// Number of messages in this session.
677    pub message_count: usize,
678    /// Auto-generated title for this session. Derived from the first user
679    /// message (truncated to ~60 chars) when not explicitly set.
680    #[serde(skip_serializing_if = "Option::is_none")]
681    pub title: Option<String>,
682    /// Active seed ID if any.
683    #[serde(skip_serializing_if = "Option::is_none")]
684    pub active_seed_id: Option<String>,
685    /// Active project ID(s) this session belongs to.
686    #[serde(skip_serializing_if = "Option::is_none")]
687    pub project_id: Option<String>,
688    /// When the session was created.
689    pub created_at: DateTime<Utc>,
690    /// When the session was last updated.
691    pub updated_at: DateTime<Utc>,
692}
693
694/// Configuration for session pruning.
695#[derive(Debug, Clone)]
696pub struct PruneConfig {
697    /// Maximum number of sessions to keep. 0 = unlimited.
698    pub max_sessions: usize,
699    /// TTL in hours. Sessions older than this are pruned. 0 = no TTL.
700    pub ttl_hours: u64,
701}
702
703impl Default for PruneConfig {
704    fn default() -> Self {
705        Self {
706            max_sessions: 100,
707            ttl_hours: 168, // 7 days
708        }
709    }
710}
711
712/// Tracks the last time a prune was performed, enabling cooldown.
713pub struct PruneThrottle {
714    /// Instant of the last prune (monotonic).
715    last_prune: std::sync::Mutex<Option<std::time::Instant>>,
716    /// Minimum seconds between prune runs.
717    cooldown_secs: u64,
718}
719
720impl PruneThrottle {
721    /// Create a new throttle with the given cooldown.
722    pub fn new(cooldown_secs: u64) -> Self {
723        Self {
724            last_prune: std::sync::Mutex::new(None),
725            cooldown_secs,
726        }
727    }
728
729    /// Check if enough time has elapsed since the last prune.
730    /// Returns `true` if prune should proceed.
731    pub fn should_prune(&self) -> bool {
732        // SAFETY: parking_lot::Mutex never poisons, but std::sync::Mutex does.
733        // Recover from poison by taking the inner value so pruning continues.
734        let mut guard = self.last_prune.lock().unwrap_or_else(|e| {
735            tracing::warn!("PruneThrottle mutex poisoned, recovering: {e}");
736            e.into_inner()
737        });
738        let now = std::time::Instant::now();
739        match *guard {
740            Some(last) => {
741                if now.duration_since(last).as_secs() >= self.cooldown_secs {
742                    *guard = Some(now);
743                    true
744                } else {
745                    false
746                }
747            }
748            None => {
749                *guard = Some(now);
750                true
751            }
752        }
753    }
754}
755
756#[cfg(test)]
757mod tests {
758    use super::*;
759
760    #[tokio::test]
761    async fn test_session_creation_and_persistence() {
762        let temp_dir = tempfile::tempdir().unwrap();
763        let store = StateStore::new(temp_dir.path().to_path_buf()).unwrap();
764
765        // Create a session
766        let mut session = Session::new("user-123");
767        session.add_user_message("Hello");
768
769        // Save and load
770        store.save_session(&session).await.unwrap();
771        let loaded = store.load_session(&session.id).await.unwrap();
772        assert!(loaded.is_some());
773        let loaded = loaded.unwrap();
774        assert_eq!(loaded.user_id, "user-123");
775        assert_eq!(loaded.user_messages.len(), 1);
776    }
777
778    #[tokio::test]
779    async fn test_session_list_sorts_by_updated() {
780        let temp_dir = tempfile::tempdir().unwrap();
781        let store = StateStore::new(temp_dir.path().to_path_buf()).unwrap();
782
783        // Create multiple sessions
784        for i in 0..3 {
785            let mut session = Session::new(&format!("user-{}", i));
786            session.add_user_message(&format!("Message {}", i));
787            store.save_session(&session).await.unwrap();
788            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
789        }
790
791        let sessions = store.list_sessions().await.unwrap();
792        assert_eq!(sessions.len(), 3);
793        // Most recently updated should be first
794        assert_eq!(sessions[0].user_id, "user-2");
795    }
796
797    #[tokio::test]
798    async fn test_delete_session() {
799        let temp_dir = tempfile::tempdir().unwrap();
800        let store = StateStore::new(temp_dir.path().to_path_buf()).unwrap();
801
802        let session = Session::new("user-123");
803        store.save_session(&session).await.unwrap();
804
805        // Delete and verify
806        let deleted = store.delete_session(&session.id).await.unwrap();
807        assert!(deleted);
808
809        let loaded = store.load_session(&session.id).await.unwrap();
810        assert!(loaded.is_none());
811    }
812
813    #[tokio::test]
814    async fn test_get_or_create_session_existing() {
815        let temp_dir = tempfile::tempdir().unwrap();
816        let store = StateStore::new(temp_dir.path().to_path_buf()).unwrap();
817
818        let mut existing = Session::new("user-123");
819        existing.add_user_message("Original message");
820        store.save_session(&existing).await.unwrap();
821
822        // Get or create with same ID should return existing
823        let retrieved = store
824            .get_or_create_session("user-123", Some(&existing.id))
825            .await
826            .unwrap();
827        assert_eq!(retrieved.id, existing.id);
828        assert_eq!(retrieved.user_messages.len(), 1);
829    }
830
831    #[tokio::test]
832    async fn test_get_or_create_session_new() {
833        let temp_dir = tempfile::tempdir().unwrap();
834        let store = StateStore::new(temp_dir.path().to_path_buf()).unwrap();
835
836        // Get or create without existing session should create new
837        let session = store.get_or_create_session("user-456", None).await.unwrap();
838        assert_eq!(session.user_id, "user-456");
839        assert!(session.user_messages.is_empty());
840    }
841
842    #[tokio::test]
843    async fn test_prune_sessions_by_count() {
844        let temp_dir = tempfile::tempdir().unwrap();
845        let store = StateStore::new(temp_dir.path().to_path_buf()).unwrap();
846
847        // Create 5 sessions
848        for i in 0..5 {
849            let mut session = Session::new(&format!("user-{}", i));
850            session.add_user_message(&format!("Message {}", i));
851            store.save_session(&session).await.unwrap();
852            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
853        }
854
855        // Prune to max 3
856        let config = PruneConfig {
857            max_sessions: 3,
858            ttl_hours: 0,
859        };
860        let pruned = store.prune_sessions(&config).await.unwrap();
861        assert_eq!(pruned, 2);
862
863        let remaining = store.list_sessions().await.unwrap();
864        assert_eq!(remaining.len(), 3);
865        // Oldest sessions (user-0, user-1) should be pruned
866        let remaining_ids: Vec<&str> = remaining.iter().map(|s| s.user_id.as_str()).collect();
867        assert!(remaining_ids.contains(&"user-2"));
868        assert!(remaining_ids.contains(&"user-3"));
869        assert!(remaining_ids.contains(&"user-4"));
870    }
871
872    #[tokio::test]
873    async fn test_prune_sessions_by_ttl() {
874        let temp_dir = tempfile::tempdir().unwrap();
875        let store = StateStore::new(temp_dir.path().to_path_buf()).unwrap();
876
877        // Create a session and manually set updated_at to the past
878        let mut old_session = Session::new("old-user");
879        old_session.updated_at = Utc::now() - chrono::Duration::hours(48);
880        store.save_session(&old_session).await.unwrap();
881
882        // Create a recent session
883        let mut recent_session = Session::new("recent-user");
884        recent_session.add_user_message("Hello");
885        store.save_session(&recent_session).await.unwrap();
886
887        // Prune with 24-hour TTL
888        let config = PruneConfig {
889            max_sessions: 0,
890            ttl_hours: 24,
891        };
892        let pruned = store.prune_sessions(&config).await.unwrap();
893        assert_eq!(pruned, 1);
894
895        let remaining = store.list_sessions().await.unwrap();
896        assert_eq!(remaining.len(), 1);
897        assert_eq!(remaining[0].user_id, "recent-user");
898    }
899}