agentzero_channels/
interruption.rs1use std::collections::HashMap;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::Arc;
10use tokio::sync::Mutex;
11
12#[derive(Clone)]
14pub struct CancelToken {
15 cancelled: Arc<AtomicBool>,
16}
17
18impl CancelToken {
19 fn new() -> Self {
20 Self {
21 cancelled: Arc::new(AtomicBool::new(false)),
22 }
23 }
24
25 pub fn cancel(&self) {
27 self.cancelled.store(true, Ordering::SeqCst);
28 }
29
30 pub fn is_cancelled(&self) -> bool {
32 self.cancelled.load(Ordering::SeqCst)
33 }
34}
35
36#[derive(Debug, Clone, Hash, PartialEq, Eq)]
38pub struct TurnKey {
39 pub sender: String,
40 pub channel: String,
41}
42
43impl TurnKey {
44 pub fn new(sender: impl Into<String>, channel: impl Into<String>) -> Self {
45 Self {
46 sender: sender.into(),
47 channel: channel.into(),
48 }
49 }
50}
51
52pub struct InterruptionDetector {
54 active_turns: Arc<Mutex<HashMap<TurnKey, CancelToken>>>,
55}
56
57impl InterruptionDetector {
58 pub fn new() -> Self {
59 Self {
60 active_turns: Arc::new(Mutex::new(HashMap::new())),
61 }
62 }
63
64 pub async fn start_turn(&self, key: TurnKey) -> CancelToken {
70 let mut turns = self.active_turns.lock().await;
71
72 if let Some(existing) = turns.remove(&key) {
74 existing.cancel();
75 }
76
77 let token = CancelToken::new();
78 turns.insert(key, token.clone());
79 token
80 }
81
82 pub async fn finish_turn(&self, key: &TurnKey) {
84 self.active_turns.lock().await.remove(key);
85 }
86
87 pub async fn has_active_turn(&self, key: &TurnKey) -> bool {
89 self.active_turns.lock().await.contains_key(key)
90 }
91
92 pub async fn active_count(&self) -> usize {
94 self.active_turns.lock().await.len()
95 }
96
97 pub async fn cancel_all(&self) {
99 let mut turns = self.active_turns.lock().await;
100 for (_, token) in turns.drain() {
101 token.cancel();
102 }
103 }
104}
105
106impl Default for InterruptionDetector {
107 fn default() -> Self {
108 Self::new()
109 }
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115
116 fn key(sender: &str, channel: &str) -> TurnKey {
117 TurnKey::new(sender, channel)
118 }
119
120 #[tokio::test]
121 async fn start_turn_returns_active_token() {
122 let detector = InterruptionDetector::new();
123 let token = detector.start_turn(key("alice", "telegram")).await;
124
125 assert!(!token.is_cancelled());
126 assert!(detector.has_active_turn(&key("alice", "telegram")).await);
127 assert_eq!(detector.active_count().await, 1);
128 }
129
130 #[tokio::test]
131 async fn new_turn_cancels_previous() {
132 let detector = InterruptionDetector::new();
133 let token1 = detector.start_turn(key("alice", "telegram")).await;
134 let token2 = detector.start_turn(key("alice", "telegram")).await;
135
136 assert!(token1.is_cancelled(), "previous turn should be cancelled");
137 assert!(!token2.is_cancelled(), "new turn should be active");
138 assert_eq!(detector.active_count().await, 1);
139 }
140
141 #[tokio::test]
142 async fn different_senders_are_independent() {
143 let detector = InterruptionDetector::new();
144 let token_alice = detector.start_turn(key("alice", "telegram")).await;
145 let token_bob = detector.start_turn(key("bob", "telegram")).await;
146
147 assert!(!token_alice.is_cancelled());
148 assert!(!token_bob.is_cancelled());
149 assert_eq!(detector.active_count().await, 2);
150 }
151
152 #[tokio::test]
153 async fn different_channels_are_independent() {
154 let detector = InterruptionDetector::new();
155 let token_tg = detector.start_turn(key("alice", "telegram")).await;
156 let token_slack = detector.start_turn(key("alice", "slack")).await;
157
158 assert!(!token_tg.is_cancelled());
159 assert!(!token_slack.is_cancelled());
160 assert_eq!(detector.active_count().await, 2);
161 }
162
163 #[tokio::test]
164 async fn finish_turn_removes_tracking() {
165 let detector = InterruptionDetector::new();
166 let _token = detector.start_turn(key("alice", "telegram")).await;
167
168 detector.finish_turn(&key("alice", "telegram")).await;
169 assert!(!detector.has_active_turn(&key("alice", "telegram")).await);
170 assert_eq!(detector.active_count().await, 0);
171 }
172
173 #[tokio::test]
174 async fn cancel_all_cancels_everything() {
175 let detector = InterruptionDetector::new();
176 let t1 = detector.start_turn(key("alice", "telegram")).await;
177 let t2 = detector.start_turn(key("bob", "slack")).await;
178
179 detector.cancel_all().await;
180
181 assert!(t1.is_cancelled());
182 assert!(t2.is_cancelled());
183 assert_eq!(detector.active_count().await, 0);
184 }
185
186 #[tokio::test]
187 async fn rapid_interruption_sequence() {
188 let detector = InterruptionDetector::new();
189 let k = key("alice", "telegram");
190
191 let t1 = detector.start_turn(k.clone()).await;
192 let t2 = detector.start_turn(k.clone()).await;
193 let t3 = detector.start_turn(k.clone()).await;
194
195 assert!(t1.is_cancelled());
196 assert!(t2.is_cancelled());
197 assert!(!t3.is_cancelled());
198 assert_eq!(detector.active_count().await, 1);
199 }
200
201 #[tokio::test]
202 async fn handler_detects_cancellation() {
203 let detector = InterruptionDetector::new();
204 let k = key("alice", "telegram");
205
206 let token = detector.start_turn(k.clone()).await;
207
208 assert!(!token.is_cancelled());
210
211 let _new_token = detector.start_turn(k).await;
213
214 assert!(token.is_cancelled());
216 }
217}