1use astrid_approval::allowance::Allowance;
6use astrid_approval::budget::{BudgetSnapshot, BudgetTracker, WorkspaceBudgetTracker};
7use astrid_approval::{AllowanceStore, ApprovalManager, DeferredResolutionStore};
8use astrid_capabilities::CapabilityStore;
9use astrid_core::SessionId;
10use astrid_llm::Message;
11use astrid_workspace::escape::{EscapeHandler, EscapeState};
12use chrono::{DateTime, Utc};
13use serde::{Deserialize, Serialize};
14use std::path::PathBuf;
15use std::sync::Arc;
16
17#[derive(Debug)]
19pub struct AgentSession {
20 pub id: SessionId,
22 pub user_id: [u8; 8],
24 pub messages: Vec<Message>,
26 pub capabilities: Arc<CapabilityStore>,
28 pub allowance_store: Arc<AllowanceStore>,
30 pub approval_manager: Arc<ApprovalManager>,
32 pub system_prompt: String,
34 pub created_at: DateTime<Utc>,
36 pub token_count: usize,
38 pub metadata: SessionMetadata,
40 pub escape_handler: EscapeHandler,
42 pub budget_tracker: Arc<BudgetTracker>,
44 pub workspace_budget_tracker: Option<Arc<WorkspaceBudgetTracker>>,
46 pub workspace_path: Option<PathBuf>,
48 pub model: Option<String>,
50}
51
52impl AgentSession {
53 #[must_use]
55 pub fn new(user_id: [u8; 8], system_prompt: impl Into<String>) -> Self {
56 let allowance_store = Arc::new(AllowanceStore::new());
57 let deferred_queue = Arc::new(DeferredResolutionStore::new());
58 let approval_manager = Arc::new(ApprovalManager::new(
59 Arc::clone(&allowance_store),
60 deferred_queue,
61 ));
62 Self {
63 id: SessionId::new(),
64 user_id,
65 messages: Vec::new(),
66 capabilities: Arc::new(CapabilityStore::in_memory()),
67 allowance_store,
68 approval_manager,
69 system_prompt: system_prompt.into(),
70 created_at: Utc::now(),
71 token_count: 0,
72 metadata: SessionMetadata::default(),
73 escape_handler: EscapeHandler::new(),
74 budget_tracker: Arc::new(BudgetTracker::default()),
75 workspace_budget_tracker: None,
76 workspace_path: None,
77 model: None,
78 }
79 }
80
81 #[must_use]
83 pub fn with_id(id: SessionId, user_id: [u8; 8], system_prompt: impl Into<String>) -> Self {
84 let allowance_store = Arc::new(AllowanceStore::new());
85 let deferred_queue = Arc::new(DeferredResolutionStore::new());
86 let approval_manager = Arc::new(ApprovalManager::new(
87 Arc::clone(&allowance_store),
88 deferred_queue,
89 ));
90 Self {
91 id,
92 user_id,
93 messages: Vec::new(),
94 capabilities: Arc::new(CapabilityStore::in_memory()),
95 allowance_store,
96 approval_manager,
97 system_prompt: system_prompt.into(),
98 created_at: Utc::now(),
99 token_count: 0,
100 metadata: SessionMetadata::default(),
101 escape_handler: EscapeHandler::new(),
102 budget_tracker: Arc::new(BudgetTracker::default()),
103 workspace_budget_tracker: None,
104 workspace_path: None,
105 model: None,
106 }
107 }
108
109 #[must_use]
116 pub fn with_shared_stores(
117 id: SessionId,
118 user_id: [u8; 8],
119 system_prompt: impl Into<String>,
120 allowance_store: Arc<AllowanceStore>,
121 capabilities: Arc<CapabilityStore>,
122 budget_tracker: Arc<BudgetTracker>,
123 ) -> Self {
124 let deferred_queue = Arc::new(DeferredResolutionStore::new());
125 let approval_manager = Arc::new(ApprovalManager::new(
126 Arc::clone(&allowance_store),
127 deferred_queue,
128 ));
129 Self {
130 id,
131 user_id,
132 messages: Vec::new(),
133 capabilities,
134 allowance_store,
135 approval_manager,
136 system_prompt: system_prompt.into(),
137 created_at: Utc::now(),
138 token_count: 0,
139 metadata: SessionMetadata::default(),
140 escape_handler: EscapeHandler::new(),
141 budget_tracker,
142 workspace_budget_tracker: None,
143 workspace_path: None,
144 model: None,
145 }
146 }
147
148 #[must_use]
150 pub fn with_workspace(mut self, path: impl Into<PathBuf>) -> Self {
151 self.workspace_path = Some(path.into());
152 self
153 }
154
155 #[must_use]
157 pub fn with_model(mut self, model: impl Into<String>) -> Self {
158 self.model = Some(model.into());
159 self
160 }
161
162 #[must_use]
167 pub fn with_capability_store(mut self, store: Arc<CapabilityStore>) -> Self {
168 self.capabilities = store;
169 self
170 }
171
172 #[must_use]
174 pub fn with_workspace_budget(mut self, tracker: Arc<WorkspaceBudgetTracker>) -> Self {
175 self.workspace_budget_tracker = Some(tracker);
176 self
177 }
178
179 pub fn import_workspace_allowances(
184 &self,
185 allowances: Vec<astrid_approval::allowance::Allowance>,
186 ) {
187 self.allowance_store.import_allowances(allowances);
188 }
189
190 #[must_use]
192 pub fn export_workspace_allowances(&self) -> Vec<astrid_approval::allowance::Allowance> {
193 self.allowance_store.export_workspace_allowances()
194 }
195
196 pub async fn with_persistent_deferred_queue(
205 mut self,
206 store: astrid_storage::ScopedKvStore,
207 ) -> Result<Self, astrid_core::error::SecurityError> {
208 let deferred_queue = Arc::new(DeferredResolutionStore::with_persistence(store).await?);
209 self.approval_manager = Arc::new(ApprovalManager::new(
210 Arc::clone(&self.allowance_store),
211 deferred_queue,
212 ));
213 Ok(self)
214 }
215
216 pub fn add_message(&mut self, message: Message) {
218 let msg_tokens = match &message.content {
222 astrid_llm::MessageContent::Text(t) => t.len() / 4,
223 _ => 100, };
225 self.token_count = self.token_count.saturating_add(msg_tokens);
226 self.messages.push(message);
227 }
228
229 #[must_use]
231 pub fn last_messages(&self, n: usize) -> &[Message] {
232 let start = self.messages.len().saturating_sub(n);
233 &self.messages[start..]
234 }
235
236 pub fn clear_messages(&mut self) {
238 self.messages.clear();
239 self.token_count = 0;
240 }
241
242 #[must_use]
244 pub fn duration(&self) -> chrono::Duration {
245 #[allow(clippy::arithmetic_side_effects)]
247 {
248 Utc::now() - self.created_at
249 }
250 }
251
252 pub fn end_session(&self) {
256 self.allowance_store.clear_session_allowances();
257 }
258
259 #[must_use]
261 #[allow(clippy::cast_precision_loss)]
262 pub fn is_near_limit(&self, max_tokens: usize, threshold: f32) -> bool {
263 self.token_count as f32 > max_tokens as f32 * threshold
264 }
265}
266
267#[derive(Debug, Clone, Default, Serialize, Deserialize)]
269pub struct SessionMetadata {
270 pub title: Option<String>,
272 pub tags: Vec<String>,
274 pub turn_count: usize,
276 pub tool_call_count: usize,
278 pub approval_count: usize,
280 pub custom: std::collections::HashMap<String, String>,
282}
283
284#[derive(Debug, Clone, Serialize, Deserialize)]
290pub struct SerializableSession {
291 pub id: String,
293 pub user_id: String,
295 pub messages: Vec<SerializableMessage>,
297 pub system_prompt: String,
299 pub created_at: DateTime<Utc>,
301 pub token_count: usize,
303 pub metadata: SessionMetadata,
305 #[serde(default)]
307 pub allowances: Vec<Allowance>,
308 #[serde(default)]
310 pub budget_snapshot: Option<BudgetSnapshot>,
311 #[serde(default)]
313 pub escape_state: Option<EscapeState>,
314 #[serde(default)]
316 pub workspace_path: Option<String>,
317 #[serde(default)]
319 pub model: Option<String>,
320 #[serde(default)]
322 pub git_state: Option<GitState>,
323}
324
325#[derive(Debug, Clone, Serialize, Deserialize)]
327pub struct GitState {
328 pub branch: Option<String>,
330 pub commit: Option<String>,
332}
333
334impl GitState {
335 #[must_use]
339 pub fn capture(workspace_path: &std::path::Path) -> Option<Self> {
340 let branch = std::process::Command::new("git")
341 .args([
342 "-C",
343 &workspace_path.display().to_string(),
344 "rev-parse",
345 "--abbrev-ref",
346 "HEAD",
347 ])
348 .stdout(std::process::Stdio::piped())
349 .stderr(std::process::Stdio::null())
350 .output()
351 .ok()
352 .filter(|o| o.status.success())
353 .and_then(|o| String::from_utf8(o.stdout).ok())
354 .map(|s| s.trim().to_string());
355
356 let commit = std::process::Command::new("git")
357 .args([
358 "-C",
359 &workspace_path.display().to_string(),
360 "rev-parse",
361 "HEAD",
362 ])
363 .stdout(std::process::Stdio::piped())
364 .stderr(std::process::Stdio::null())
365 .output()
366 .ok()
367 .filter(|o| o.status.success())
368 .and_then(|o| String::from_utf8(o.stdout).ok())
369 .map(|s| s.trim().to_string());
370
371 if branch.is_some() || commit.is_some() {
373 Some(Self { branch, commit })
374 } else {
375 None
376 }
377 }
378}
379
380#[derive(Debug, Clone, Serialize, Deserialize)]
382pub struct SerializableMessage {
383 pub role: String,
385 pub content: serde_json::Value,
387}
388
389impl From<&AgentSession> for SerializableSession {
390 fn from(session: &AgentSession) -> Self {
391 Self {
392 id: session.id.0.to_string(),
393 user_id: hex::encode(session.user_id),
394 messages: session
395 .messages
396 .iter()
397 .map(|m| SerializableMessage {
398 role: match m.role {
399 astrid_llm::MessageRole::System => "system".to_string(),
400 astrid_llm::MessageRole::User => "user".to_string(),
401 astrid_llm::MessageRole::Assistant => "assistant".to_string(),
402 astrid_llm::MessageRole::Tool => "tool".to_string(),
403 },
404 content: serde_json::to_value(&m.content).unwrap_or_default(),
405 })
406 .collect(),
407 system_prompt: session.system_prompt.clone(),
408 created_at: session.created_at,
409 token_count: session.token_count,
410 metadata: session.metadata.clone(),
411 allowances: session.allowance_store.export_session_allowances(),
412 budget_snapshot: Some(session.budget_tracker.snapshot()),
413 escape_state: Some(session.escape_handler.export_state()),
414 workspace_path: session
415 .workspace_path
416 .as_ref()
417 .map(|p| p.display().to_string()),
418 model: session.model.clone(),
419 git_state: session
420 .workspace_path
421 .as_ref()
422 .and_then(|p| GitState::capture(p)),
423 }
424 }
425}
426
427impl SerializableSession {
428 #[must_use]
432 pub fn to_session(&self) -> AgentSession {
433 let mut user_id = [0u8; 8];
434 if let Ok(bytes) = hex::decode(&self.user_id)
435 && bytes.len() >= 8
436 {
437 user_id.copy_from_slice(&bytes[..8]);
438 }
439
440 let id =
441 uuid::Uuid::parse_str(&self.id).map_or_else(|_| SessionId::new(), SessionId::from_uuid);
442
443 let messages: Vec<Message> = self
444 .messages
445 .iter()
446 .filter_map(|m| {
447 let content: astrid_llm::MessageContent =
448 serde_json::from_value(m.content.clone()).ok()?;
449 let role = match m.role.as_str() {
450 "system" => astrid_llm::MessageRole::System,
451 "user" => astrid_llm::MessageRole::User,
452 "assistant" => astrid_llm::MessageRole::Assistant,
453 "tool" => astrid_llm::MessageRole::Tool,
454 _ => return None,
455 };
456 Some(Message { role, content })
457 })
458 .collect();
459
460 let mut session = AgentSession::with_id(id, user_id, &self.system_prompt);
461 session.messages = messages;
462 session.created_at = self.created_at;
463 session.token_count = self.token_count;
464 session.metadata = self.metadata.clone();
465 session.workspace_path = self.workspace_path.as_ref().map(PathBuf::from);
466 session.model.clone_from(&self.model);
467
468 if !self.allowances.is_empty() {
470 session
471 .allowance_store
472 .import_allowances(self.allowances.clone());
473 }
474
475 if let Some(snapshot) = &self.budget_snapshot {
477 session.budget_tracker = Arc::new(BudgetTracker::restore(snapshot.clone()));
478 }
479
480 if let Some(escape_state) = &self.escape_state {
482 session.escape_handler.restore_state(escape_state.clone());
483 }
484
485 session
486 }
487}
488
489#[cfg(test)]
490mod tests {
491 use super::*;
492 use astrid_llm::Message;
493
494 #[test]
495 fn test_session_creation() {
496 let session = AgentSession::new([0u8; 8], "You are helpful");
497 assert!(session.messages.is_empty());
498 assert_eq!(session.system_prompt, "You are helpful");
499 }
500
501 #[test]
502 fn test_add_message() {
503 let mut session = AgentSession::new([0u8; 8], "");
504 session.add_message(Message::user("Hello"));
505 session.add_message(Message::assistant("Hi!"));
506
507 assert_eq!(session.messages.len(), 2);
508 assert!(session.token_count > 0);
509 }
510
511 #[test]
512 fn test_serialization_roundtrip() {
513 let mut session = AgentSession::new([1u8; 8], "Test prompt");
514 session.add_message(Message::user("Hello"));
515 session.add_message(Message::assistant("World"));
516
517 let serializable = SerializableSession::from(&session);
518 let restored = serializable.to_session();
519
520 assert_eq!(restored.system_prompt, session.system_prompt);
521 assert_eq!(restored.messages.len(), session.messages.len());
522 }
523
524 #[test]
525 fn test_budget_snapshot_roundtrip() {
526 let session = AgentSession::new([1u8; 8], "Test");
527 session.budget_tracker.record_cost(42.5);
528
529 let serializable = SerializableSession::from(&session);
530 let restored = serializable.to_session();
531
532 assert!((restored.budget_tracker.spent() - 42.5).abs() < f64::EPSILON);
533 }
534
535 #[test]
536 fn test_workspace_path_roundtrip() {
537 let session = AgentSession::new([1u8; 8], "Test").with_workspace("/home/user/project");
538
539 let serializable = SerializableSession::from(&session);
540 let restored = serializable.to_session();
541
542 assert_eq!(
543 restored.workspace_path,
544 Some(PathBuf::from("/home/user/project"))
545 );
546 }
547
548 #[test]
549 fn test_with_shared_stores() {
550 let parent = AgentSession::new([1u8; 8], "Parent");
551
552 parent.budget_tracker.record_cost(10.0);
554
555 let child = AgentSession::with_shared_stores(
557 SessionId::new(),
558 [1u8; 8],
559 "Child",
560 Arc::clone(&parent.allowance_store),
561 Arc::clone(&parent.capabilities),
562 Arc::clone(&parent.budget_tracker),
563 );
564
565 assert!((child.budget_tracker.spent() - 10.0).abs() < f64::EPSILON);
567
568 child.budget_tracker.record_cost(5.0);
570 assert!((parent.budget_tracker.spent() - 15.0).abs() < f64::EPSILON);
571
572 assert!(Arc::ptr_eq(&parent.budget_tracker, &child.budget_tracker));
574 assert!(Arc::ptr_eq(&parent.allowance_store, &child.allowance_store));
575 assert!(Arc::ptr_eq(&parent.capabilities, &child.capabilities));
576
577 assert!(child.messages.is_empty());
579
580 assert!(!Arc::ptr_eq(
582 &parent.approval_manager,
583 &child.approval_manager
584 ));
585 }
586
587 #[test]
588 fn test_backwards_compatible_deserialization() {
589 let json = r#"{
591 "id": "00000000-0000-0000-0000-000000000001",
592 "user_id": "0101010101010101",
593 "messages": [],
594 "system_prompt": "Test",
595 "created_at": "2024-01-01T00:00:00Z",
596 "token_count": 0,
597 "metadata": {
598 "title": null,
599 "tags": [],
600 "turn_count": 0,
601 "tool_call_count": 0,
602 "approval_count": 0,
603 "custom": {}
604 }
605 }"#;
606
607 let serializable: SerializableSession = serde_json::from_str(json).unwrap();
608 let session = serializable.to_session();
609 assert_eq!(session.system_prompt, "Test");
610 assert!(session.workspace_path.is_none());
611 assert_eq!(session.budget_tracker.spent(), 0.0);
612 }
613}