Skip to main content

ai_agent/utils/
abort_controller.rs

1//! AbortController utilities
2
3use std::sync::Arc;
4use std::sync::atomic::Ordering;
5
6/// Default max listeners for standard operations
7const DEFAULT_MAX_LISTENERS: usize = 50;
8
9/// Creates an AbortController with proper event listener limits set.
10/// This prevents MaxListenersExceededWarning when multiple listeners
11/// are attached to the abort signal.
12///
13/// # Arguments
14/// * `max_listeners` - Maximum number of listeners (default: 50)
15///
16/// # Returns
17/// AbortController with configured listener limit
18pub fn create_abort_controller(max_listeners: usize) -> AbortController {
19    AbortController::new(max_listeners)
20}
21
22/// Creates an AbortController with default max listeners
23pub fn create_abort_controller_default() -> AbortController {
24    create_abort_controller(DEFAULT_MAX_LISTENERS)
25}
26
27/// AbortController implementation for Rust
28/// Provides similar functionality to the JavaScript AbortController
29pub struct AbortController {
30    signal: Arc<AbortSignal>,
31}
32
33impl AbortController {
34    /// Create a new AbortController with custom max listeners
35    pub fn new(max_listeners: usize) -> Self {
36        Self {
37            signal: Arc::new(AbortSignal::new(max_listeners)),
38        }
39    }
40
41    /// Get the abort signal
42    pub fn signal(&self) -> &Arc<AbortSignal> {
43        &self.signal
44    }
45
46    /// Abort the controller with an optional reason
47    pub fn abort(&self, reason: Option<Arc<dyn std::any::Any + Send + Sync>>) {
48        self.signal.abort(reason);
49    }
50
51    /// Check if aborted
52    pub fn is_aborted(&self) -> bool {
53        self.signal.is_aborted()
54    }
55}
56
57impl Default for AbortController {
58    fn default() -> Self {
59        Self::new(DEFAULT_MAX_LISTENERS)
60    }
61}
62
63impl Clone for AbortController {
64    fn clone(&self) -> Self {
65        Self {
66            signal: Arc::clone(&self.signal),
67        }
68    }
69}
70
71/// AbortSignal implementation
72pub struct AbortSignal {
73    aborted: std::sync::atomic::AtomicBool,
74    reason: std::sync::Mutex<Option<Arc<dyn std::any::Any + Send + Sync>>>,
75    listeners: std::sync::Mutex<Vec<AbortCallback>>,
76    max_listeners: usize,
77}
78
79pub type AbortCallback = Box<dyn Fn(Option<&dyn std::any::Any>) + Send + Sync>;
80
81impl AbortSignal {
82    /// Create a new AbortSignal with custom max listeners
83    pub fn new(max_listeners: usize) -> Self {
84        Self {
85            aborted: std::sync::atomic::AtomicBool::new(false),
86            reason: std::sync::Mutex::new(None),
87            listeners: std::sync::Mutex::new(Vec::new()),
88            max_listeners,
89        }
90    }
91
92    /// Check if aborted
93    pub fn is_aborted(&self) -> bool {
94        self.aborted.load(Ordering::SeqCst)
95    }
96
97    /// Expose the underlying AtomicBool for callers that need to pass it
98    /// to functions expecting &AtomicBool abort signals.
99    pub fn abort_flag(&self) -> &std::sync::atomic::AtomicBool {
100        &self.aborted
101    }
102
103    /// Get the abort reason
104    pub fn reason(&self) -> Option<Arc<dyn std::any::Any + Send + Sync>> {
105        self.reason.lock().ok().and_then(|guard| guard.clone())
106    }
107
108    /// Abort the signal
109    pub fn abort(&self, reason: Option<Arc<dyn std::any::Any + Send + Sync>>) {
110        if self.aborted.swap(true, Ordering::SeqCst) {
111            return; // Already aborted
112        }
113
114        *self.reason.lock().unwrap() = reason.clone();
115
116        // Notify all listeners - iterate directly over the locked guard
117        // This is safe because we hold the lock during iteration
118        let reason_ref = reason.as_deref().map(|a| a as &dyn std::any::Any);
119        for listener in self.listeners.lock().unwrap().iter() {
120            listener(reason_ref);
121        }
122    }
123
124    /// Add an abort listener
125    /// Returns the number of listeners after adding
126    pub fn add_event_listener(&self, callback: AbortCallback) -> usize {
127        let mut listeners = self.listeners.lock().unwrap();
128        if listeners.len() >= self.max_listeners {
129            log::warn!(
130                "Max listeners ({}) exceeded for AbortSignal",
131                self.max_listeners
132            );
133        }
134        listeners.push(callback);
135        listeners.len()
136    }
137
138    /// Remove an abort listener
139    #[allow(dead_code)]
140    pub fn remove_event_listener(&self, _callback: &AbortCallback) {
141        // Note: Full implementation would require function pointer comparison
142        // For now, this is a placeholder
143    }
144
145    /// Get the number of listeners
146    #[allow(dead_code)]
147    pub fn listener_count(&self) -> usize {
148        self.listeners.lock().unwrap().len()
149    }
150}
151
152impl Default for AbortSignal {
153    fn default() -> Self {
154        Self::new(DEFAULT_MAX_LISTENERS)
155    }
156}
157
158impl Clone for AbortSignal {
159    fn clone(&self) -> Self {
160        Self {
161            aborted: std::sync::atomic::AtomicBool::new(self.aborted.load(Ordering::SeqCst)),
162            reason: std::sync::Mutex::new(self.reason.lock().ok().and_then(|g| g.clone())),
163            listeners: std::sync::Mutex::new(Vec::new()), // Cloned signals don't share listeners
164            max_listeners: self.max_listeners,
165        }
166    }
167}
168
169impl std::fmt::Debug for AbortSignal {
170    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171        f.debug_struct("AbortSignal")
172            .field("aborted", &self.aborted.load(Ordering::SeqCst))
173            .field("max_listeners", &self.max_listeners)
174            .finish()
175    }
176}
177
178/// Creates a child AbortController that aborts when its parent aborts.
179/// Aborting the child does NOT affect the parent.
180///
181/// Memory-safe: Uses Arc so that parent doesn't retain abandoned children.
182/// If the child is dropped without being aborted, it can still be GC'd.
183/// When the child IS aborted, the parent listener is removed to prevent
184/// accumulation of dead handlers.
185///
186/// # Arguments
187/// * `parent` - The parent AbortController
188/// * `max_listeners` - Maximum number of listeners (default: 50)
189///
190/// # Returns
191/// Child AbortController
192#[allow(dead_code)]
193pub fn create_child_abort_controller(
194    parent: &AbortController,
195    max_listeners: Option<usize>,
196) -> AbortController {
197    let max_listeners = max_listeners.unwrap_or(DEFAULT_MAX_LISTENERS);
198    let child = AbortController::new(max_listeners);
199
200    // Fast path: parent already aborted, no listener setup needed
201    if parent.is_aborted() {
202        child.abort(parent.signal.reason());
203        return child;
204    }
205
206    // Clone the child signal to use in the closure
207    let child_signal = Arc::clone(&child.signal);
208    let parent_signal = Arc::clone(parent.signal());
209
210    // Get the reason now, before moving into closure
211    let reason = parent_signal.reason();
212
213    // Use a wrapper to handle the propagation
214    // Note: We need both signals to be Send + Sync, which they are
215    parent_signal.add_event_listener(Box::new(move |_reason| {
216        // Propagate the captured reason to child
217        child_signal.abort(reason.clone());
218    }));
219
220    child
221}
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226
227    #[test]
228    fn test_create_abort_controller() {
229        let controller = create_abort_controller(50);
230        assert!(!controller.is_aborted());
231    }
232
233    #[test]
234    fn test_abort_controller_abort() {
235        let controller = create_abort_controller(50);
236        controller.abort(None);
237        assert!(controller.is_aborted());
238    }
239
240    #[test]
241    fn test_abort_with_reason() {
242        let controller = create_abort_controller(50);
243        let reason = Arc::new("test reason".to_string()) as Arc<dyn std::any::Any + Send + Sync>;
244        controller.abort(Some(reason));
245
246        assert!(controller.is_aborted());
247        let stored_reason = controller.signal().reason();
248        assert!(stored_reason.is_some());
249    }
250
251    #[test]
252    fn test_abort_listener() {
253        let controller = create_abort_controller(50);
254        let called = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
255        let called_clone = called.clone();
256
257        controller
258            .signal()
259            .add_event_listener(Box::new(move |_reason| {
260                called.store(true, std::sync::atomic::Ordering::SeqCst);
261            }));
262
263        controller.abort(None);
264        assert!(called_clone.load(std::sync::atomic::Ordering::SeqCst));
265    }
266
267    #[test]
268    fn test_child_abort_controller() {
269        let parent = create_abort_controller(50);
270        let child = create_child_abort_controller(&parent, None);
271
272        assert!(!parent.is_aborted());
273        assert!(!child.is_aborted());
274
275        parent.abort(None);
276
277        assert!(parent.is_aborted());
278        assert!(child.is_aborted());
279    }
280
281    #[test]
282    fn test_child_already_aborted_parent() {
283        let parent = create_abort_controller(50);
284        parent.abort(None);
285
286        let child = create_child_abort_controller(&parent, None);
287
288        assert!(child.is_aborted());
289    }
290
291    #[test]
292    fn test_abort_flag_reflects_state() {
293        let controller = create_abort_controller(50);
294        let flag = controller.signal().abort_flag();
295        assert!(!flag.load(Ordering::SeqCst));
296
297        controller.abort(None);
298        assert!(flag.load(Ordering::SeqCst));
299    }
300
301    #[test]
302    fn test_abort_flag_survives_guard() {
303        let abort_ctrl = create_abort_controller(50);
304        let flag = abort_ctrl.signal().abort_flag();
305        // flag borrows from abort_ctrl, not from a mutex guard
306        assert!(!flag.load(Ordering::SeqCst));
307        abort_ctrl.abort(None);
308        assert!(flag.load(Ordering::SeqCst));
309    }
310}