arbiter_behavior/
tracker.rs1use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::RwLock;
11use uuid::Uuid;
12
13use crate::classifier::OperationType;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct CallRecord {
18 pub tool_name: String,
20 pub method: String,
22 pub operation_type: OperationType,
24 pub timestamp: DateTime<Utc>,
26}
27
28#[derive(Clone)]
30pub struct BehaviorTracker {
31 records: Arc<RwLock<HashMap<Uuid, Vec<CallRecord>>>>,
33}
34
35impl BehaviorTracker {
36 pub fn new() -> Self {
38 Self {
39 records: Arc::new(RwLock::new(HashMap::new())),
40 }
41 }
42
43 pub async fn record_call(
45 &self,
46 session_id: Uuid,
47 tool_name: String,
48 method: String,
49 operation_type: OperationType,
50 ) {
51 let record = CallRecord {
52 tool_name,
53 method,
54 operation_type,
55 timestamp: Utc::now(),
56 };
57
58 tracing::trace!(
59 session_id = %session_id,
60 tool = %record.tool_name,
61 op = ?record.operation_type,
62 "recording tool call"
63 );
64
65 let mut records = self.records.write().await;
66 records
67 .entry(session_id)
68 .or_insert_with(Vec::new)
69 .push(record);
70 }
71
72 pub async fn get_records(&self, session_id: Uuid) -> Vec<CallRecord> {
74 let records = self.records.read().await;
75 records.get(&session_id).cloned().unwrap_or_default()
76 }
77
78 pub async fn clear_session(&self, session_id: Uuid) {
80 let mut records = self.records.write().await;
81 records.remove(&session_id);
82 }
83}
84
85impl Default for BehaviorTracker {
86 fn default() -> Self {
87 Self::new()
88 }
89}
90
91#[cfg(test)]
92mod tests {
93 use super::*;
94
95 #[tokio::test]
96 async fn track_and_retrieve_calls() {
97 let tracker = BehaviorTracker::new();
98 let session_id = Uuid::new_v4();
99
100 tracker
101 .record_call(
102 session_id,
103 "read_file".into(),
104 "tools/call".into(),
105 OperationType::Read,
106 )
107 .await;
108
109 tracker
110 .record_call(
111 session_id,
112 "list_dir".into(),
113 "tools/call".into(),
114 OperationType::Read,
115 )
116 .await;
117
118 let records = tracker.get_records(session_id).await;
119 assert_eq!(records.len(), 2);
120 assert_eq!(records[0].tool_name, "read_file");
121 assert_eq!(records[1].tool_name, "list_dir");
122 }
123
124 #[tokio::test]
125 async fn clear_session_records() {
126 let tracker = BehaviorTracker::new();
127 let session_id = Uuid::new_v4();
128
129 tracker
130 .record_call(
131 session_id,
132 "read_file".into(),
133 "tools/call".into(),
134 OperationType::Read,
135 )
136 .await;
137
138 tracker.clear_session(session_id).await;
139 let records = tracker.get_records(session_id).await;
140 assert!(records.is_empty());
141 }
142}