ai_agent/
session_state.rs1use std::sync::atomic::{AtomicU32, Ordering};
10
11#[derive(Debug, Clone, PartialEq, Eq, Default)]
13pub enum SessionState {
14 #[default]
15 Idle,
16 Running,
17 RequiresAction { details: Option<RequiresActionDetails> },
18}
19
20impl SessionState {
21 pub fn as_str(&self) -> &str {
22 match self {
23 SessionState::Idle => "idle",
24 SessionState::Running => "running",
25 SessionState::RequiresAction { .. } => "requires_action",
26 }
27 }
28}
29
30#[derive(Debug, Clone, PartialEq, Eq)]
32pub struct RequiresActionDetails {
33 pub typ: ActionType,
34 pub permission_denial: Option<PermissionDenialInfo>,
35}
36
37#[derive(Debug, Clone, PartialEq, Eq)]
38pub enum ActionType {
39 Permission,
40 Question,
41 Interrupt,
42}
43
44#[derive(Debug, Clone, PartialEq, Eq)]
45pub struct PermissionDenialInfo {
46 pub tool_name: String,
47 pub tool_use_id: String,
48}
49
50#[derive(Debug, Default)]
52pub struct SessionStateManager {
53 state: std::sync::Mutex<SessionState>,
54 permission_denial_count: AtomicU32,
55}
56
57impl SessionStateManager {
58 pub fn new() -> Self {
59 Self {
60 state: std::sync::Mutex::new(SessionState::Idle),
61 permission_denial_count: AtomicU32::new(0),
62 }
63 }
64
65 pub fn get_state(&self) -> SessionState {
66 self.state.lock().unwrap().clone()
67 }
68
69 pub fn set_state(&self, state: SessionState) {
70 *self.state.lock().unwrap() = state;
71 }
72
73 pub fn start_running(&self) {
74 *self.state.lock().unwrap() = SessionState::Running;
75 }
76
77 pub fn stop(&self) {
78 *self.state.lock().unwrap() = SessionState::Idle;
79 }
80
81 pub fn require_action(&self, details: RequiresActionDetails) {
82 *self.state.lock().unwrap() =
83 SessionState::RequiresAction {
84 details: Some(details),
85 };
86 }
87
88 pub fn clear_action(&self) {
89 *self.state.lock().unwrap() = SessionState::Idle;
90 }
91
92 pub fn permission_denial_count(&self) -> u32 {
93 self.permission_denial_count.load(Ordering::Relaxed)
94 }
95
96 pub fn increment_permission_denial(&self) {
97 self.permission_denial_count.fetch_add(1, Ordering::Relaxed);
98 }
99
100 pub fn reset_permission_denial(&self) {
101 self.permission_denial_count.store(0, Ordering::Relaxed);
102 }
103
104 pub fn is_consistently_denied(&self, threshold: u32) -> bool {
106 self.permission_denial_count.load(Ordering::Relaxed) >= threshold
107 }
108}
109
110impl Clone for SessionStateManager {
111 fn clone(&self) -> Self {
112 let state = self.state.lock().unwrap().clone();
113 Self {
114 state: std::sync::Mutex::new(state),
115 permission_denial_count: AtomicU32::new(self.permission_denial_count.load(Ordering::Relaxed)),
116 }
117 }
118}
119
120#[cfg(test)]
121mod tests {
122 use super::*;
123
124 #[test]
125 fn test_session_state_transitions() {
126 let manager = SessionStateManager::new();
127
128 assert_eq!(manager.get_state(), SessionState::Idle);
130
131 manager.start_running();
133 assert_eq!(manager.get_state(), SessionState::Running);
134
135 manager.require_action(RequiresActionDetails {
137 typ: ActionType::Permission,
138 permission_denial: Some(PermissionDenialInfo {
139 tool_name: "Bash".to_string(),
140 tool_use_id: "test-123".to_string(),
141 }),
142 });
143 assert_eq!(manager.get_state().as_str(), "requires_action");
144
145 manager.clear_action();
147 assert_eq!(manager.get_state(), SessionState::Idle);
148
149 manager.stop();
151 assert_eq!(manager.get_state(), SessionState::Idle);
152 }
153
154 #[test]
155 fn test_permission_denial_count() {
156 let manager = SessionStateManager::new();
157 assert_eq!(manager.permission_denial_count(), 0);
158 assert!(!manager.is_consistently_denied(3));
159
160 manager.increment_permission_denial();
161 manager.increment_permission_denial();
162 assert_eq!(manager.permission_denial_count(), 2);
163 assert!(!manager.is_consistently_denied(3));
164
165 manager.increment_permission_denial();
166 assert!(manager.is_consistently_denied(3));
167
168 manager.reset_permission_denial();
169 assert_eq!(manager.permission_denial_count(), 0);
170 }
171
172 #[test]
173 fn test_session_state_as_str() {
174 assert_eq!(SessionState::Idle.as_str(), "idle");
175 assert_eq!(SessionState::Running.as_str(), "running");
176 assert_eq!(
177 SessionState::RequiresAction { details: None }.as_str(),
178 "requires_action"
179 );
180 }
181}