Skip to main content

astrid_runtime/
session.rs

1//! Agent session management.
2//!
3//! Sessions track conversation state, capabilities, and context.
4
5use astrid_approval::allowance::Allowance;
6use astrid_approval::budget::{BudgetSnapshot, BudgetTracker, WorkspaceBudgetTracker};
7use astrid_approval::{AllowanceStore, ApprovalManager, DeferredResolutionStore};
8use astrid_capabilities::CapabilityStore;
9use astrid_core::SessionId;
10use astrid_llm::Message;
11use astrid_workspace::escape::{EscapeHandler, EscapeState};
12use chrono::{DateTime, Utc};
13use serde::{Deserialize, Serialize};
14use std::path::PathBuf;
15use std::sync::Arc;
16
17/// An agent session.
18#[derive(Debug)]
19pub struct AgentSession {
20    /// Unique session identifier.
21    pub id: SessionId,
22    /// User identifier (key ID).
23    pub user_id: [u8; 8],
24    /// Conversation messages.
25    pub messages: Vec<Message>,
26    /// Session capabilities.
27    pub capabilities: Arc<CapabilityStore>,
28    /// Session allowance store.
29    pub allowance_store: Arc<AllowanceStore>,
30    /// Session approval manager.
31    pub approval_manager: Arc<ApprovalManager>,
32    /// System prompt.
33    pub system_prompt: String,
34    /// When the session was created.
35    pub created_at: DateTime<Utc>,
36    /// Estimated token count.
37    pub token_count: usize,
38    /// Session metadata.
39    pub metadata: SessionMetadata,
40    /// Workspace escape handler for tracking allowed paths.
41    pub escape_handler: EscapeHandler,
42    /// Per-session budget tracker.
43    pub budget_tracker: Arc<BudgetTracker>,
44    /// Workspace cumulative budget tracker (shared across sessions).
45    pub workspace_budget_tracker: Option<Arc<WorkspaceBudgetTracker>>,
46    /// Workspace path this session belongs to (for workspace-scoped listing).
47    pub workspace_path: Option<PathBuf>,
48    /// Model used for this session (e.g. `"claude-sonnet-4-20250514"`).
49    pub model: Option<String>,
50    /// Whether this session belongs to a sub-agent (skip spark preamble in `run_loop`).
51    pub is_subagent: bool,
52    /// Plugin-provided context (fetched dynamically per subagent/session, not persisted).
53    pub capsule_context: Option<String>,
54}
55
56impl AgentSession {
57    /// Create a new session.
58    #[must_use]
59    pub fn new(user_id: [u8; 8], system_prompt: impl Into<String>) -> Self {
60        let allowance_store = Arc::new(AllowanceStore::new());
61        let deferred_queue = Arc::new(DeferredResolutionStore::new());
62        let approval_manager = Arc::new(ApprovalManager::new(
63            Arc::clone(&allowance_store),
64            deferred_queue,
65        ));
66        Self {
67            id: SessionId::new(),
68            user_id,
69            messages: Vec::new(),
70            capabilities: Arc::new(CapabilityStore::in_memory()),
71            allowance_store,
72            approval_manager,
73            system_prompt: system_prompt.into(),
74            created_at: Utc::now(),
75            token_count: 0,
76            metadata: SessionMetadata::default(),
77            escape_handler: EscapeHandler::new(),
78            budget_tracker: Arc::new(BudgetTracker::default()),
79            workspace_budget_tracker: None,
80            workspace_path: None,
81            model: None,
82            is_subagent: false,
83            capsule_context: None,
84        }
85    }
86
87    /// Create with a specific session ID.
88    #[must_use]
89    pub fn with_id(id: SessionId, user_id: [u8; 8], system_prompt: impl Into<String>) -> Self {
90        let allowance_store = Arc::new(AllowanceStore::new());
91        let deferred_queue = Arc::new(DeferredResolutionStore::new());
92        let approval_manager = Arc::new(ApprovalManager::new(
93            Arc::clone(&allowance_store),
94            deferred_queue,
95        ));
96        Self {
97            id,
98            user_id,
99            messages: Vec::new(),
100            capabilities: Arc::new(CapabilityStore::in_memory()),
101            allowance_store,
102            approval_manager,
103            system_prompt: system_prompt.into(),
104            created_at: Utc::now(),
105            token_count: 0,
106            metadata: SessionMetadata::default(),
107            escape_handler: EscapeHandler::new(),
108            budget_tracker: Arc::new(BudgetTracker::default()),
109            workspace_budget_tracker: None,
110            workspace_path: None,
111            model: None,
112            is_subagent: false,
113            capsule_context: None,
114        }
115    }
116
117    /// Create a child session that shares parent's stores.
118    ///
119    /// The child inherits the parent's `AllowanceStore`, `CapabilityStore`, and
120    /// `BudgetTracker` (same `Arc` — spend is visible bidirectionally). The
121    /// `ApprovalManager` and `DeferredResolutionStore` are fresh (independent
122    /// handler registration and independent deferred queue).
123    #[must_use]
124    pub fn with_shared_stores(
125        id: SessionId,
126        user_id: [u8; 8],
127        system_prompt: impl Into<String>,
128        allowance_store: Arc<AllowanceStore>,
129        capabilities: Arc<CapabilityStore>,
130        budget_tracker: Arc<BudgetTracker>,
131    ) -> Self {
132        let deferred_queue = Arc::new(DeferredResolutionStore::new());
133        let approval_manager = Arc::new(ApprovalManager::new(
134            Arc::clone(&allowance_store),
135            deferred_queue,
136        ));
137        Self {
138            id,
139            user_id,
140            messages: Vec::new(),
141            capabilities,
142            allowance_store,
143            approval_manager,
144            system_prompt: system_prompt.into(),
145            created_at: Utc::now(),
146            token_count: 0,
147            metadata: SessionMetadata::default(),
148            escape_handler: EscapeHandler::new(),
149            budget_tracker,
150            workspace_budget_tracker: None,
151            workspace_path: None,
152            model: None,
153            is_subagent: true,
154            capsule_context: None,
155        }
156    }
157
158    /// Set the workspace path for this session.
159    #[must_use]
160    pub fn with_workspace(mut self, path: impl Into<PathBuf>) -> Self {
161        self.workspace_path = Some(path.into());
162        self
163    }
164
165    /// Set the model name for this session.
166    #[must_use]
167    pub fn with_model(mut self, model: impl Into<String>) -> Self {
168        self.model = Some(model.into());
169        self
170    }
171
172    /// Replace the capability store with a persistent one.
173    ///
174    /// Call this after session construction when a persistent store is available
175    /// (e.g. at daemon startup).
176    #[must_use]
177    pub fn with_capability_store(mut self, store: Arc<CapabilityStore>) -> Self {
178        self.capabilities = store;
179        self
180    }
181
182    /// Set the workspace cumulative budget tracker.
183    #[must_use]
184    pub fn with_workspace_budget(mut self, tracker: Arc<WorkspaceBudgetTracker>) -> Self {
185        self.workspace_budget_tracker = Some(tracker);
186        self
187    }
188
189    /// Import workspace-scoped allowances into this session.
190    ///
191    /// These allowances were previously persisted in the workspace `state.db`
192    /// and are loaded when a session is created or resumed in the same workspace.
193    pub fn import_workspace_allowances(
194        &self,
195        allowances: Vec<astrid_approval::allowance::Allowance>,
196    ) {
197        self.allowance_store.import_allowances(allowances);
198    }
199
200    /// Export workspace-scoped allowances from this session for persistence.
201    #[must_use]
202    pub fn export_workspace_allowances(&self) -> Vec<astrid_approval::allowance::Allowance> {
203        self.allowance_store.export_workspace_allowances()
204    }
205
206    /// Replace the deferred resolution queue with a persistent one.
207    ///
208    /// This reconstructs the `ApprovalManager` with the new persistent queue.
209    /// Call this after session construction when a persistent store is available.
210    ///
211    /// # Errors
212    ///
213    /// Returns an error if the persistent store cannot be initialized.
214    pub async fn with_persistent_deferred_queue(
215        mut self,
216        store: astrid_storage::ScopedKvStore,
217    ) -> Result<Self, crate::error::RuntimeError> {
218        let deferred_queue = Arc::new(
219            DeferredResolutionStore::with_persistence(store)
220                .await
221                .map_err(|e| crate::error::RuntimeError::StorageError(e.to_string()))?,
222        );
223        self.approval_manager = Arc::new(ApprovalManager::new(
224            Arc::clone(&self.allowance_store),
225            deferred_queue,
226        ));
227        Ok(self)
228    }
229
230    /// Add a message to the session.
231    pub fn add_message(&mut self, message: Message) {
232        // Rough token estimate (4 chars per token). This is a heuristic for
233        // context-limit warnings, not billing. Real context windows top out at
234        // ~200K tokens, so overflow of usize is not a practical concern.
235        let msg_tokens = match &message.content {
236            astrid_llm::MessageContent::Text(t) => t.len() / 4,
237            _ => 100, // Rough estimate for tool calls
238        };
239        self.token_count = self.token_count.saturating_add(msg_tokens);
240        self.messages.push(message);
241    }
242
243    /// Get the last N messages.
244    #[must_use]
245    pub fn last_messages(&self, n: usize) -> &[Message] {
246        let start = self.messages.len().saturating_sub(n);
247        &self.messages[start..]
248    }
249
250    /// Clear messages (keeping system prompt).
251    pub fn clear_messages(&mut self) {
252        self.messages.clear();
253        self.token_count = 0;
254    }
255
256    /// Get session duration.
257    #[must_use]
258    pub fn duration(&self) -> chrono::Duration {
259        // Safety: current time is always >= created_at
260        #[allow(clippy::arithmetic_side_effects)]
261        {
262            Utc::now() - self.created_at
263        }
264    }
265
266    /// Clean up session-scoped state.
267    ///
268    /// Clears session-only allowances, leaving workspace and persistent ones intact.
269    pub fn end_session(&self) {
270        self.allowance_store.clear_session_allowances();
271    }
272
273    /// Check if session is near context limit.
274    #[must_use]
275    #[allow(clippy::cast_precision_loss)]
276    pub fn is_near_limit(&self, max_tokens: usize, threshold: f32) -> bool {
277        self.token_count as f32 > max_tokens as f32 * threshold
278    }
279}
280
281/// Session metadata.
282#[derive(Debug, Clone, Default, Serialize, Deserialize)]
283pub struct SessionMetadata {
284    /// Session title (generated or user-provided).
285    pub title: Option<String>,
286    /// Tags for organization.
287    pub tags: Vec<String>,
288    /// Number of turns.
289    pub turn_count: usize,
290    /// Number of tool calls.
291    pub tool_call_count: usize,
292    /// Number of approvals granted.
293    pub approval_count: usize,
294    /// Custom key-value metadata.
295    pub custom: std::collections::HashMap<String, String>,
296}
297
298/// Serializable session state (for persistence).
299///
300/// Includes full security state: allowances, budget snapshot, escape handler
301/// state, and workspace path. This ensures "Allow Session" approvals,
302/// budget spend, and escape decisions survive daemon restarts.
303#[derive(Debug, Clone, Serialize, Deserialize)]
304pub struct SerializableSession {
305    /// Session ID.
306    pub id: String,
307    /// User ID (hex).
308    pub user_id: String,
309    /// Messages.
310    pub messages: Vec<SerializableMessage>,
311    /// System prompt.
312    pub system_prompt: String,
313    /// Created at.
314    pub created_at: DateTime<Utc>,
315    /// Token count.
316    pub token_count: usize,
317    /// Metadata.
318    pub metadata: SessionMetadata,
319    /// Session allowances (persisted so "Allow Session" survives restart).
320    #[serde(default)]
321    pub allowances: Vec<Allowance>,
322    /// Budget snapshot (persisted so budget is not reset on restart).
323    #[serde(default)]
324    pub budget_snapshot: Option<BudgetSnapshot>,
325    /// Escape handler state (persisted so "Allow Always" paths survive).
326    #[serde(default)]
327    pub escape_state: Option<EscapeState>,
328    /// Workspace path this session belongs to.
329    #[serde(default)]
330    pub workspace_path: Option<String>,
331    /// Model used for this session (e.g. "claude-sonnet-4-20250514").
332    #[serde(default)]
333    pub model: Option<String>,
334    /// Git state placeholder (branch, commit hash) for future worktree support.
335    #[serde(default)]
336    pub git_state: Option<GitState>,
337}
338
339/// Git repository state snapshot.
340#[derive(Debug, Clone, Serialize, Deserialize)]
341pub struct GitState {
342    /// Current branch name.
343    pub branch: Option<String>,
344    /// Current commit hash.
345    pub commit: Option<String>,
346}
347
348impl GitState {
349    /// Capture the current git state for a workspace path.
350    ///
351    /// Returns `None` if the path is not in a git repository or git is not available.
352    #[must_use]
353    pub fn capture(workspace_path: &std::path::Path) -> Option<Self> {
354        let branch = std::process::Command::new("git")
355            .args([
356                "-C",
357                &workspace_path.display().to_string(),
358                "rev-parse",
359                "--abbrev-ref",
360                "HEAD",
361            ])
362            .stdout(std::process::Stdio::piped())
363            .stderr(std::process::Stdio::null())
364            .output()
365            .ok()
366            .filter(|o| o.status.success())
367            .and_then(|o| String::from_utf8(o.stdout).ok())
368            .map(|s| s.trim().to_string());
369
370        let commit = std::process::Command::new("git")
371            .args([
372                "-C",
373                &workspace_path.display().to_string(),
374                "rev-parse",
375                "HEAD",
376            ])
377            .stdout(std::process::Stdio::piped())
378            .stderr(std::process::Stdio::null())
379            .output()
380            .ok()
381            .filter(|o| o.status.success())
382            .and_then(|o| String::from_utf8(o.stdout).ok())
383            .map(|s| s.trim().to_string());
384
385        // Only return Some if at least one field was captured.
386        if branch.is_some() || commit.is_some() {
387            Some(Self { branch, commit })
388        } else {
389            None
390        }
391    }
392}
393
394/// Serializable message.
395#[derive(Debug, Clone, Serialize, Deserialize)]
396pub struct SerializableMessage {
397    /// Role.
398    pub role: String,
399    /// Content (JSON).
400    pub content: serde_json::Value,
401}
402
403impl From<&AgentSession> for SerializableSession {
404    fn from(session: &AgentSession) -> Self {
405        Self {
406            id: session.id.0.to_string(),
407            user_id: hex::encode(session.user_id),
408            messages: session
409                .messages
410                .iter()
411                .map(|m| SerializableMessage {
412                    role: match m.role {
413                        astrid_llm::MessageRole::System => "system".to_string(),
414                        astrid_llm::MessageRole::User => "user".to_string(),
415                        astrid_llm::MessageRole::Assistant => "assistant".to_string(),
416                        astrid_llm::MessageRole::Tool => "tool".to_string(),
417                    },
418                    content: serde_json::to_value(&m.content).unwrap_or_default(),
419                })
420                .collect(),
421            system_prompt: session.system_prompt.clone(),
422            created_at: session.created_at,
423            token_count: session.token_count,
424            metadata: session.metadata.clone(),
425            allowances: session.allowance_store.export_session_allowances(),
426            budget_snapshot: Some(session.budget_tracker.snapshot()),
427            escape_state: Some(session.escape_handler.export_state()),
428            workspace_path: session
429                .workspace_path
430                .as_ref()
431                .map(|p| p.display().to_string()),
432            model: session.model.clone(),
433            git_state: session
434                .workspace_path
435                .as_ref()
436                .and_then(|p| GitState::capture(p)),
437        }
438    }
439}
440
441impl SerializableSession {
442    /// Convert back to an `AgentSession`.
443    ///
444    /// Restores full security state: allowances, budget, and escape handler.
445    #[must_use]
446    pub fn to_session(&self) -> AgentSession {
447        let mut user_id = [0u8; 8];
448        if let Ok(bytes) = hex::decode(&self.user_id)
449            && bytes.len() >= 8
450        {
451            user_id.copy_from_slice(&bytes[..8]);
452        }
453
454        let id =
455            uuid::Uuid::parse_str(&self.id).map_or_else(|_| SessionId::new(), SessionId::from_uuid);
456
457        let messages: Vec<Message> = self
458            .messages
459            .iter()
460            .filter_map(|m| {
461                let content: astrid_llm::MessageContent =
462                    serde_json::from_value(m.content.clone()).ok()?;
463                let role = match m.role.as_str() {
464                    "system" => astrid_llm::MessageRole::System,
465                    "user" => astrid_llm::MessageRole::User,
466                    "assistant" => astrid_llm::MessageRole::Assistant,
467                    "tool" => astrid_llm::MessageRole::Tool,
468                    _ => return None,
469                };
470                Some(Message { role, content })
471            })
472            .collect();
473
474        let mut session = AgentSession::with_id(id, user_id, &self.system_prompt);
475        session.messages = messages;
476        session.created_at = self.created_at;
477        session.token_count = self.token_count;
478        session.metadata = self.metadata.clone();
479        session.workspace_path = self.workspace_path.as_ref().map(PathBuf::from);
480        session.model.clone_from(&self.model);
481
482        // Restore session allowances
483        if !self.allowances.is_empty() {
484            session
485                .allowance_store
486                .import_allowances(self.allowances.clone());
487        }
488
489        // Restore budget from snapshot (prevents budget bypass via restart)
490        if let Some(snapshot) = &self.budget_snapshot {
491            session.budget_tracker = Arc::new(BudgetTracker::restore(snapshot.clone()));
492        }
493
494        // Restore escape handler state (preserves "AllowAlways" paths)
495        if let Some(escape_state) = &self.escape_state {
496            session.escape_handler.restore_state(escape_state.clone());
497        }
498
499        session
500    }
501}
502
503#[cfg(test)]
504mod tests {
505    use super::*;
506    use astrid_llm::Message;
507
508    #[test]
509    fn test_session_creation() {
510        let session = AgentSession::new([0u8; 8], "You are helpful");
511        assert!(session.messages.is_empty());
512        assert_eq!(session.system_prompt, "You are helpful");
513    }
514
515    #[test]
516    fn test_add_message() {
517        let mut session = AgentSession::new([0u8; 8], "");
518        session.add_message(Message::user("Hello"));
519        session.add_message(Message::assistant("Hi!"));
520
521        assert_eq!(session.messages.len(), 2);
522        assert!(session.token_count > 0);
523    }
524
525    #[test]
526    fn test_serialization_roundtrip() {
527        let mut session = AgentSession::new([1u8; 8], "Test prompt");
528        session.add_message(Message::user("Hello"));
529        session.add_message(Message::assistant("World"));
530
531        let serializable = SerializableSession::from(&session);
532        let restored = serializable.to_session();
533
534        assert_eq!(restored.system_prompt, session.system_prompt);
535        assert_eq!(restored.messages.len(), session.messages.len());
536    }
537
538    #[test]
539    fn test_budget_snapshot_roundtrip() {
540        let session = AgentSession::new([1u8; 8], "Test");
541        session.budget_tracker.record_cost(42.5);
542
543        let serializable = SerializableSession::from(&session);
544        let restored = serializable.to_session();
545
546        assert!((restored.budget_tracker.spent() - 42.5).abs() < f64::EPSILON);
547    }
548
549    #[test]
550    fn test_workspace_path_roundtrip() {
551        let session = AgentSession::new([1u8; 8], "Test").with_workspace("/home/user/project");
552
553        let serializable = SerializableSession::from(&session);
554        let restored = serializable.to_session();
555
556        assert_eq!(
557            restored.workspace_path,
558            Some(PathBuf::from("/home/user/project"))
559        );
560    }
561
562    #[test]
563    fn test_with_shared_stores() {
564        let parent = AgentSession::new([1u8; 8], "Parent");
565
566        // Record some spend on the parent
567        parent.budget_tracker.record_cost(10.0);
568
569        // Create child with shared stores
570        let child = AgentSession::with_shared_stores(
571            SessionId::new(),
572            [1u8; 8],
573            "Child",
574            Arc::clone(&parent.allowance_store),
575            Arc::clone(&parent.capabilities),
576            Arc::clone(&parent.budget_tracker),
577        );
578
579        // Budget spend is visible from child
580        assert!((child.budget_tracker.spent() - 10.0).abs() < f64::EPSILON);
581
582        // Child spend is visible from parent (same Arc)
583        child.budget_tracker.record_cost(5.0);
584        assert!((parent.budget_tracker.spent() - 15.0).abs() < f64::EPSILON);
585
586        // Stores are the same Arc
587        assert!(Arc::ptr_eq(&parent.budget_tracker, &child.budget_tracker));
588        assert!(Arc::ptr_eq(&parent.allowance_store, &child.allowance_store));
589        assert!(Arc::ptr_eq(&parent.capabilities, &child.capabilities));
590
591        // Messages are independent
592        assert!(child.messages.is_empty());
593
594        // ApprovalManager is a different instance (fresh handler registration)
595        assert!(!Arc::ptr_eq(
596            &parent.approval_manager,
597            &child.approval_manager
598        ));
599    }
600
601    #[test]
602    fn test_backwards_compatible_deserialization() {
603        // Old session format without new fields should still deserialize
604        let json = r#"{
605            "id": "00000000-0000-0000-0000-000000000001",
606            "user_id": "0101010101010101",
607            "messages": [],
608            "system_prompt": "Test",
609            "created_at": "2024-01-01T00:00:00Z",
610            "token_count": 0,
611            "metadata": {
612                "title": null,
613                "tags": [],
614                "turn_count": 0,
615                "tool_call_count": 0,
616                "approval_count": 0,
617                "custom": {}
618            }
619        }"#;
620
621        let serializable: SerializableSession = serde_json::from_str(json).unwrap();
622        let session = serializable.to_session();
623        assert_eq!(session.system_prompt, "Test");
624        assert!(session.workspace_path.is_none());
625        assert!((session.budget_tracker.spent() - 0.0_f64).abs() < f64::EPSILON);
626    }
627}