1use astrid_approval::allowance::Allowance;
6use astrid_approval::budget::{BudgetSnapshot, BudgetTracker, WorkspaceBudgetTracker};
7use astrid_approval::{AllowanceStore, ApprovalManager, DeferredResolutionStore};
8use astrid_capabilities::CapabilityStore;
9use astrid_core::SessionId;
10use astrid_llm::Message;
11use astrid_workspace::escape::{EscapeHandler, EscapeState};
12use chrono::{DateTime, Utc};
13use serde::{Deserialize, Serialize};
14use std::path::PathBuf;
15use std::sync::Arc;
16
17#[derive(Debug)]
19pub struct AgentSession {
20 pub id: SessionId,
22 pub user_id: [u8; 8],
24 pub messages: Vec<Message>,
26 pub capabilities: Arc<CapabilityStore>,
28 pub allowance_store: Arc<AllowanceStore>,
30 pub approval_manager: Arc<ApprovalManager>,
32 pub system_prompt: String,
34 pub created_at: DateTime<Utc>,
36 pub token_count: usize,
38 pub metadata: SessionMetadata,
40 pub escape_handler: EscapeHandler,
42 pub budget_tracker: Arc<BudgetTracker>,
44 pub workspace_budget_tracker: Option<Arc<WorkspaceBudgetTracker>>,
46 pub workspace_path: Option<PathBuf>,
48 pub model: Option<String>,
50 pub is_subagent: bool,
52 pub capsule_context: Option<String>,
54}
55
56impl AgentSession {
57 #[must_use]
59 pub fn new(user_id: [u8; 8], system_prompt: impl Into<String>) -> Self {
60 let allowance_store = Arc::new(AllowanceStore::new());
61 let deferred_queue = Arc::new(DeferredResolutionStore::new());
62 let approval_manager = Arc::new(ApprovalManager::new(
63 Arc::clone(&allowance_store),
64 deferred_queue,
65 ));
66 Self {
67 id: SessionId::new(),
68 user_id,
69 messages: Vec::new(),
70 capabilities: Arc::new(CapabilityStore::in_memory()),
71 allowance_store,
72 approval_manager,
73 system_prompt: system_prompt.into(),
74 created_at: Utc::now(),
75 token_count: 0,
76 metadata: SessionMetadata::default(),
77 escape_handler: EscapeHandler::new(),
78 budget_tracker: Arc::new(BudgetTracker::default()),
79 workspace_budget_tracker: None,
80 workspace_path: None,
81 model: None,
82 is_subagent: false,
83 capsule_context: None,
84 }
85 }
86
87 #[must_use]
89 pub fn with_id(id: SessionId, user_id: [u8; 8], system_prompt: impl Into<String>) -> Self {
90 let allowance_store = Arc::new(AllowanceStore::new());
91 let deferred_queue = Arc::new(DeferredResolutionStore::new());
92 let approval_manager = Arc::new(ApprovalManager::new(
93 Arc::clone(&allowance_store),
94 deferred_queue,
95 ));
96 Self {
97 id,
98 user_id,
99 messages: Vec::new(),
100 capabilities: Arc::new(CapabilityStore::in_memory()),
101 allowance_store,
102 approval_manager,
103 system_prompt: system_prompt.into(),
104 created_at: Utc::now(),
105 token_count: 0,
106 metadata: SessionMetadata::default(),
107 escape_handler: EscapeHandler::new(),
108 budget_tracker: Arc::new(BudgetTracker::default()),
109 workspace_budget_tracker: None,
110 workspace_path: None,
111 model: None,
112 is_subagent: false,
113 capsule_context: None,
114 }
115 }
116
117 #[must_use]
124 pub fn with_shared_stores(
125 id: SessionId,
126 user_id: [u8; 8],
127 system_prompt: impl Into<String>,
128 allowance_store: Arc<AllowanceStore>,
129 capabilities: Arc<CapabilityStore>,
130 budget_tracker: Arc<BudgetTracker>,
131 ) -> Self {
132 let deferred_queue = Arc::new(DeferredResolutionStore::new());
133 let approval_manager = Arc::new(ApprovalManager::new(
134 Arc::clone(&allowance_store),
135 deferred_queue,
136 ));
137 Self {
138 id,
139 user_id,
140 messages: Vec::new(),
141 capabilities,
142 allowance_store,
143 approval_manager,
144 system_prompt: system_prompt.into(),
145 created_at: Utc::now(),
146 token_count: 0,
147 metadata: SessionMetadata::default(),
148 escape_handler: EscapeHandler::new(),
149 budget_tracker,
150 workspace_budget_tracker: None,
151 workspace_path: None,
152 model: None,
153 is_subagent: true,
154 capsule_context: None,
155 }
156 }
157
158 #[must_use]
160 pub fn with_workspace(mut self, path: impl Into<PathBuf>) -> Self {
161 self.workspace_path = Some(path.into());
162 self
163 }
164
165 #[must_use]
167 pub fn with_model(mut self, model: impl Into<String>) -> Self {
168 self.model = Some(model.into());
169 self
170 }
171
172 #[must_use]
177 pub fn with_capability_store(mut self, store: Arc<CapabilityStore>) -> Self {
178 self.capabilities = store;
179 self
180 }
181
182 #[must_use]
184 pub fn with_workspace_budget(mut self, tracker: Arc<WorkspaceBudgetTracker>) -> Self {
185 self.workspace_budget_tracker = Some(tracker);
186 self
187 }
188
189 pub fn import_workspace_allowances(
194 &self,
195 allowances: Vec<astrid_approval::allowance::Allowance>,
196 ) {
197 self.allowance_store.import_allowances(allowances);
198 }
199
200 #[must_use]
202 pub fn export_workspace_allowances(&self) -> Vec<astrid_approval::allowance::Allowance> {
203 self.allowance_store.export_workspace_allowances()
204 }
205
206 pub async fn with_persistent_deferred_queue(
215 mut self,
216 store: astrid_storage::ScopedKvStore,
217 ) -> Result<Self, crate::error::RuntimeError> {
218 let deferred_queue = Arc::new(
219 DeferredResolutionStore::with_persistence(store)
220 .await
221 .map_err(|e| crate::error::RuntimeError::StorageError(e.to_string()))?,
222 );
223 self.approval_manager = Arc::new(ApprovalManager::new(
224 Arc::clone(&self.allowance_store),
225 deferred_queue,
226 ));
227 Ok(self)
228 }
229
230 pub fn add_message(&mut self, message: Message) {
232 let msg_tokens = match &message.content {
236 astrid_llm::MessageContent::Text(t) => t.len() / 4,
237 _ => 100, };
239 self.token_count = self.token_count.saturating_add(msg_tokens);
240 self.messages.push(message);
241 }
242
243 #[must_use]
245 pub fn last_messages(&self, n: usize) -> &[Message] {
246 let start = self.messages.len().saturating_sub(n);
247 &self.messages[start..]
248 }
249
250 pub fn clear_messages(&mut self) {
252 self.messages.clear();
253 self.token_count = 0;
254 }
255
256 #[must_use]
258 pub fn duration(&self) -> chrono::Duration {
259 #[allow(clippy::arithmetic_side_effects)]
261 {
262 Utc::now() - self.created_at
263 }
264 }
265
266 pub fn end_session(&self) {
270 self.allowance_store.clear_session_allowances();
271 }
272
273 #[must_use]
275 #[allow(clippy::cast_precision_loss)]
276 pub fn is_near_limit(&self, max_tokens: usize, threshold: f32) -> bool {
277 self.token_count as f32 > max_tokens as f32 * threshold
278 }
279}
280
281#[derive(Debug, Clone, Default, Serialize, Deserialize)]
283pub struct SessionMetadata {
284 pub title: Option<String>,
286 pub tags: Vec<String>,
288 pub turn_count: usize,
290 pub tool_call_count: usize,
292 pub approval_count: usize,
294 pub custom: std::collections::HashMap<String, String>,
296}
297
298#[derive(Debug, Clone, Serialize, Deserialize)]
304pub struct SerializableSession {
305 pub id: String,
307 pub user_id: String,
309 pub messages: Vec<SerializableMessage>,
311 pub system_prompt: String,
313 pub created_at: DateTime<Utc>,
315 pub token_count: usize,
317 pub metadata: SessionMetadata,
319 #[serde(default)]
321 pub allowances: Vec<Allowance>,
322 #[serde(default)]
324 pub budget_snapshot: Option<BudgetSnapshot>,
325 #[serde(default)]
327 pub escape_state: Option<EscapeState>,
328 #[serde(default)]
330 pub workspace_path: Option<String>,
331 #[serde(default)]
333 pub model: Option<String>,
334 #[serde(default)]
336 pub git_state: Option<GitState>,
337}
338
339#[derive(Debug, Clone, Serialize, Deserialize)]
341pub struct GitState {
342 pub branch: Option<String>,
344 pub commit: Option<String>,
346}
347
348impl GitState {
349 #[must_use]
353 pub fn capture(workspace_path: &std::path::Path) -> Option<Self> {
354 let branch = std::process::Command::new("git")
355 .args([
356 "-C",
357 &workspace_path.display().to_string(),
358 "rev-parse",
359 "--abbrev-ref",
360 "HEAD",
361 ])
362 .stdout(std::process::Stdio::piped())
363 .stderr(std::process::Stdio::null())
364 .output()
365 .ok()
366 .filter(|o| o.status.success())
367 .and_then(|o| String::from_utf8(o.stdout).ok())
368 .map(|s| s.trim().to_string());
369
370 let commit = std::process::Command::new("git")
371 .args([
372 "-C",
373 &workspace_path.display().to_string(),
374 "rev-parse",
375 "HEAD",
376 ])
377 .stdout(std::process::Stdio::piped())
378 .stderr(std::process::Stdio::null())
379 .output()
380 .ok()
381 .filter(|o| o.status.success())
382 .and_then(|o| String::from_utf8(o.stdout).ok())
383 .map(|s| s.trim().to_string());
384
385 if branch.is_some() || commit.is_some() {
387 Some(Self { branch, commit })
388 } else {
389 None
390 }
391 }
392}
393
394#[derive(Debug, Clone, Serialize, Deserialize)]
396pub struct SerializableMessage {
397 pub role: String,
399 pub content: serde_json::Value,
401}
402
403impl From<&AgentSession> for SerializableSession {
404 fn from(session: &AgentSession) -> Self {
405 Self {
406 id: session.id.0.to_string(),
407 user_id: hex::encode(session.user_id),
408 messages: session
409 .messages
410 .iter()
411 .map(|m| SerializableMessage {
412 role: match m.role {
413 astrid_llm::MessageRole::System => "system".to_string(),
414 astrid_llm::MessageRole::User => "user".to_string(),
415 astrid_llm::MessageRole::Assistant => "assistant".to_string(),
416 astrid_llm::MessageRole::Tool => "tool".to_string(),
417 },
418 content: serde_json::to_value(&m.content).unwrap_or_default(),
419 })
420 .collect(),
421 system_prompt: session.system_prompt.clone(),
422 created_at: session.created_at,
423 token_count: session.token_count,
424 metadata: session.metadata.clone(),
425 allowances: session.allowance_store.export_session_allowances(),
426 budget_snapshot: Some(session.budget_tracker.snapshot()),
427 escape_state: Some(session.escape_handler.export_state()),
428 workspace_path: session
429 .workspace_path
430 .as_ref()
431 .map(|p| p.display().to_string()),
432 model: session.model.clone(),
433 git_state: session
434 .workspace_path
435 .as_ref()
436 .and_then(|p| GitState::capture(p)),
437 }
438 }
439}
440
441impl SerializableSession {
442 #[must_use]
446 pub fn to_session(&self) -> AgentSession {
447 let mut user_id = [0u8; 8];
448 if let Ok(bytes) = hex::decode(&self.user_id)
449 && bytes.len() >= 8
450 {
451 user_id.copy_from_slice(&bytes[..8]);
452 }
453
454 let id =
455 uuid::Uuid::parse_str(&self.id).map_or_else(|_| SessionId::new(), SessionId::from_uuid);
456
457 let messages: Vec<Message> = self
458 .messages
459 .iter()
460 .filter_map(|m| {
461 let content: astrid_llm::MessageContent =
462 serde_json::from_value(m.content.clone()).ok()?;
463 let role = match m.role.as_str() {
464 "system" => astrid_llm::MessageRole::System,
465 "user" => astrid_llm::MessageRole::User,
466 "assistant" => astrid_llm::MessageRole::Assistant,
467 "tool" => astrid_llm::MessageRole::Tool,
468 _ => return None,
469 };
470 Some(Message { role, content })
471 })
472 .collect();
473
474 let mut session = AgentSession::with_id(id, user_id, &self.system_prompt);
475 session.messages = messages;
476 session.created_at = self.created_at;
477 session.token_count = self.token_count;
478 session.metadata = self.metadata.clone();
479 session.workspace_path = self.workspace_path.as_ref().map(PathBuf::from);
480 session.model.clone_from(&self.model);
481
482 if !self.allowances.is_empty() {
484 session
485 .allowance_store
486 .import_allowances(self.allowances.clone());
487 }
488
489 if let Some(snapshot) = &self.budget_snapshot {
491 session.budget_tracker = Arc::new(BudgetTracker::restore(snapshot.clone()));
492 }
493
494 if let Some(escape_state) = &self.escape_state {
496 session.escape_handler.restore_state(escape_state.clone());
497 }
498
499 session
500 }
501}
502
503#[cfg(test)]
504mod tests {
505 use super::*;
506 use astrid_llm::Message;
507
508 #[test]
509 fn test_session_creation() {
510 let session = AgentSession::new([0u8; 8], "You are helpful");
511 assert!(session.messages.is_empty());
512 assert_eq!(session.system_prompt, "You are helpful");
513 }
514
515 #[test]
516 fn test_add_message() {
517 let mut session = AgentSession::new([0u8; 8], "");
518 session.add_message(Message::user("Hello"));
519 session.add_message(Message::assistant("Hi!"));
520
521 assert_eq!(session.messages.len(), 2);
522 assert!(session.token_count > 0);
523 }
524
525 #[test]
526 fn test_serialization_roundtrip() {
527 let mut session = AgentSession::new([1u8; 8], "Test prompt");
528 session.add_message(Message::user("Hello"));
529 session.add_message(Message::assistant("World"));
530
531 let serializable = SerializableSession::from(&session);
532 let restored = serializable.to_session();
533
534 assert_eq!(restored.system_prompt, session.system_prompt);
535 assert_eq!(restored.messages.len(), session.messages.len());
536 }
537
538 #[test]
539 fn test_budget_snapshot_roundtrip() {
540 let session = AgentSession::new([1u8; 8], "Test");
541 session.budget_tracker.record_cost(42.5);
542
543 let serializable = SerializableSession::from(&session);
544 let restored = serializable.to_session();
545
546 assert!((restored.budget_tracker.spent() - 42.5).abs() < f64::EPSILON);
547 }
548
549 #[test]
550 fn test_workspace_path_roundtrip() {
551 let session = AgentSession::new([1u8; 8], "Test").with_workspace("/home/user/project");
552
553 let serializable = SerializableSession::from(&session);
554 let restored = serializable.to_session();
555
556 assert_eq!(
557 restored.workspace_path,
558 Some(PathBuf::from("/home/user/project"))
559 );
560 }
561
562 #[test]
563 fn test_with_shared_stores() {
564 let parent = AgentSession::new([1u8; 8], "Parent");
565
566 parent.budget_tracker.record_cost(10.0);
568
569 let child = AgentSession::with_shared_stores(
571 SessionId::new(),
572 [1u8; 8],
573 "Child",
574 Arc::clone(&parent.allowance_store),
575 Arc::clone(&parent.capabilities),
576 Arc::clone(&parent.budget_tracker),
577 );
578
579 assert!((child.budget_tracker.spent() - 10.0).abs() < f64::EPSILON);
581
582 child.budget_tracker.record_cost(5.0);
584 assert!((parent.budget_tracker.spent() - 15.0).abs() < f64::EPSILON);
585
586 assert!(Arc::ptr_eq(&parent.budget_tracker, &child.budget_tracker));
588 assert!(Arc::ptr_eq(&parent.allowance_store, &child.allowance_store));
589 assert!(Arc::ptr_eq(&parent.capabilities, &child.capabilities));
590
591 assert!(child.messages.is_empty());
593
594 assert!(!Arc::ptr_eq(
596 &parent.approval_manager,
597 &child.approval_manager
598 ));
599 }
600
601 #[test]
602 fn test_backwards_compatible_deserialization() {
603 let json = r#"{
605 "id": "00000000-0000-0000-0000-000000000001",
606 "user_id": "0101010101010101",
607 "messages": [],
608 "system_prompt": "Test",
609 "created_at": "2024-01-01T00:00:00Z",
610 "token_count": 0,
611 "metadata": {
612 "title": null,
613 "tags": [],
614 "turn_count": 0,
615 "tool_call_count": 0,
616 "approval_count": 0,
617 "custom": {}
618 }
619 }"#;
620
621 let serializable: SerializableSession = serde_json::from_str(json).unwrap();
622 let session = serializable.to_session();
623 assert_eq!(session.system_prompt, "Test");
624 assert!(session.workspace_path.is_none());
625 assert!((session.budget_tracker.spent() - 0.0_f64).abs() < f64::EPSILON);
626 }
627}