Skip to main content

agentzero_channels/
interruption.rs

1//! Same-sender same-channel message interruption.
2//!
3//! When `interrupt_on_new_message` is enabled, a new message from the same
4//! sender in the same channel cancels any in-flight turn for that sender.
5//! The handler task checks the cancellation token and aborts if interrupted.
6
7use std::collections::HashMap;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::Arc;
10use tokio::sync::Mutex;
11
12/// A lightweight cancellation token backed by `AtomicBool`.
13#[derive(Clone)]
14pub struct CancelToken {
15    cancelled: Arc<AtomicBool>,
16}
17
18impl CancelToken {
19    fn new() -> Self {
20        Self {
21            cancelled: Arc::new(AtomicBool::new(false)),
22        }
23    }
24
25    /// Mark this token as cancelled.
26    pub fn cancel(&self) {
27        self.cancelled.store(true, Ordering::SeqCst);
28    }
29
30    /// Check if this token has been cancelled.
31    pub fn is_cancelled(&self) -> bool {
32        self.cancelled.load(Ordering::SeqCst)
33    }
34}
35
36/// Composite key for tracking active turns: (sender, channel).
37#[derive(Debug, Clone, Hash, PartialEq, Eq)]
38pub struct TurnKey {
39    pub sender: String,
40    pub channel: String,
41}
42
43impl TurnKey {
44    pub fn new(sender: impl Into<String>, channel: impl Into<String>) -> Self {
45        Self {
46            sender: sender.into(),
47            channel: channel.into(),
48        }
49    }
50}
51
52/// Tracks active turns and manages cancellation tokens.
53pub struct InterruptionDetector {
54    active_turns: Arc<Mutex<HashMap<TurnKey, CancelToken>>>,
55}
56
57impl InterruptionDetector {
58    pub fn new() -> Self {
59        Self {
60            active_turns: Arc::new(Mutex::new(HashMap::new())),
61        }
62    }
63
64    /// Register a new turn for the given key.
65    ///
66    /// If there is already an active turn for this key, it is cancelled first.
67    /// Returns the cancellation token for the new turn — the handler should
68    /// check `token.is_cancelled()` periodically.
69    pub async fn start_turn(&self, key: TurnKey) -> CancelToken {
70        let mut turns = self.active_turns.lock().await;
71
72        // Cancel any existing turn for this sender+channel
73        if let Some(existing) = turns.remove(&key) {
74            existing.cancel();
75        }
76
77        let token = CancelToken::new();
78        turns.insert(key, token.clone());
79        token
80    }
81
82    /// Finish a turn (normal completion). Removes the token from tracking.
83    pub async fn finish_turn(&self, key: &TurnKey) {
84        self.active_turns.lock().await.remove(key);
85    }
86
87    /// Check if a turn is currently active for the given key.
88    pub async fn has_active_turn(&self, key: &TurnKey) -> bool {
89        self.active_turns.lock().await.contains_key(key)
90    }
91
92    /// Get the number of active turns.
93    pub async fn active_count(&self) -> usize {
94        self.active_turns.lock().await.len()
95    }
96
97    /// Cancel all active turns (e.g. on shutdown).
98    pub async fn cancel_all(&self) {
99        let mut turns = self.active_turns.lock().await;
100        for (_, token) in turns.drain() {
101            token.cancel();
102        }
103    }
104}
105
106impl Default for InterruptionDetector {
107    fn default() -> Self {
108        Self::new()
109    }
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115
116    fn key(sender: &str, channel: &str) -> TurnKey {
117        TurnKey::new(sender, channel)
118    }
119
120    #[tokio::test]
121    async fn start_turn_returns_active_token() {
122        let detector = InterruptionDetector::new();
123        let token = detector.start_turn(key("alice", "telegram")).await;
124
125        assert!(!token.is_cancelled());
126        assert!(detector.has_active_turn(&key("alice", "telegram")).await);
127        assert_eq!(detector.active_count().await, 1);
128    }
129
130    #[tokio::test]
131    async fn new_turn_cancels_previous() {
132        let detector = InterruptionDetector::new();
133        let token1 = detector.start_turn(key("alice", "telegram")).await;
134        let token2 = detector.start_turn(key("alice", "telegram")).await;
135
136        assert!(token1.is_cancelled(), "previous turn should be cancelled");
137        assert!(!token2.is_cancelled(), "new turn should be active");
138        assert_eq!(detector.active_count().await, 1);
139    }
140
141    #[tokio::test]
142    async fn different_senders_are_independent() {
143        let detector = InterruptionDetector::new();
144        let token_alice = detector.start_turn(key("alice", "telegram")).await;
145        let token_bob = detector.start_turn(key("bob", "telegram")).await;
146
147        assert!(!token_alice.is_cancelled());
148        assert!(!token_bob.is_cancelled());
149        assert_eq!(detector.active_count().await, 2);
150    }
151
152    #[tokio::test]
153    async fn different_channels_are_independent() {
154        let detector = InterruptionDetector::new();
155        let token_tg = detector.start_turn(key("alice", "telegram")).await;
156        let token_slack = detector.start_turn(key("alice", "slack")).await;
157
158        assert!(!token_tg.is_cancelled());
159        assert!(!token_slack.is_cancelled());
160        assert_eq!(detector.active_count().await, 2);
161    }
162
163    #[tokio::test]
164    async fn finish_turn_removes_tracking() {
165        let detector = InterruptionDetector::new();
166        let _token = detector.start_turn(key("alice", "telegram")).await;
167
168        detector.finish_turn(&key("alice", "telegram")).await;
169        assert!(!detector.has_active_turn(&key("alice", "telegram")).await);
170        assert_eq!(detector.active_count().await, 0);
171    }
172
173    #[tokio::test]
174    async fn cancel_all_cancels_everything() {
175        let detector = InterruptionDetector::new();
176        let t1 = detector.start_turn(key("alice", "telegram")).await;
177        let t2 = detector.start_turn(key("bob", "slack")).await;
178
179        detector.cancel_all().await;
180
181        assert!(t1.is_cancelled());
182        assert!(t2.is_cancelled());
183        assert_eq!(detector.active_count().await, 0);
184    }
185
186    #[tokio::test]
187    async fn rapid_interruption_sequence() {
188        let detector = InterruptionDetector::new();
189        let k = key("alice", "telegram");
190
191        let t1 = detector.start_turn(k.clone()).await;
192        let t2 = detector.start_turn(k.clone()).await;
193        let t3 = detector.start_turn(k.clone()).await;
194
195        assert!(t1.is_cancelled());
196        assert!(t2.is_cancelled());
197        assert!(!t3.is_cancelled());
198        assert_eq!(detector.active_count().await, 1);
199    }
200
201    #[tokio::test]
202    async fn handler_detects_cancellation() {
203        let detector = InterruptionDetector::new();
204        let k = key("alice", "telegram");
205
206        let token = detector.start_turn(k.clone()).await;
207
208        // Simulate handler checking periodically
209        assert!(!token.is_cancelled());
210
211        // New message arrives — interrupts
212        let _new_token = detector.start_turn(k).await;
213
214        // Handler detects cancellation
215        assert!(token.is_cancelled());
216    }
217}