1use std::collections::HashMap;
34use std::fmt;
35use std::sync::{Arc, RwLock};
36use std::time::{Duration, Instant, SystemTime};
37
38pub type SessionId = String;
40
41#[derive(Debug, Clone)]
43pub struct AgentContext {
44 pub session_id: SessionId,
46 pub working_dir: String,
48 pub variables: HashMap<String, ContextValue>,
50 pub permissions: AgentPermissions,
52 pub started_at: SystemTime,
54 pub last_activity: Instant,
56 pub audit: Vec<AuditEntry>,
58 pub transaction: Option<TransactionScope>,
60 pub budget: OperationBudget,
62 pub tool_registry: Vec<ToolDefinition>,
64 pub tool_calls: Vec<ToolCallRecord>,
66}
67
68#[derive(Debug, Clone)]
70pub struct ToolDefinition {
71 pub name: String,
73 pub description: String,
75 pub parameters_schema: Option<String>,
77 pub requires_confirmation: bool,
79}
80
81#[derive(Debug, Clone)]
83pub struct ToolCallRecord {
84 pub call_id: String,
86 pub tool_name: String,
88 pub arguments: String,
90 pub result: Option<String>,
92 pub error: Option<String>,
94 pub timestamp: SystemTime,
96}
97
98#[derive(Debug, Clone, PartialEq)]
100pub enum ContextValue {
101 String(String),
102 Number(f64),
103 Bool(bool),
104 List(Vec<ContextValue>),
105 Object(HashMap<String, ContextValue>),
106 Null,
107}
108
109impl fmt::Display for ContextValue {
110 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
111 match self {
112 ContextValue::String(s) => write!(f, "\"{}\"", s),
113 ContextValue::Number(n) => write!(f, "{}", n),
114 ContextValue::Bool(b) => write!(f, "{}", b),
115 ContextValue::List(l) => {
116 write!(f, "[")?;
117 for (i, v) in l.iter().enumerate() {
118 if i > 0 {
119 write!(f, ", ")?;
120 }
121 write!(f, "{}", v)?;
122 }
123 write!(f, "]")
124 }
125 ContextValue::Object(o) => {
126 write!(f, "{{")?;
127 for (i, (k, v)) in o.iter().enumerate() {
128 if i > 0 {
129 write!(f, ", ")?;
130 }
131 write!(f, "\"{}\": {}", k, v)?;
132 }
133 write!(f, "}}")
134 }
135 ContextValue::Null => write!(f, "null"),
136 }
137 }
138}
139
140#[derive(Debug, Clone, Default)]
142pub struct AgentPermissions {
143 pub filesystem: FsPermissions,
145 pub database: DbPermissions,
147 pub calculator: bool,
149 pub network: NetworkPermissions,
151}
152
153#[derive(Debug, Clone, Default)]
155pub struct FsPermissions {
156 pub read: bool,
158 pub write: bool,
160 pub mkdir: bool,
162 pub delete: bool,
164 pub allowed_paths: Vec<String>,
166}
167
168#[derive(Debug, Clone, Default)]
170pub struct DbPermissions {
171 pub read: bool,
173 pub write: bool,
175 pub create: bool,
177 pub drop: bool,
179 pub allowed_tables: Vec<String>,
181}
182
183#[derive(Debug, Clone, Default)]
185pub struct NetworkPermissions {
186 pub http: bool,
188 pub allowed_domains: Vec<String>,
190}
191
192#[derive(Debug, Clone)]
194pub struct AuditEntry {
195 pub timestamp: SystemTime,
197 pub operation: AuditOperation,
199 pub resource: String,
201 pub result: AuditResult,
203 pub metadata: HashMap<String, String>,
205}
206
207#[derive(Debug, Clone, PartialEq, Eq)]
209pub enum AuditOperation {
210 FsRead,
211 FsWrite,
212 FsMkdir,
213 FsDelete,
214 FsList,
215 DbQuery,
216 DbInsert,
217 DbUpdate,
218 DbDelete,
219 Calculate,
220 VarSet,
221 VarGet,
222 TxBegin,
223 TxCommit,
224 TxRollback,
225}
226
227impl fmt::Display for AuditOperation {
228 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
229 match self {
230 AuditOperation::FsRead => write!(f, "fs.read"),
231 AuditOperation::FsWrite => write!(f, "fs.write"),
232 AuditOperation::FsMkdir => write!(f, "fs.mkdir"),
233 AuditOperation::FsDelete => write!(f, "fs.delete"),
234 AuditOperation::FsList => write!(f, "fs.list"),
235 AuditOperation::DbQuery => write!(f, "db.query"),
236 AuditOperation::DbInsert => write!(f, "db.insert"),
237 AuditOperation::DbUpdate => write!(f, "db.update"),
238 AuditOperation::DbDelete => write!(f, "db.delete"),
239 AuditOperation::Calculate => write!(f, "calc"),
240 AuditOperation::VarSet => write!(f, "var.set"),
241 AuditOperation::VarGet => write!(f, "var.get"),
242 AuditOperation::TxBegin => write!(f, "tx.begin"),
243 AuditOperation::TxCommit => write!(f, "tx.commit"),
244 AuditOperation::TxRollback => write!(f, "tx.rollback"),
245 }
246 }
247}
248
249#[derive(Debug, Clone)]
251pub enum AuditResult {
252 Success,
253 Error(String),
254 Denied(String),
255}
256
257#[derive(Debug, Clone)]
259pub struct TransactionScope {
260 pub tx_id: u64,
262 pub started_at: Instant,
264 pub savepoints: Vec<String>,
266 pub pending_writes: Vec<PendingWrite>,
268}
269
270#[derive(Debug, Clone)]
272pub struct PendingWrite {
273 pub resource_type: ResourceType,
275 pub resource_key: String,
277 pub original_value: Option<Vec<u8>>,
279}
280
281#[derive(Debug, Clone, PartialEq, Eq)]
283pub enum ResourceType {
284 File,
285 Directory,
286 Table,
287 Variable,
288}
289
290#[derive(Debug, Clone)]
292pub struct OperationBudget {
293 pub max_tokens: Option<u64>,
295 pub tokens_used: u64,
297 pub max_cost: Option<u64>,
299 pub cost_used: u64,
301 pub max_operations: Option<u64>,
303 pub operations_used: u64,
305}
306
307impl Default for OperationBudget {
308 fn default() -> Self {
309 Self {
310 max_tokens: None,
311 max_cost: None,
312 max_operations: Some(10000),
313 tokens_used: 0,
314 cost_used: 0,
315 operations_used: 0,
316 }
317 }
318}
319
320#[derive(Debug, Clone)]
322pub enum ContextError {
323 PermissionDenied(String),
324 VariableNotFound(String),
325 BudgetExceeded(String),
326 TransactionError(String),
327 InvalidPath(String),
328 SessionExpired,
329}
330
331impl fmt::Display for ContextError {
332 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
333 match self {
334 ContextError::PermissionDenied(msg) => write!(f, "Permission denied: {}", msg),
335 ContextError::VariableNotFound(name) => write!(f, "Variable not found: {}", name),
336 ContextError::BudgetExceeded(msg) => write!(f, "Budget exceeded: {}", msg),
337 ContextError::TransactionError(msg) => write!(f, "Transaction error: {}", msg),
338 ContextError::InvalidPath(path) => write!(f, "Invalid path: {}", path),
339 ContextError::SessionExpired => write!(f, "Session expired"),
340 }
341 }
342}
343
344impl std::error::Error for ContextError {}
345
346impl AgentContext {
347 pub fn new(session_id: SessionId) -> Self {
349 let now = Instant::now();
350 Self {
351 session_id: session_id.clone(),
352 working_dir: format!("/agents/{}", session_id),
353 variables: HashMap::new(),
354 permissions: AgentPermissions::default(),
355 started_at: SystemTime::now(),
356 last_activity: now,
357 audit: Vec::new(),
358 transaction: None,
359 budget: OperationBudget::default(),
360 tool_registry: Vec::new(),
361 tool_calls: Vec::new(),
362 }
363 }
364
365 pub fn with_working_dir(session_id: SessionId, working_dir: String) -> Self {
367 let mut ctx = Self::new(session_id);
368 ctx.working_dir = working_dir;
369 ctx
370 }
371
372 pub fn with_full_permissions(session_id: SessionId) -> Self {
374 let mut ctx = Self::new(session_id);
375 ctx.permissions = AgentPermissions {
376 filesystem: FsPermissions {
377 read: true,
378 write: true,
379 mkdir: true,
380 delete: true,
381 allowed_paths: vec!["/".into()],
382 },
383 database: DbPermissions {
384 read: true,
385 write: true,
386 create: true,
387 drop: true,
388 allowed_tables: vec!["*".into()],
389 },
390 calculator: true,
391 network: NetworkPermissions::default(),
392 };
393 ctx
394 }
395
396 pub fn register_tool(&mut self, tool: ToolDefinition) {
398 self.tool_registry.push(tool);
399 }
400
401 pub fn record_tool_call(&mut self, call: ToolCallRecord) {
403 self.tool_calls.push(call);
404 }
405
406 pub fn set_var(&mut self, name: &str, value: ContextValue) {
408 self.variables.insert(name.to_string(), value.clone());
409 self.touch();
410 self.audit(AuditOperation::VarSet, name, AuditResult::Success);
411 }
412
413 pub fn get_var(&mut self, name: &str) -> Option<ContextValue> {
415 self.touch();
416 let result = self.variables.get(name).cloned();
417 if result.is_some() {
418 self.audit(AuditOperation::VarGet, name, AuditResult::Success);
419 } else {
420 self.audit(
421 AuditOperation::VarGet,
422 name,
423 AuditResult::Error("not found".into()),
424 );
425 }
426 result
427 }
428
429 pub fn peek_var(&self, name: &str) -> Option<&ContextValue> {
431 self.variables.get(name)
432 }
433
434 fn touch(&mut self) {
436 self.last_activity = Instant::now();
437 }
438
439 fn audit(&mut self, operation: AuditOperation, resource: &str, result: AuditResult) {
441 self.audit.push(AuditEntry {
442 timestamp: SystemTime::now(),
443 operation,
444 resource: resource.to_string(),
445 result,
446 metadata: HashMap::new(),
447 });
448 }
449
450 pub fn check_fs_permission(&self, path: &str, op: AuditOperation) -> Result<(), ContextError> {
452 let perm = match op {
453 AuditOperation::FsRead | AuditOperation::FsList => self.permissions.filesystem.read,
454 AuditOperation::FsWrite => self.permissions.filesystem.write,
455 AuditOperation::FsMkdir => self.permissions.filesystem.mkdir,
456 AuditOperation::FsDelete => self.permissions.filesystem.delete,
457 _ => {
458 return Err(ContextError::PermissionDenied(
459 "invalid fs operation".into(),
460 ));
461 }
462 };
463
464 if !perm {
465 return Err(ContextError::PermissionDenied(format!(
466 "{} not allowed",
467 op
468 )));
469 }
470
471 if !self.permissions.filesystem.allowed_paths.is_empty() {
473 let allowed = self
474 .permissions
475 .filesystem
476 .allowed_paths
477 .iter()
478 .any(|p| path.starts_with(p) || p == "*");
479 if !allowed {
480 return Err(ContextError::PermissionDenied(format!(
481 "path {} not in allowed paths",
482 path
483 )));
484 }
485 }
486
487 Ok(())
488 }
489
490 pub fn check_db_permission(&self, table: &str, op: AuditOperation) -> Result<(), ContextError> {
492 let perm = match op {
493 AuditOperation::DbQuery => self.permissions.database.read,
494 AuditOperation::DbInsert | AuditOperation::DbUpdate => self.permissions.database.write,
495 AuditOperation::DbDelete => self.permissions.database.drop,
496 _ => {
497 return Err(ContextError::PermissionDenied(
498 "invalid db operation".into(),
499 ));
500 }
501 };
502
503 if !perm {
504 return Err(ContextError::PermissionDenied(format!(
505 "{} not allowed",
506 op
507 )));
508 }
509
510 if !self.permissions.database.allowed_tables.is_empty() {
512 let allowed = self.permissions.database.allowed_tables.iter().any(|t| {
513 t == "*" || t == table || (t.ends_with('*') && table.starts_with(&t[..t.len() - 1]))
514 });
515 if !allowed {
516 return Err(ContextError::PermissionDenied(format!(
517 "table {} not in allowed tables",
518 table
519 )));
520 }
521 }
522
523 Ok(())
524 }
525
526 pub fn consume_budget(&mut self, tokens: u64, cost: u64) -> Result<(), ContextError> {
528 self.budget.operations_used += 1;
529 self.budget.tokens_used += tokens;
530 self.budget.cost_used += cost;
531
532 if let Some(max) = self.budget.max_operations
533 && self.budget.operations_used > max
534 {
535 return Err(ContextError::BudgetExceeded("max operations".into()));
536 }
537 if let Some(max) = self.budget.max_tokens
538 && self.budget.tokens_used > max
539 {
540 return Err(ContextError::BudgetExceeded("max tokens".into()));
541 }
542 if let Some(max) = self.budget.max_cost
543 && self.budget.cost_used > max
544 {
545 return Err(ContextError::BudgetExceeded("max cost".into()));
546 }
547
548 Ok(())
549 }
550
551 pub fn begin_transaction(&mut self, tx_id: u64) -> Result<(), ContextError> {
553 if self.transaction.is_some() {
554 return Err(ContextError::TransactionError(
555 "already in transaction".into(),
556 ));
557 }
558
559 self.transaction = Some(TransactionScope {
560 tx_id,
561 started_at: Instant::now(),
562 savepoints: Vec::new(),
563 pending_writes: Vec::new(),
564 });
565
566 self.audit(
567 AuditOperation::TxBegin,
568 &format!("tx:{}", tx_id),
569 AuditResult::Success,
570 );
571 Ok(())
572 }
573
574 pub fn commit_transaction(&mut self) -> Result<(), ContextError> {
576 let tx = self
577 .transaction
578 .take()
579 .ok_or_else(|| ContextError::TransactionError("no active transaction".into()))?;
580
581 self.audit(
582 AuditOperation::TxCommit,
583 &format!("tx:{}", tx.tx_id),
584 AuditResult::Success,
585 );
586 Ok(())
587 }
588
589 pub fn rollback_transaction(&mut self) -> Result<Vec<PendingWrite>, ContextError> {
591 let tx = self
592 .transaction
593 .take()
594 .ok_or_else(|| ContextError::TransactionError("no active transaction".into()))?;
595
596 self.audit(
597 AuditOperation::TxRollback,
598 &format!("tx:{}", tx.tx_id),
599 AuditResult::Success,
600 );
601
602 Ok(tx.pending_writes)
603 }
604
605 pub fn savepoint(&mut self, name: &str) -> Result<(), ContextError> {
607 let tx = self
608 .transaction
609 .as_mut()
610 .ok_or_else(|| ContextError::TransactionError("no active transaction".into()))?;
611
612 tx.savepoints.push(name.to_string());
613 Ok(())
614 }
615
616 pub fn record_pending_write(
618 &mut self,
619 resource_type: ResourceType,
620 resource_key: String,
621 original_value: Option<Vec<u8>>,
622 ) -> Result<(), ContextError> {
623 let tx = self
624 .transaction
625 .as_mut()
626 .ok_or_else(|| ContextError::TransactionError("no active transaction".into()))?;
627
628 tx.pending_writes.push(PendingWrite {
629 resource_type,
630 resource_key,
631 original_value,
632 });
633 Ok(())
634 }
635
636 pub fn resolve_path(&self, path: &str) -> String {
638 if path.starts_with('/') {
639 path.to_string()
640 } else {
641 format!("{}/{}", self.working_dir, path)
642 }
643 }
644
645 pub fn substitute_vars(&self, input: &str) -> String {
647 let mut result = input.to_string();
648
649 for (name, value) in &self.variables {
650 let pattern = format!("${}", name);
651 let replacement = match value {
652 ContextValue::String(s) => s.clone(),
653 ContextValue::Number(n) => n.to_string(),
654 ContextValue::Bool(b) => b.to_string(),
655 _ => value.to_string(),
656 };
657 result = result.replace(&pattern, &replacement);
658 }
659
660 result
661 }
662
663 pub fn age(&self) -> Duration {
665 SystemTime::now()
666 .duration_since(self.started_at)
667 .unwrap_or_default()
668 }
669
670 pub fn idle_time(&self) -> Duration {
672 self.last_activity.elapsed()
673 }
674
675 pub fn is_expired(&self, idle_timeout: Duration) -> bool {
677 self.idle_time() > idle_timeout
678 }
679
680 pub fn export_audit(&self) -> Vec<HashMap<String, String>> {
682 self.audit
683 .iter()
684 .map(|entry| {
685 let mut m = HashMap::new();
686 m.insert(
687 "timestamp".into(),
688 entry
689 .timestamp
690 .duration_since(SystemTime::UNIX_EPOCH)
691 .map(|d| d.as_secs().to_string())
692 .unwrap_or_default(),
693 );
694 m.insert("operation".into(), entry.operation.to_string());
695 m.insert("resource".into(), entry.resource.clone());
696 m.insert(
697 "result".into(),
698 match &entry.result {
699 AuditResult::Success => "success".into(),
700 AuditResult::Error(e) => format!("error:{}", e),
701 AuditResult::Denied(r) => format!("denied:{}", r),
702 },
703 );
704 for (k, v) in &entry.metadata {
705 m.insert(k.clone(), v.clone());
706 }
707 m
708 })
709 .collect()
710 }
711}
712
713pub struct SessionManager {
715 sessions: RwLock<HashMap<SessionId, Arc<RwLock<AgentContext>>>>,
716 idle_timeout: Duration,
717}
718
719impl SessionManager {
720 pub fn new(idle_timeout: Duration) -> Self {
722 Self {
723 sessions: RwLock::new(HashMap::new()),
724 idle_timeout,
725 }
726 }
727
728 pub fn create_session(&self, session_id: SessionId) -> Arc<RwLock<AgentContext>> {
730 let ctx = Arc::new(RwLock::new(AgentContext::new(session_id.clone())));
731 self.sessions
732 .write()
733 .unwrap()
734 .insert(session_id, ctx.clone());
735 ctx
736 }
737
738 pub fn get_session(&self, session_id: &str) -> Option<Arc<RwLock<AgentContext>>> {
740 let sessions = self.sessions.read().unwrap();
741 sessions.get(session_id).cloned()
742 }
743
744 pub fn get_or_create(&self, session_id: SessionId) -> Arc<RwLock<AgentContext>> {
746 if let Some(ctx) = self.get_session(&session_id) {
747 return ctx;
748 }
749 self.create_session(session_id)
750 }
751
752 pub fn remove_session(&self, session_id: &str) -> Option<Arc<RwLock<AgentContext>>> {
754 self.sessions.write().unwrap().remove(session_id)
755 }
756
757 pub fn cleanup_expired(&self) -> usize {
759 let mut sessions = self.sessions.write().unwrap();
760 let initial_count = sessions.len();
761
762 sessions.retain(|_, ctx| {
763 let ctx = ctx.read().unwrap();
764 !ctx.is_expired(self.idle_timeout)
765 });
766
767 initial_count - sessions.len()
768 }
769
770 pub fn session_count(&self) -> usize {
772 self.sessions.read().unwrap().len()
773 }
774}
775
776impl Default for SessionManager {
777 fn default() -> Self {
778 Self::new(Duration::from_secs(3600)) }
780}
781
782#[cfg(test)]
783mod tests {
784 use super::*;
785
786 #[test]
787 fn test_context_creation() {
788 let ctx = AgentContext::new("test-session".into());
789 assert_eq!(ctx.session_id, "test-session");
790 assert_eq!(ctx.working_dir, "/agents/test-session");
791 }
792
793 #[test]
794 fn test_variables() {
795 let mut ctx = AgentContext::new("test".into());
796 ctx.set_var("model", ContextValue::String("gpt-4".into()));
797 ctx.set_var("budget", ContextValue::Number(1000.0));
798
799 assert_eq!(
800 ctx.get_var("model"),
801 Some(ContextValue::String("gpt-4".into()))
802 );
803 assert_eq!(ctx.get_var("budget"), Some(ContextValue::Number(1000.0)));
804 }
805
806 #[test]
807 fn test_variable_substitution() {
808 let mut ctx = AgentContext::new("test".into());
809 ctx.set_var("name", ContextValue::String("Alice".into()));
810 ctx.set_var("count", ContextValue::Number(42.0));
811
812 let result = ctx.substitute_vars("Hello $name, you have $count items");
813 assert_eq!(result, "Hello Alice, you have 42 items");
814 }
815
816 #[test]
817 fn test_path_resolution() {
818 let ctx = AgentContext::with_working_dir("test".into(), "/home/agent".into());
819
820 assert_eq!(ctx.resolve_path("data.json"), "/home/agent/data.json");
821 assert_eq!(ctx.resolve_path("/absolute/path"), "/absolute/path");
822 }
823
824 #[test]
825 fn test_permissions() {
826 let mut ctx = AgentContext::new("test".into());
827 ctx.permissions.filesystem.read = true;
828 ctx.permissions.filesystem.allowed_paths = vec!["/allowed".into()];
829
830 assert!(
831 ctx.check_fs_permission("/allowed/file", AuditOperation::FsRead)
832 .is_ok()
833 );
834 assert!(
835 ctx.check_fs_permission("/forbidden/file", AuditOperation::FsRead)
836 .is_err()
837 );
838 assert!(
839 ctx.check_fs_permission("/allowed/file", AuditOperation::FsWrite)
840 .is_err()
841 );
842 }
843
844 #[test]
845 fn test_budget() {
846 let mut ctx = AgentContext::new("test".into());
847 ctx.budget.max_operations = Some(3);
848
849 assert!(ctx.consume_budget(100, 10).is_ok());
850 assert!(ctx.consume_budget(100, 10).is_ok());
851 assert!(ctx.consume_budget(100, 10).is_ok());
852 assert!(ctx.consume_budget(100, 10).is_err());
853 }
854
855 #[test]
856 fn test_transaction() {
857 let mut ctx = AgentContext::new("test".into());
858
859 assert!(ctx.begin_transaction(1).is_ok());
860 assert!(ctx.begin_transaction(2).is_err()); ctx.record_pending_write(
863 ResourceType::File,
864 "/test/file".into(),
865 Some(b"original".to_vec()),
866 )
867 .unwrap();
868
869 let pending = ctx.rollback_transaction().unwrap();
870 assert_eq!(pending.len(), 1);
871 }
872
873 #[test]
874 fn test_session_manager() {
875 let mgr = SessionManager::default();
876
877 let _s1 = mgr.create_session("s1".into());
878 let _s2 = mgr.create_session("s2".into());
879
880 assert_eq!(mgr.session_count(), 2);
881 assert!(mgr.get_session("s1").is_some());
882 assert!(mgr.get_session("s3").is_none());
883
884 mgr.remove_session("s1");
885 assert_eq!(mgr.session_count(), 1);
886 }
887}