Skip to main content

arbiter_behavior/
tracker.rs

1//! Behavioral call sequence tracker.
2//!
3//! Tracks the sequence of tool calls within a task session for
4//! anomaly detection purposes.
5
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::RwLock;
11use uuid::Uuid;
12
13use crate::classifier::OperationType;
14
15/// A record of a single tool call in a session.
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct CallRecord {
18    /// The tool that was called.
19    pub tool_name: String,
20    /// The MCP method.
21    pub method: String,
22    /// Classified operation type.
23    pub operation_type: OperationType,
24    /// When the call was made.
25    pub timestamp: DateTime<Utc>,
26}
27
28/// Tracks call sequences per session.
29#[derive(Clone)]
30pub struct BehaviorTracker {
31    /// session_id -> ordered list of call records.
32    records: Arc<RwLock<HashMap<Uuid, Vec<CallRecord>>>>,
33}
34
35impl BehaviorTracker {
36    /// Create a new behavior tracker.
37    pub fn new() -> Self {
38        Self {
39            records: Arc::new(RwLock::new(HashMap::new())),
40        }
41    }
42
43    /// Record a tool call for a session.
44    pub async fn record_call(
45        &self,
46        session_id: Uuid,
47        tool_name: String,
48        method: String,
49        operation_type: OperationType,
50    ) {
51        let record = CallRecord {
52            tool_name,
53            method,
54            operation_type,
55            timestamp: Utc::now(),
56        };
57
58        tracing::trace!(
59            session_id = %session_id,
60            tool = %record.tool_name,
61            op = ?record.operation_type,
62            "recording tool call"
63        );
64
65        let mut records = self.records.write().await;
66        records
67            .entry(session_id)
68            .or_insert_with(Vec::new)
69            .push(record);
70    }
71
72    /// Get all call records for a session.
73    pub async fn get_records(&self, session_id: Uuid) -> Vec<CallRecord> {
74        let records = self.records.read().await;
75        records.get(&session_id).cloned().unwrap_or_default()
76    }
77
78    /// Remove records for a session (called on session close/cleanup).
79    pub async fn clear_session(&self, session_id: Uuid) {
80        let mut records = self.records.write().await;
81        records.remove(&session_id);
82    }
83}
84
85impl Default for BehaviorTracker {
86    fn default() -> Self {
87        Self::new()
88    }
89}
90
91#[cfg(test)]
92mod tests {
93    use super::*;
94
95    #[tokio::test]
96    async fn track_and_retrieve_calls() {
97        let tracker = BehaviorTracker::new();
98        let session_id = Uuid::new_v4();
99
100        tracker
101            .record_call(
102                session_id,
103                "read_file".into(),
104                "tools/call".into(),
105                OperationType::Read,
106            )
107            .await;
108
109        tracker
110            .record_call(
111                session_id,
112                "list_dir".into(),
113                "tools/call".into(),
114                OperationType::Read,
115            )
116            .await;
117
118        let records = tracker.get_records(session_id).await;
119        assert_eq!(records.len(), 2);
120        assert_eq!(records[0].tool_name, "read_file");
121        assert_eq!(records[1].tool_name, "list_dir");
122    }
123
124    #[tokio::test]
125    async fn clear_session_records() {
126        let tracker = BehaviorTracker::new();
127        let session_id = Uuid::new_v4();
128
129        tracker
130            .record_call(
131                session_id,
132                "read_file".into(),
133                "tools/call".into(),
134                OperationType::Read,
135            )
136            .await;
137
138        tracker.clear_session(session_id).await;
139        let records = tracker.get_records(session_id).await;
140        assert!(records.is_empty());
141    }
142}