Skip to main content

astrid_runtime/
session.rs

1//! Agent session management.
2//!
3//! Sessions track conversation state, capabilities, and context.
4
5use astrid_approval::allowance::Allowance;
6use astrid_approval::budget::{BudgetSnapshot, BudgetTracker, WorkspaceBudgetTracker};
7use astrid_approval::{AllowanceStore, ApprovalManager, DeferredResolutionStore};
8use astrid_capabilities::CapabilityStore;
9use astrid_core::SessionId;
10use astrid_llm::Message;
11use astrid_workspace::escape::{EscapeHandler, EscapeState};
12use chrono::{DateTime, Utc};
13use serde::{Deserialize, Serialize};
14use std::path::PathBuf;
15use std::sync::Arc;
16
17/// An agent session.
18#[derive(Debug)]
19pub struct AgentSession {
20    /// Unique session identifier.
21    pub id: SessionId,
22    /// User identifier (key ID).
23    pub user_id: [u8; 8],
24    /// Conversation messages.
25    pub messages: Vec<Message>,
26    /// Session capabilities.
27    pub capabilities: Arc<CapabilityStore>,
28    /// Session allowance store.
29    pub allowance_store: Arc<AllowanceStore>,
30    /// Session approval manager.
31    pub approval_manager: Arc<ApprovalManager>,
32    /// System prompt.
33    pub system_prompt: String,
34    /// When the session was created.
35    pub created_at: DateTime<Utc>,
36    /// Estimated token count.
37    pub token_count: usize,
38    /// Session metadata.
39    pub metadata: SessionMetadata,
40    /// Workspace escape handler for tracking allowed paths.
41    pub escape_handler: EscapeHandler,
42    /// Per-session budget tracker.
43    pub budget_tracker: Arc<BudgetTracker>,
44    /// Workspace cumulative budget tracker (shared across sessions).
45    pub workspace_budget_tracker: Option<Arc<WorkspaceBudgetTracker>>,
46    /// Workspace path this session belongs to (for workspace-scoped listing).
47    pub workspace_path: Option<PathBuf>,
48    /// Model used for this session (e.g. `"claude-sonnet-4-20250514"`).
49    pub model: Option<String>,
50}
51
52impl AgentSession {
53    /// Create a new session.
54    #[must_use]
55    pub fn new(user_id: [u8; 8], system_prompt: impl Into<String>) -> Self {
56        let allowance_store = Arc::new(AllowanceStore::new());
57        let deferred_queue = Arc::new(DeferredResolutionStore::new());
58        let approval_manager = Arc::new(ApprovalManager::new(
59            Arc::clone(&allowance_store),
60            deferred_queue,
61        ));
62        Self {
63            id: SessionId::new(),
64            user_id,
65            messages: Vec::new(),
66            capabilities: Arc::new(CapabilityStore::in_memory()),
67            allowance_store,
68            approval_manager,
69            system_prompt: system_prompt.into(),
70            created_at: Utc::now(),
71            token_count: 0,
72            metadata: SessionMetadata::default(),
73            escape_handler: EscapeHandler::new(),
74            budget_tracker: Arc::new(BudgetTracker::default()),
75            workspace_budget_tracker: None,
76            workspace_path: None,
77            model: None,
78        }
79    }
80
81    /// Create with a specific session ID.
82    #[must_use]
83    pub fn with_id(id: SessionId, user_id: [u8; 8], system_prompt: impl Into<String>) -> Self {
84        let allowance_store = Arc::new(AllowanceStore::new());
85        let deferred_queue = Arc::new(DeferredResolutionStore::new());
86        let approval_manager = Arc::new(ApprovalManager::new(
87            Arc::clone(&allowance_store),
88            deferred_queue,
89        ));
90        Self {
91            id,
92            user_id,
93            messages: Vec::new(),
94            capabilities: Arc::new(CapabilityStore::in_memory()),
95            allowance_store,
96            approval_manager,
97            system_prompt: system_prompt.into(),
98            created_at: Utc::now(),
99            token_count: 0,
100            metadata: SessionMetadata::default(),
101            escape_handler: EscapeHandler::new(),
102            budget_tracker: Arc::new(BudgetTracker::default()),
103            workspace_budget_tracker: None,
104            workspace_path: None,
105            model: None,
106        }
107    }
108
109    /// Create a child session that shares parent's stores.
110    ///
111    /// The child inherits the parent's `AllowanceStore`, `CapabilityStore`, and
112    /// `BudgetTracker` (same `Arc` — spend is visible bidirectionally). The
113    /// `ApprovalManager` and `DeferredResolutionStore` are fresh (independent
114    /// handler registration and independent deferred queue).
115    #[must_use]
116    pub fn with_shared_stores(
117        id: SessionId,
118        user_id: [u8; 8],
119        system_prompt: impl Into<String>,
120        allowance_store: Arc<AllowanceStore>,
121        capabilities: Arc<CapabilityStore>,
122        budget_tracker: Arc<BudgetTracker>,
123    ) -> Self {
124        let deferred_queue = Arc::new(DeferredResolutionStore::new());
125        let approval_manager = Arc::new(ApprovalManager::new(
126            Arc::clone(&allowance_store),
127            deferred_queue,
128        ));
129        Self {
130            id,
131            user_id,
132            messages: Vec::new(),
133            capabilities,
134            allowance_store,
135            approval_manager,
136            system_prompt: system_prompt.into(),
137            created_at: Utc::now(),
138            token_count: 0,
139            metadata: SessionMetadata::default(),
140            escape_handler: EscapeHandler::new(),
141            budget_tracker,
142            workspace_budget_tracker: None,
143            workspace_path: None,
144            model: None,
145        }
146    }
147
148    /// Set the workspace path for this session.
149    #[must_use]
150    pub fn with_workspace(mut self, path: impl Into<PathBuf>) -> Self {
151        self.workspace_path = Some(path.into());
152        self
153    }
154
155    /// Set the model name for this session.
156    #[must_use]
157    pub fn with_model(mut self, model: impl Into<String>) -> Self {
158        self.model = Some(model.into());
159        self
160    }
161
162    /// Replace the capability store with a persistent one.
163    ///
164    /// Call this after session construction when a persistent store is available
165    /// (e.g. at daemon startup).
166    #[must_use]
167    pub fn with_capability_store(mut self, store: Arc<CapabilityStore>) -> Self {
168        self.capabilities = store;
169        self
170    }
171
172    /// Set the workspace cumulative budget tracker.
173    #[must_use]
174    pub fn with_workspace_budget(mut self, tracker: Arc<WorkspaceBudgetTracker>) -> Self {
175        self.workspace_budget_tracker = Some(tracker);
176        self
177    }
178
179    /// Import workspace-scoped allowances into this session.
180    ///
181    /// These allowances were previously persisted in the workspace `state.db`
182    /// and are loaded when a session is created or resumed in the same workspace.
183    pub fn import_workspace_allowances(
184        &self,
185        allowances: Vec<astrid_approval::allowance::Allowance>,
186    ) {
187        self.allowance_store.import_allowances(allowances);
188    }
189
190    /// Export workspace-scoped allowances from this session for persistence.
191    #[must_use]
192    pub fn export_workspace_allowances(&self) -> Vec<astrid_approval::allowance::Allowance> {
193        self.allowance_store.export_workspace_allowances()
194    }
195
196    /// Replace the deferred resolution queue with a persistent one.
197    ///
198    /// This reconstructs the `ApprovalManager` with the new persistent queue.
199    /// Call this after session construction when a persistent store is available.
200    ///
201    /// # Errors
202    ///
203    /// Returns an error if the persistent store cannot be initialized.
204    pub async fn with_persistent_deferred_queue(
205        mut self,
206        store: astrid_storage::ScopedKvStore,
207    ) -> Result<Self, astrid_core::error::SecurityError> {
208        let deferred_queue = Arc::new(DeferredResolutionStore::with_persistence(store).await?);
209        self.approval_manager = Arc::new(ApprovalManager::new(
210            Arc::clone(&self.allowance_store),
211            deferred_queue,
212        ));
213        Ok(self)
214    }
215
216    /// Add a message to the session.
217    pub fn add_message(&mut self, message: Message) {
218        // Rough token estimate (4 chars per token). This is a heuristic for
219        // context-limit warnings, not billing. Real context windows top out at
220        // ~200K tokens, so overflow of usize is not a practical concern.
221        let msg_tokens = match &message.content {
222            astrid_llm::MessageContent::Text(t) => t.len() / 4,
223            _ => 100, // Rough estimate for tool calls
224        };
225        self.token_count = self.token_count.saturating_add(msg_tokens);
226        self.messages.push(message);
227    }
228
229    /// Get the last N messages.
230    #[must_use]
231    pub fn last_messages(&self, n: usize) -> &[Message] {
232        let start = self.messages.len().saturating_sub(n);
233        &self.messages[start..]
234    }
235
236    /// Clear messages (keeping system prompt).
237    pub fn clear_messages(&mut self) {
238        self.messages.clear();
239        self.token_count = 0;
240    }
241
242    /// Get session duration.
243    #[must_use]
244    pub fn duration(&self) -> chrono::Duration {
245        // Safety: current time is always >= created_at
246        #[allow(clippy::arithmetic_side_effects)]
247        {
248            Utc::now() - self.created_at
249        }
250    }
251
252    /// Clean up session-scoped state.
253    ///
254    /// Clears session-only allowances, leaving workspace and persistent ones intact.
255    pub fn end_session(&self) {
256        self.allowance_store.clear_session_allowances();
257    }
258
259    /// Check if session is near context limit.
260    #[must_use]
261    #[allow(clippy::cast_precision_loss)]
262    pub fn is_near_limit(&self, max_tokens: usize, threshold: f32) -> bool {
263        self.token_count as f32 > max_tokens as f32 * threshold
264    }
265}
266
267/// Session metadata.
268#[derive(Debug, Clone, Default, Serialize, Deserialize)]
269pub struct SessionMetadata {
270    /// Session title (generated or user-provided).
271    pub title: Option<String>,
272    /// Tags for organization.
273    pub tags: Vec<String>,
274    /// Number of turns.
275    pub turn_count: usize,
276    /// Number of tool calls.
277    pub tool_call_count: usize,
278    /// Number of approvals granted.
279    pub approval_count: usize,
280    /// Custom key-value metadata.
281    pub custom: std::collections::HashMap<String, String>,
282}
283
284/// Serializable session state (for persistence).
285///
286/// Includes full security state: allowances, budget snapshot, escape handler
287/// state, and workspace path. This ensures "Allow Session" approvals,
288/// budget spend, and escape decisions survive daemon restarts.
289#[derive(Debug, Clone, Serialize, Deserialize)]
290pub struct SerializableSession {
291    /// Session ID.
292    pub id: String,
293    /// User ID (hex).
294    pub user_id: String,
295    /// Messages.
296    pub messages: Vec<SerializableMessage>,
297    /// System prompt.
298    pub system_prompt: String,
299    /// Created at.
300    pub created_at: DateTime<Utc>,
301    /// Token count.
302    pub token_count: usize,
303    /// Metadata.
304    pub metadata: SessionMetadata,
305    /// Session allowances (persisted so "Allow Session" survives restart).
306    #[serde(default)]
307    pub allowances: Vec<Allowance>,
308    /// Budget snapshot (persisted so budget is not reset on restart).
309    #[serde(default)]
310    pub budget_snapshot: Option<BudgetSnapshot>,
311    /// Escape handler state (persisted so "Allow Always" paths survive).
312    #[serde(default)]
313    pub escape_state: Option<EscapeState>,
314    /// Workspace path this session belongs to.
315    #[serde(default)]
316    pub workspace_path: Option<String>,
317    /// Model used for this session (e.g. "claude-sonnet-4-20250514").
318    #[serde(default)]
319    pub model: Option<String>,
320    /// Git state placeholder (branch, commit hash) for future worktree support.
321    #[serde(default)]
322    pub git_state: Option<GitState>,
323}
324
325/// Git repository state snapshot.
326#[derive(Debug, Clone, Serialize, Deserialize)]
327pub struct GitState {
328    /// Current branch name.
329    pub branch: Option<String>,
330    /// Current commit hash.
331    pub commit: Option<String>,
332}
333
334impl GitState {
335    /// Capture the current git state for a workspace path.
336    ///
337    /// Returns `None` if the path is not in a git repository or git is not available.
338    #[must_use]
339    pub fn capture(workspace_path: &std::path::Path) -> Option<Self> {
340        let branch = std::process::Command::new("git")
341            .args([
342                "-C",
343                &workspace_path.display().to_string(),
344                "rev-parse",
345                "--abbrev-ref",
346                "HEAD",
347            ])
348            .stdout(std::process::Stdio::piped())
349            .stderr(std::process::Stdio::null())
350            .output()
351            .ok()
352            .filter(|o| o.status.success())
353            .and_then(|o| String::from_utf8(o.stdout).ok())
354            .map(|s| s.trim().to_string());
355
356        let commit = std::process::Command::new("git")
357            .args([
358                "-C",
359                &workspace_path.display().to_string(),
360                "rev-parse",
361                "HEAD",
362            ])
363            .stdout(std::process::Stdio::piped())
364            .stderr(std::process::Stdio::null())
365            .output()
366            .ok()
367            .filter(|o| o.status.success())
368            .and_then(|o| String::from_utf8(o.stdout).ok())
369            .map(|s| s.trim().to_string());
370
371        // Only return Some if at least one field was captured.
372        if branch.is_some() || commit.is_some() {
373            Some(Self { branch, commit })
374        } else {
375            None
376        }
377    }
378}
379
380/// Serializable message.
381#[derive(Debug, Clone, Serialize, Deserialize)]
382pub struct SerializableMessage {
383    /// Role.
384    pub role: String,
385    /// Content (JSON).
386    pub content: serde_json::Value,
387}
388
389impl From<&AgentSession> for SerializableSession {
390    fn from(session: &AgentSession) -> Self {
391        Self {
392            id: session.id.0.to_string(),
393            user_id: hex::encode(session.user_id),
394            messages: session
395                .messages
396                .iter()
397                .map(|m| SerializableMessage {
398                    role: match m.role {
399                        astrid_llm::MessageRole::System => "system".to_string(),
400                        astrid_llm::MessageRole::User => "user".to_string(),
401                        astrid_llm::MessageRole::Assistant => "assistant".to_string(),
402                        astrid_llm::MessageRole::Tool => "tool".to_string(),
403                    },
404                    content: serde_json::to_value(&m.content).unwrap_or_default(),
405                })
406                .collect(),
407            system_prompt: session.system_prompt.clone(),
408            created_at: session.created_at,
409            token_count: session.token_count,
410            metadata: session.metadata.clone(),
411            allowances: session.allowance_store.export_session_allowances(),
412            budget_snapshot: Some(session.budget_tracker.snapshot()),
413            escape_state: Some(session.escape_handler.export_state()),
414            workspace_path: session
415                .workspace_path
416                .as_ref()
417                .map(|p| p.display().to_string()),
418            model: session.model.clone(),
419            git_state: session
420                .workspace_path
421                .as_ref()
422                .and_then(|p| GitState::capture(p)),
423        }
424    }
425}
426
427impl SerializableSession {
428    /// Convert back to an `AgentSession`.
429    ///
430    /// Restores full security state: allowances, budget, and escape handler.
431    #[must_use]
432    pub fn to_session(&self) -> AgentSession {
433        let mut user_id = [0u8; 8];
434        if let Ok(bytes) = hex::decode(&self.user_id)
435            && bytes.len() >= 8
436        {
437            user_id.copy_from_slice(&bytes[..8]);
438        }
439
440        let id =
441            uuid::Uuid::parse_str(&self.id).map_or_else(|_| SessionId::new(), SessionId::from_uuid);
442
443        let messages: Vec<Message> = self
444            .messages
445            .iter()
446            .filter_map(|m| {
447                let content: astrid_llm::MessageContent =
448                    serde_json::from_value(m.content.clone()).ok()?;
449                let role = match m.role.as_str() {
450                    "system" => astrid_llm::MessageRole::System,
451                    "user" => astrid_llm::MessageRole::User,
452                    "assistant" => astrid_llm::MessageRole::Assistant,
453                    "tool" => astrid_llm::MessageRole::Tool,
454                    _ => return None,
455                };
456                Some(Message { role, content })
457            })
458            .collect();
459
460        let mut session = AgentSession::with_id(id, user_id, &self.system_prompt);
461        session.messages = messages;
462        session.created_at = self.created_at;
463        session.token_count = self.token_count;
464        session.metadata = self.metadata.clone();
465        session.workspace_path = self.workspace_path.as_ref().map(PathBuf::from);
466        session.model.clone_from(&self.model);
467
468        // Restore session allowances
469        if !self.allowances.is_empty() {
470            session
471                .allowance_store
472                .import_allowances(self.allowances.clone());
473        }
474
475        // Restore budget from snapshot (prevents budget bypass via restart)
476        if let Some(snapshot) = &self.budget_snapshot {
477            session.budget_tracker = Arc::new(BudgetTracker::restore(snapshot.clone()));
478        }
479
480        // Restore escape handler state (preserves "AllowAlways" paths)
481        if let Some(escape_state) = &self.escape_state {
482            session.escape_handler.restore_state(escape_state.clone());
483        }
484
485        session
486    }
487}
488
489#[cfg(test)]
490mod tests {
491    use super::*;
492    use astrid_llm::Message;
493
494    #[test]
495    fn test_session_creation() {
496        let session = AgentSession::new([0u8; 8], "You are helpful");
497        assert!(session.messages.is_empty());
498        assert_eq!(session.system_prompt, "You are helpful");
499    }
500
501    #[test]
502    fn test_add_message() {
503        let mut session = AgentSession::new([0u8; 8], "");
504        session.add_message(Message::user("Hello"));
505        session.add_message(Message::assistant("Hi!"));
506
507        assert_eq!(session.messages.len(), 2);
508        assert!(session.token_count > 0);
509    }
510
511    #[test]
512    fn test_serialization_roundtrip() {
513        let mut session = AgentSession::new([1u8; 8], "Test prompt");
514        session.add_message(Message::user("Hello"));
515        session.add_message(Message::assistant("World"));
516
517        let serializable = SerializableSession::from(&session);
518        let restored = serializable.to_session();
519
520        assert_eq!(restored.system_prompt, session.system_prompt);
521        assert_eq!(restored.messages.len(), session.messages.len());
522    }
523
524    #[test]
525    fn test_budget_snapshot_roundtrip() {
526        let session = AgentSession::new([1u8; 8], "Test");
527        session.budget_tracker.record_cost(42.5);
528
529        let serializable = SerializableSession::from(&session);
530        let restored = serializable.to_session();
531
532        assert!((restored.budget_tracker.spent() - 42.5).abs() < f64::EPSILON);
533    }
534
535    #[test]
536    fn test_workspace_path_roundtrip() {
537        let session = AgentSession::new([1u8; 8], "Test").with_workspace("/home/user/project");
538
539        let serializable = SerializableSession::from(&session);
540        let restored = serializable.to_session();
541
542        assert_eq!(
543            restored.workspace_path,
544            Some(PathBuf::from("/home/user/project"))
545        );
546    }
547
548    #[test]
549    fn test_with_shared_stores() {
550        let parent = AgentSession::new([1u8; 8], "Parent");
551
552        // Record some spend on the parent
553        parent.budget_tracker.record_cost(10.0);
554
555        // Create child with shared stores
556        let child = AgentSession::with_shared_stores(
557            SessionId::new(),
558            [1u8; 8],
559            "Child",
560            Arc::clone(&parent.allowance_store),
561            Arc::clone(&parent.capabilities),
562            Arc::clone(&parent.budget_tracker),
563        );
564
565        // Budget spend is visible from child
566        assert!((child.budget_tracker.spent() - 10.0).abs() < f64::EPSILON);
567
568        // Child spend is visible from parent (same Arc)
569        child.budget_tracker.record_cost(5.0);
570        assert!((parent.budget_tracker.spent() - 15.0).abs() < f64::EPSILON);
571
572        // Stores are the same Arc
573        assert!(Arc::ptr_eq(&parent.budget_tracker, &child.budget_tracker));
574        assert!(Arc::ptr_eq(&parent.allowance_store, &child.allowance_store));
575        assert!(Arc::ptr_eq(&parent.capabilities, &child.capabilities));
576
577        // Messages are independent
578        assert!(child.messages.is_empty());
579
580        // ApprovalManager is a different instance (fresh handler registration)
581        assert!(!Arc::ptr_eq(
582            &parent.approval_manager,
583            &child.approval_manager
584        ));
585    }
586
587    #[test]
588    fn test_backwards_compatible_deserialization() {
589        // Old session format without new fields should still deserialize
590        let json = r#"{
591            "id": "00000000-0000-0000-0000-000000000001",
592            "user_id": "0101010101010101",
593            "messages": [],
594            "system_prompt": "Test",
595            "created_at": "2024-01-01T00:00:00Z",
596            "token_count": 0,
597            "metadata": {
598                "title": null,
599                "tags": [],
600                "turn_count": 0,
601                "tool_call_count": 0,
602                "approval_count": 0,
603                "custom": {}
604            }
605        }"#;
606
607        let serializable: SerializableSession = serde_json::from_str(json).unwrap();
608        let session = serializable.to_session();
609        assert_eq!(session.system_prompt, "Test");
610        assert!(session.workspace_path.is_none());
611        assert_eq!(session.budget_tracker.spent(), 0.0);
612    }
613}