Skip to main content

oxios_kernel/
state_store.rs

1//! Filesystem-based state store.
2//!
3//! All state is persisted as markdown or JSON files organized
4//! by category. This is the "filesystem" of Oxios.
5
6use anyhow::{bail, Result};
7use chrono::{DateTime, Utc};
8use serde::{de::DeserializeOwned, Deserialize, Deserializer, Serialize, Serializer};
9use std::path::PathBuf;
10use tokio::fs;
11
12/// Unique identifier for a session.
13#[derive(Debug, Clone, PartialEq, Eq, Hash)]
14pub struct SessionId(pub String);
15
16impl SessionId {
17    /// Creates a new random session ID.
18    pub fn new() -> Self {
19        Self(uuid::Uuid::new_v4().to_string())
20    }
21}
22
23impl Default for SessionId {
24    fn default() -> Self {
25        Self::new()
26    }
27}
28
29impl std::fmt::Display for SessionId {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        write!(f, "{}", self.0)
32    }
33}
34
35impl Serialize for SessionId {
36    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
37    where
38        S: Serializer,
39    {
40        serializer.serialize_str(&self.0)
41    }
42}
43
44impl<'de> Deserialize<'de> for SessionId {
45    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
46    where
47        D: Deserializer<'de>,
48    {
49        let s = String::deserialize(deserializer)?;
50        Ok(Self(s))
51    }
52}
53
54/// A user message in a session.
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct UserMessage {
57    /// Message content.
58    pub content: String,
59    /// Timestamp when the message was sent.
60    pub timestamp: DateTime<Utc>,
61}
62
63/// An agent response in a session.
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct AgentResponse {
66    /// Response content.
67    pub content: String,
68    /// Session ID associated with this response.
69    pub session_id: Option<String>,
70    /// Seed ID used for this response (if any).
71    pub seed_id: Option<String>,
72    /// Phase reached during orchestration.
73    pub phase_reached: Option<String>,
74    /// Whether evaluation passed.
75    pub evaluation_passed: Option<bool>,
76    /// Timestamp when the response was generated.
77    pub timestamp: DateTime<Utc>,
78}
79
80/// Arbitrary key-value metadata for a session.
81pub type SessionMetadata = std::collections::HashMap<String, serde_json::Value>;
82
83/// A session represents a single user conversation.
84///
85/// Sessions track the full message history and metadata for
86/// a user conversation. They are created per user interaction
87/// and persisted for later retrieval.
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct Session {
90    /// Unique session identifier.
91    pub id: SessionId,
92    /// User ID who owns this session.
93    pub user_id: String,
94    /// All user messages in this session.
95    #[serde(default)]
96    pub user_messages: Vec<UserMessage>,
97    /// All agent responses in this session.
98    #[serde(default)]
99    pub agent_responses: Vec<AgentResponse>,
100    /// Currently active seed ID (if any).
101    #[serde(skip_serializing_if = "Option::is_none")]
102    pub active_seed_id: Option<String>,
103    /// Currently active persona ID (for future multi-persona support).
104    #[serde(skip_serializing_if = "Option::is_none")]
105    pub active_persona_id: Option<String>,
106    /// Timestamp when the session was created.
107    pub created_at: DateTime<Utc>,
108    /// Timestamp when the session was last updated.
109    pub updated_at: DateTime<Utc>,
110    /// Arbitrary key-value metadata.
111    #[serde(default)]
112    pub metadata: SessionMetadata,
113}
114
115impl Session {
116    /// Creates a new session for a user.
117    pub fn new(user_id: impl Into<String>) -> Self {
118        let now = Utc::now();
119        Self {
120            id: SessionId::new(),
121            user_id: user_id.into(),
122            user_messages: Vec::new(),
123            agent_responses: Vec::new(),
124            active_seed_id: None,
125            active_persona_id: None,
126            created_at: now,
127            updated_at: now,
128            metadata: SessionMetadata::new(),
129        }
130    }
131
132    /// Creates a session with a specific ID.
133    pub fn with_id(user_id: impl Into<String>, session_id: SessionId) -> Self {
134        let now = Utc::now();
135        Self {
136            id: session_id,
137            user_id: user_id.into(),
138            user_messages: Vec::new(),
139            agent_responses: Vec::new(),
140            active_seed_id: None,
141            active_persona_id: None,
142            created_at: now,
143            updated_at: now,
144            metadata: SessionMetadata::new(),
145        }
146    }
147
148    /// Adds a user message to the session.
149    pub fn add_user_message(&mut self, content: impl Into<String>) {
150        self.user_messages.push(UserMessage {
151            content: content.into(),
152            timestamp: Utc::now(),
153        });
154        self.updated_at = Utc::now();
155    }
156
157    /// Adds an agent response to the session.
158    pub fn add_agent_response(&mut self, response: AgentResponse) {
159        self.agent_responses.push(response);
160        self.updated_at = Utc::now();
161    }
162
163    /// Sets the active seed ID.
164    pub fn set_active_seed(&mut self, seed_id: Option<String>) {
165        self.active_seed_id = seed_id;
166        self.updated_at = Utc::now();
167    }
168
169    /// Sets the active persona ID.
170    pub fn set_active_persona(&mut self, persona_id: Option<String>) {
171        self.active_persona_id = persona_id;
172        self.updated_at = Utc::now();
173    }
174
175    /// Sets a metadata value.
176    pub fn set_metadata(&mut self, key: impl Into<String>, value: serde_json::Value) {
177        self.metadata.insert(key.into(), value);
178        self.updated_at = Utc::now();
179    }
180
181    /// Gets a metadata value.
182    pub fn get_metadata(&self, key: &str) -> Option<&serde_json::Value> {
183        self.metadata.get(key)
184    }
185
186    /// Returns the total number of exchanges in this session.
187    pub fn exchange_count(&self) -> usize {
188        self.user_messages.len().min(self.agent_responses.len())
189    }
190
191    /// Returns true if the session is empty (no messages).
192    pub fn is_empty(&self) -> bool {
193        self.user_messages.is_empty()
194    }
195}
196/// A filesystem-based persistent state store.
197///
198/// Files are organized as `<base_path>/<category>/<name>.md` or
199/// `<base_path>/<category>/<name>.json`.
200#[derive(Clone)]
201pub struct StateStore {
202    /// Root directory for all state files.
203    pub base_path: PathBuf,
204}
205
206impl StateStore {
207    /// Creates a new state store, initializing the directory if needed.
208    ///
209    /// # Example
210    ///
211    /// ```ignore
212    /// use oxios_kernel::StateStore;
213    /// use std::path::PathBuf;
214    ///
215    /// let store = StateStore::new(PathBuf::from("/tmp/oxios-state")).unwrap();
216    /// ```
217    pub fn new(base_path: PathBuf) -> Result<Self> {
218        Ok(Self { base_path })
219    }
220
221    /// Validate that a category name does not contain path traversal.
222    fn validate_category(category: &str) -> Result<()> {
223        if category.contains("..") || category.contains('\\') {
224            bail!("invalid category name: '{}'", category);
225        }
226        if category.is_empty()
227            || category.starts_with('/')
228            || category.ends_with('/')
229            || category.contains("//")
230        {
231            bail!("invalid category name: '{}'", category);
232        }
233        Ok(())
234    }
235
236    /// Validate that a file name does not contain path traversal.
237    fn validate_name(name: &str) -> Result<()> {
238        if name.contains("..") || name.contains('/') || name.contains('\\') {
239            bail!("invalid file name: '{}'", name);
240        }
241        Ok(())
242    }
243
244    /// Save a markdown file under the given category.
245    pub async fn save_markdown(&self, category: &str, name: &str, content: &str) -> Result<()> {
246        Self::validate_category(category)?;
247        Self::validate_name(name)?;
248        let dir = self.base_path.join(category);
249        fs::create_dir_all(&dir).await?;
250        let path = dir.join(format!("{name}.md"));
251
252        // Write to temp file first, then atomic rename
253        let temp_path = dir.join(format!("{name}.{}.tmp", std::process::id()));
254        fs::write(&temp_path, content).await?;
255        tokio::fs::rename(&temp_path, &path).await?;
256
257        Ok(())
258    }
259
260    /// Load a markdown file from the given category.
261    pub async fn load_markdown(&self, category: &str, name: &str) -> Result<Option<String>> {
262        Self::validate_category(category)?;
263        Self::validate_name(name)?;
264        let path = self.base_path.join(category).join(format!("{name}.md"));
265        match fs::read_to_string(&path).await {
266            Ok(content) => Ok(Some(content)),
267            Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None),
268            Err(e) => Err(e.into()),
269        }
270    }
271
272    /// List all markdown files in a category (names without extension).
273    pub async fn list_category(&self, category: &str) -> Result<Vec<String>> {
274        Self::validate_category(category)?;
275        let dir = self.base_path.join(category);
276        if !dir.exists() {
277            return Ok(Vec::new());
278        }
279        let mut entries = fs::read_dir(&dir).await?;
280        let mut names = Vec::new();
281        while let Some(entry) = entries.next_entry().await? {
282            let path = entry.path();
283            if let Some(ext) = path.extension() {
284                if ext == "md" || ext == "json" {
285                    if let Some(stem) = path.file_stem() {
286                        names.push(stem.to_string_lossy().into_owned());
287                    }
288                }
289            }
290        }
291        names.sort();
292        Ok(names)
293    }
294
295    /// Save a serializable value as JSON under the given category.
296    pub async fn save_json<T: Serialize>(
297        &self,
298        category: &str,
299        name: &str,
300        data: &T,
301    ) -> Result<()> {
302        Self::validate_category(category)?;
303        Self::validate_name(name)?;
304        let dir = self.base_path.join(category);
305        fs::create_dir_all(&dir).await?;
306        let path = dir.join(format!("{name}.json"));
307
308        let content = serde_json::to_string_pretty(data)?;
309
310        // Write to temp file first, then atomic rename
311        let temp_path = dir.join(format!("{name}.{}.tmp", std::process::id()));
312        fs::write(&temp_path, &content).await?;
313        tokio::fs::rename(&temp_path, &path).await?;
314
315        Ok(())
316    }
317
318    /// Load a deserializable value from JSON in the given category.
319    pub async fn load_json<T: DeserializeOwned>(
320        &self,
321        category: &str,
322        name: &str,
323    ) -> Result<Option<T>> {
324        Self::validate_category(category)?;
325        Self::validate_name(name)?;
326        let path = self.base_path.join(category).join(format!("{name}.json"));
327        match fs::read_to_string(&path).await {
328            Ok(content) => Ok(Some(serde_json::from_str(&content)?)),
329            Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None),
330            Err(e) => Err(e.into()),
331        }
332    }
333
334    /// Delete a file from the given category.
335    pub async fn delete_file(&self, category: &str, name: &str) -> Result<bool> {
336        Self::validate_category(category)?;
337        Self::validate_name(name)?;
338        let path = self.base_path.join(category).join(format!("{name}.json"));
339        if path.exists() {
340            tokio::fs::remove_file(path).await?;
341            Ok(true)
342        } else {
343            let path = self.base_path.join(category).join(format!("{name}.md"));
344            if path.exists() {
345                tokio::fs::remove_file(path).await?;
346                Ok(true)
347            } else {
348                Ok(false)
349            }
350        }
351    }
352}
353
354impl std::fmt::Debug for StateStore {
355    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
356        f.debug_struct("StateStore")
357            .field("base_path", &self.base_path)
358            .finish()
359    }
360}
361
362impl StateStore {
363    /// Saves a session to the sessions category.
364    pub async fn save_session(&self, session: &Session) -> Result<()> {
365        self.save_json("sessions", &session.id.0, session).await
366    }
367
368    /// Loads a session by ID.
369    pub async fn load_session(&self, session_id: &SessionId) -> Result<Option<Session>> {
370        self.load_json("sessions", &session_id.0).await
371    }
372
373    /// Lists all sessions (sorted by updated_at descending).
374    pub async fn list_sessions(&self) -> Result<Vec<SessionSummary>> {
375        let mut sessions = Vec::new();
376
377        if let Ok(names) = self.list_category("sessions").await {
378            for name in names {
379                if let Ok(Some(session)) = self.load_json::<Session>("sessions", &name).await {
380                    sessions.push(SessionSummary {
381                        id: session.id.0.clone(),
382                        user_id: session.user_id.clone(),
383                        message_count: session.user_messages.len(),
384                        active_seed_id: session.active_seed_id.clone(),
385                        created_at: session.created_at,
386                        updated_at: session.updated_at,
387                    });
388                }
389            }
390        }
391
392        // Sort by updated_at descending (most recent first)
393        sessions.sort_by_key(|b| std::cmp::Reverse(b.updated_at));
394        Ok(sessions)
395    }
396
397    /// Deletes a session by ID.
398    pub async fn delete_session(&self, session_id: &SessionId) -> Result<bool> {
399        let path = self
400            .base_path
401            .join("sessions")
402            .join(format!("{}.json", session_id.0));
403        match fs::remove_file(&path).await {
404            Ok(()) => {
405                tracing::info!(session_id = %session_id, "Session deleted");
406                Ok(true)
407            }
408            Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(false),
409            Err(e) => Err(e.into()),
410        }
411    }
412
413    /// Gets or creates a session for a user, initializing with the given session ID.
414    pub async fn get_or_create_session(
415        &self,
416        user_id: &str,
417        session_id: Option<&SessionId>,
418    ) -> Result<Session> {
419        if let Some(sid) = session_id {
420            if let Some(existing) = self.load_session(sid).await? {
421                return Ok(existing);
422            }
423        }
424
425        // Create new session
426        let session = match session_id {
427            Some(sid) => Session::with_id(user_id, sid.clone()),
428            None => Session::new(user_id),
429        };
430
431        self.save_session(&session).await?;
432        Ok(session)
433    }
434
435    /// Updates an existing session, saving it to disk.
436    pub async fn update_session(&self, session: &Session) -> Result<()> {
437        self.save_session(session).await
438    }
439}
440
441/// Summary of a session for listing (without full message history).
442#[derive(Debug, Clone, Serialize, Deserialize)]
443pub struct SessionSummary {
444    /// Session ID.
445    pub id: String,
446    /// User ID who owns this session.
447    pub user_id: String,
448    /// Number of messages in this session.
449    pub message_count: usize,
450    /// Active seed ID if any.
451    #[serde(skip_serializing_if = "Option::is_none")]
452    pub active_seed_id: Option<String>,
453    /// When the session was created.
454    pub created_at: DateTime<Utc>,
455    /// When the session was last updated.
456    pub updated_at: DateTime<Utc>,
457}
458
459#[cfg(test)]
460mod tests {
461    use super::*;
462
463    #[tokio::test]
464    async fn test_session_creation_and_persistence() {
465        let temp_dir = tempfile::tempdir().unwrap();
466        let store = StateStore::new(temp_dir.path().to_path_buf()).unwrap();
467
468        // Create a session
469        let mut session = Session::new("user-123");
470        session.add_user_message("Hello");
471
472        // Save and load
473        store.save_session(&session).await.unwrap();
474        let loaded = store.load_session(&session.id).await.unwrap();
475        assert!(loaded.is_some());
476        let loaded = loaded.unwrap();
477        assert_eq!(loaded.user_id, "user-123");
478        assert_eq!(loaded.user_messages.len(), 1);
479    }
480
481    #[tokio::test]
482    async fn test_session_list_sorts_by_updated() {
483        let temp_dir = tempfile::tempdir().unwrap();
484        let store = StateStore::new(temp_dir.path().to_path_buf()).unwrap();
485
486        // Create multiple sessions
487        for i in 0..3 {
488            let mut session = Session::new(&format!("user-{}", i));
489            session.add_user_message(&format!("Message {}", i));
490            store.save_session(&session).await.unwrap();
491            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
492        }
493
494        let sessions = store.list_sessions().await.unwrap();
495        assert_eq!(sessions.len(), 3);
496        // Most recently updated should be first
497        assert_eq!(sessions[0].user_id, "user-2");
498    }
499
500    #[tokio::test]
501    async fn test_delete_session() {
502        let temp_dir = tempfile::tempdir().unwrap();
503        let store = StateStore::new(temp_dir.path().to_path_buf()).unwrap();
504
505        let session = Session::new("user-123");
506        store.save_session(&session).await.unwrap();
507
508        // Delete and verify
509        let deleted = store.delete_session(&session.id).await.unwrap();
510        assert!(deleted);
511
512        let loaded = store.load_session(&session.id).await.unwrap();
513        assert!(loaded.is_none());
514    }
515
516    #[tokio::test]
517    async fn test_get_or_create_session_existing() {
518        let temp_dir = tempfile::tempdir().unwrap();
519        let store = StateStore::new(temp_dir.path().to_path_buf()).unwrap();
520
521        let mut existing = Session::new("user-123");
522        existing.add_user_message("Original message");
523        store.save_session(&existing).await.unwrap();
524
525        // Get or create with same ID should return existing
526        let retrieved = store
527            .get_or_create_session("user-123", Some(&existing.id))
528            .await
529            .unwrap();
530        assert_eq!(retrieved.id, existing.id);
531        assert_eq!(retrieved.user_messages.len(), 1);
532    }
533
534    #[tokio::test]
535    async fn test_get_or_create_session_new() {
536        let temp_dir = tempfile::tempdir().unwrap();
537        let store = StateStore::new(temp_dir.path().to_path_buf()).unwrap();
538
539        // Get or create without existing session should create new
540        let session = store.get_or_create_session("user-456", None).await.unwrap();
541        assert_eq!(session.user_id, "user-456");
542        assert!(session.user_messages.is_empty());
543    }
544}