chasm_cli/agency/
session.rs

1// Copyright (c) 2024-2026 Nervosys LLC
2// SPDX-License-Identifier: Apache-2.0
3//! Session Management
4//!
5//! Manages conversation sessions with persistent state.
6
7#![allow(dead_code)]
8
9use crate::agency::error::{AgencyError, AgencyResult};
10use crate::agency::models::{AgencyMessage, MessageRole, TokenUsage};
11use chrono::{DateTime, Utc};
12use rusqlite::{params, Connection, OptionalExtension};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::path::Path;
16use std::sync::{Arc, Mutex};
17
18/// Session state (JSON-serializable key-value store)
19#[derive(Debug, Clone, Default, Serialize, Deserialize)]
20pub struct SessionState {
21    /// Arbitrary state data
22    pub data: HashMap<String, serde_json::Value>,
23}
24
25impl SessionState {
26    pub fn new() -> Self {
27        Self::default()
28    }
29
30    pub fn get<T: for<'de> Deserialize<'de>>(&self, key: &str) -> Option<T> {
31        self.data
32            .get(key)
33            .and_then(|v| serde_json::from_value(v.clone()).ok())
34    }
35
36    pub fn set<T: Serialize>(&mut self, key: impl Into<String>, value: T) {
37        if let Ok(v) = serde_json::to_value(value) {
38            self.data.insert(key.into(), v);
39        }
40    }
41
42    pub fn remove(&mut self, key: &str) -> Option<serde_json::Value> {
43        self.data.remove(key)
44    }
45
46    pub fn contains(&self, key: &str) -> bool {
47        self.data.contains_key(key)
48    }
49
50    pub fn clear(&mut self) {
51        self.data.clear();
52    }
53}
54
55/// A conversation session
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct Session {
58    /// Unique session ID
59    pub id: String,
60    /// Associated agent name
61    pub agent_name: String,
62    /// User ID (optional)
63    #[serde(default)]
64    pub user_id: Option<String>,
65    /// Session title/name
66    #[serde(default)]
67    pub title: Option<String>,
68    /// Conversation messages
69    pub messages: Vec<AgencyMessage>,
70    /// Session state
71    #[serde(default)]
72    pub state: SessionState,
73    /// Token usage
74    #[serde(default)]
75    pub token_usage: TokenUsage,
76    /// Creation timestamp
77    pub created_at: DateTime<Utc>,
78    /// Last update timestamp
79    pub updated_at: DateTime<Utc>,
80    /// Custom metadata
81    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
82    pub metadata: HashMap<String, serde_json::Value>,
83}
84
85impl Session {
86    /// Create a new session
87    pub fn new(agent_name: impl Into<String>, user_id: Option<String>) -> Self {
88        let now = Utc::now();
89        Self {
90            id: generate_session_id(),
91            agent_name: agent_name.into(),
92            user_id,
93            title: None,
94            messages: Vec::new(),
95            state: SessionState::new(),
96            token_usage: TokenUsage::default(),
97            created_at: now,
98            updated_at: now,
99            metadata: HashMap::new(),
100        }
101    }
102
103    /// Add a message to the session
104    pub fn add_message(&mut self, message: AgencyMessage) {
105        if let Some(tokens) = message.tokens {
106            self.token_usage.total_tokens += tokens;
107            match message.role {
108                MessageRole::User | MessageRole::System => {
109                    self.token_usage.prompt_tokens += tokens;
110                }
111                MessageRole::Assistant | MessageRole::Tool => {
112                    self.token_usage.completion_tokens += tokens;
113                }
114            }
115        }
116        self.messages.push(message);
117        self.updated_at = Utc::now();
118    }
119
120    /// Get messages formatted for model API
121    pub fn to_api_messages(&self) -> Vec<serde_json::Value> {
122        self.messages
123            .iter()
124            .map(|m| {
125                serde_json::json!({
126                    "role": m.role.to_string(),
127                    "content": m.content
128                })
129            })
130            .collect()
131    }
132
133    /// Get the last N messages
134    pub fn last_messages(&self, n: usize) -> &[AgencyMessage] {
135        let start = self.messages.len().saturating_sub(n);
136        &self.messages[start..]
137    }
138
139    /// Clear messages but keep state
140    pub fn clear_messages(&mut self) {
141        self.messages.clear();
142        self.token_usage = TokenUsage::default();
143        self.updated_at = Utc::now();
144    }
145
146    /// Rewind to before a specific message
147    pub fn rewind_to(&mut self, message_id: &str) -> Option<Vec<AgencyMessage>> {
148        if let Some(pos) = self.messages.iter().position(|m| m.id == message_id) {
149            let removed: Vec<_> = self.messages.drain(pos..).collect();
150            self.updated_at = Utc::now();
151            // Recalculate token usage
152            self.recalculate_tokens();
153            Some(removed)
154        } else {
155            None
156        }
157    }
158
159    fn recalculate_tokens(&mut self) {
160        let mut usage = TokenUsage::default();
161        for m in &self.messages {
162            if let Some(tokens) = m.tokens {
163                usage.total_tokens += tokens;
164                match m.role {
165                    MessageRole::User | MessageRole::System => {
166                        usage.prompt_tokens += tokens;
167                    }
168                    MessageRole::Assistant | MessageRole::Tool => {
169                        usage.completion_tokens += tokens;
170                    }
171                }
172            }
173        }
174        self.token_usage = usage;
175    }
176}
177
178/// Generate a unique session ID
179fn generate_session_id() -> String {
180    format!(
181        "session-{}-{}",
182        Utc::now().timestamp_millis(),
183        &uuid::Uuid::new_v4().to_string()[..8]
184    )
185}
186
187/// Generate a unique message ID
188pub fn generate_message_id() -> String {
189    format!(
190        "msg-{}-{}",
191        Utc::now().timestamp_millis(),
192        &uuid::Uuid::new_v4().to_string()[..8]
193    )
194}
195
196/// Session manager with SQLite persistence
197pub struct SessionManager {
198    conn: Arc<Mutex<Connection>>,
199}
200
201impl SessionManager {
202    /// Create a new session manager with the given database path
203    pub fn new(db_path: impl AsRef<Path>) -> AgencyResult<Self> {
204        let conn = Connection::open(db_path)?;
205        let manager = Self {
206            conn: Arc::new(Mutex::new(conn)),
207        };
208        manager.init_schema()?;
209        Ok(manager)
210    }
211
212    /// Create an in-memory session manager (for testing)
213    pub fn in_memory() -> AgencyResult<Self> {
214        let conn = Connection::open_in_memory()?;
215        let manager = Self {
216            conn: Arc::new(Mutex::new(conn)),
217        };
218        manager.init_schema()?;
219        Ok(manager)
220    }
221
222    /// Initialize database schema
223    fn init_schema(&self) -> AgencyResult<()> {
224        let conn = self
225            .conn
226            .lock()
227            .map_err(|e| AgencyError::DatabaseError(e.to_string()))?;
228        conn.execute_batch(
229            r#"
230            CREATE TABLE IF NOT EXISTS Agency_sessions (
231                id TEXT PRIMARY KEY,
232                agent_name TEXT NOT NULL,
233                user_id TEXT,
234                title TEXT,
235                messages TEXT NOT NULL,
236                state TEXT NOT NULL,
237                token_usage TEXT NOT NULL,
238                metadata TEXT,
239                created_at TEXT NOT NULL,
240                updated_at TEXT NOT NULL
241            );
242
243            CREATE INDEX IF NOT EXISTS idx_Agency_sessions_agent ON Agency_sessions(agent_name);
244            CREATE INDEX IF NOT EXISTS idx_Agency_sessions_user ON Agency_sessions(user_id);
245            CREATE INDEX IF NOT EXISTS idx_Agency_sessions_updated ON Agency_sessions(updated_at DESC);
246            "#,
247        )?;
248        Ok(())
249    }
250
251    /// Create a new session
252    pub fn create(
253        &self,
254        agent_name: impl Into<String>,
255        user_id: Option<String>,
256    ) -> AgencyResult<Session> {
257        let session = Session::new(agent_name, user_id);
258        self.save(&session)?;
259        Ok(session)
260    }
261
262    /// Save a session
263    pub fn save(&self, session: &Session) -> AgencyResult<()> {
264        let conn = self
265            .conn
266            .lock()
267            .map_err(|e| AgencyError::DatabaseError(e.to_string()))?;
268        conn.execute(
269            r#"
270            INSERT OR REPLACE INTO Agency_sessions 
271            (id, agent_name, user_id, title, messages, state, token_usage, metadata, created_at, updated_at)
272            VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)
273            "#,
274            params![
275                session.id,
276                session.agent_name,
277                session.user_id,
278                session.title,
279                serde_json::to_string(&session.messages)?,
280                serde_json::to_string(&session.state)?,
281                serde_json::to_string(&session.token_usage)?,
282                serde_json::to_string(&session.metadata)?,
283                session.created_at.to_rfc3339(),
284                session.updated_at.to_rfc3339(),
285            ],
286        )?;
287        Ok(())
288    }
289
290    /// Get a session by ID
291    pub fn get(&self, id: &str) -> AgencyResult<Option<Session>> {
292        let conn = self
293            .conn
294            .lock()
295            .map_err(|e| AgencyError::DatabaseError(e.to_string()))?;
296        let session = conn
297            .query_row(
298                "SELECT * FROM Agency_sessions WHERE id = ?1",
299                params![id],
300                |row| {
301                    Ok(Session {
302                        id: row.get(0)?,
303                        agent_name: row.get(1)?,
304                        user_id: row.get(2)?,
305                        title: row.get(3)?,
306                        messages: serde_json::from_str(&row.get::<_, String>(4)?)
307                            .unwrap_or_default(),
308                        state: serde_json::from_str(&row.get::<_, String>(5)?).unwrap_or_default(),
309                        token_usage: serde_json::from_str(&row.get::<_, String>(6)?)
310                            .unwrap_or_default(),
311                        metadata: serde_json::from_str(&row.get::<_, String>(7)?)
312                            .unwrap_or_default(),
313                        created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(8)?)
314                            .map(|dt| dt.with_timezone(&Utc))
315                            .unwrap_or_else(|_| Utc::now()),
316                        updated_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(9)?)
317                            .map(|dt| dt.with_timezone(&Utc))
318                            .unwrap_or_else(|_| Utc::now()),
319                    })
320                },
321            )
322            .optional()?;
323        Ok(session)
324    }
325
326    /// List sessions for an agent
327    pub fn list_by_agent(
328        &self,
329        agent_name: &str,
330        limit: Option<u32>,
331    ) -> AgencyResult<Vec<Session>> {
332        let conn = self
333            .conn
334            .lock()
335            .map_err(|e| AgencyError::DatabaseError(e.to_string()))?;
336        let limit = limit.unwrap_or(100);
337        let mut stmt = conn.prepare(
338            "SELECT * FROM Agency_sessions WHERE agent_name = ?1 ORDER BY updated_at DESC LIMIT ?2",
339        )?;
340        let sessions = stmt
341            .query_map(params![agent_name, limit], |row| {
342                Ok(Session {
343                    id: row.get(0)?,
344                    agent_name: row.get(1)?,
345                    user_id: row.get(2)?,
346                    title: row.get(3)?,
347                    messages: serde_json::from_str(&row.get::<_, String>(4)?).unwrap_or_default(),
348                    state: serde_json::from_str(&row.get::<_, String>(5)?).unwrap_or_default(),
349                    token_usage: serde_json::from_str(&row.get::<_, String>(6)?)
350                        .unwrap_or_default(),
351                    metadata: serde_json::from_str(&row.get::<_, String>(7)?).unwrap_or_default(),
352                    created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(8)?)
353                        .map(|dt| dt.with_timezone(&Utc))
354                        .unwrap_or_else(|_| Utc::now()),
355                    updated_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(9)?)
356                        .map(|dt| dt.with_timezone(&Utc))
357                        .unwrap_or_else(|_| Utc::now()),
358                })
359            })?
360            .filter_map(|r| r.ok())
361            .collect();
362        Ok(sessions)
363    }
364
365    /// List sessions for a user
366    pub fn list_by_user(&self, user_id: &str, limit: Option<u32>) -> AgencyResult<Vec<Session>> {
367        let conn = self
368            .conn
369            .lock()
370            .map_err(|e| AgencyError::DatabaseError(e.to_string()))?;
371        let limit = limit.unwrap_or(100);
372        let mut stmt = conn.prepare(
373            "SELECT * FROM Agency_sessions WHERE user_id = ?1 ORDER BY updated_at DESC LIMIT ?2",
374        )?;
375        let sessions = stmt
376            .query_map(params![user_id, limit], |row| {
377                Ok(Session {
378                    id: row.get(0)?,
379                    agent_name: row.get(1)?,
380                    user_id: row.get(2)?,
381                    title: row.get(3)?,
382                    messages: serde_json::from_str(&row.get::<_, String>(4)?).unwrap_or_default(),
383                    state: serde_json::from_str(&row.get::<_, String>(5)?).unwrap_or_default(),
384                    token_usage: serde_json::from_str(&row.get::<_, String>(6)?)
385                        .unwrap_or_default(),
386                    metadata: serde_json::from_str(&row.get::<_, String>(7)?).unwrap_or_default(),
387                    created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(8)?)
388                        .map(|dt| dt.with_timezone(&Utc))
389                        .unwrap_or_else(|_| Utc::now()),
390                    updated_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(9)?)
391                        .map(|dt| dt.with_timezone(&Utc))
392                        .unwrap_or_else(|_| Utc::now()),
393                })
394            })?
395            .filter_map(|r| r.ok())
396            .collect();
397        Ok(sessions)
398    }
399
400    /// Delete a session
401    pub fn delete(&self, id: &str) -> AgencyResult<bool> {
402        let conn = self
403            .conn
404            .lock()
405            .map_err(|e| AgencyError::DatabaseError(e.to_string()))?;
406        let rows = conn.execute("DELETE FROM Agency_sessions WHERE id = ?1", params![id])?;
407        Ok(rows > 0)
408    }
409}
410
411#[cfg(test)]
412mod tests {
413    use super::*;
414
415    #[test]
416    fn test_session_state() {
417        let mut state = SessionState::new();
418        state.set("count", 42);
419        state.set("name", "test");
420
421        assert_eq!(state.get::<i32>("count"), Some(42));
422        assert_eq!(state.get::<String>("name"), Some("test".to_string()));
423        assert!(state.contains("count"));
424        assert!(!state.contains("missing"));
425    }
426
427    #[test]
428    fn test_session_messages() {
429        let mut session = Session::new("test_agent", None);
430        session.add_message(AgencyMessage {
431            id: "msg1".to_string(),
432            role: MessageRole::User,
433            content: "Hello".to_string(),
434            tool_calls: vec![],
435            tool_result: None,
436            timestamp: Utc::now(),
437            tokens: Some(5),
438            agent_name: None,
439            metadata: HashMap::new(),
440        });
441
442        assert_eq!(session.messages.len(), 1);
443        assert_eq!(session.token_usage.prompt_tokens, 5);
444    }
445
446    #[test]
447    fn test_session_manager() -> AgencyResult<()> {
448        let manager = SessionManager::in_memory()?;
449        let session = manager.create("test_agent", Some("user1".to_string()))?;
450
451        let loaded = manager.get(&session.id)?;
452        assert!(loaded.is_some());
453        assert_eq!(loaded.unwrap().agent_name, "test_agent");
454
455        let sessions = manager.list_by_agent("test_agent", None)?;
456        assert_eq!(sessions.len(), 1);
457
458        manager.delete(&session.id)?;
459        let deleted = manager.get(&session.id)?;
460        assert!(deleted.is_none());
461
462        Ok(())
463    }
464}