Skip to main content

ai_agent/utils/
abort_controller.rs

1//! AbortController utilities
2
3use std::sync::atomic::Ordering;
4use std::sync::Arc;
5
6/// Default max listeners for standard operations
7const DEFAULT_MAX_LISTENERS: usize = 50;
8
9/// Creates an AbortController with proper event listener limits set.
10/// This prevents MaxListenersExceededWarning when multiple listeners
11/// are attached to the abort signal.
12///
13/// # Arguments
14/// * `max_listeners` - Maximum number of listeners (default: 50)
15///
16/// # Returns
17/// AbortController with configured listener limit
18pub fn create_abort_controller(max_listeners: usize) -> AbortController {
19    AbortController::new(max_listeners)
20}
21
22/// Creates an AbortController with default max listeners
23pub fn create_abort_controller_default() -> AbortController {
24    create_abort_controller(DEFAULT_MAX_LISTENERS)
25}
26
27/// AbortController implementation for Rust
28/// Provides similar functionality to the JavaScript AbortController
29pub struct AbortController {
30    signal: Arc<AbortSignal>,
31}
32
33impl AbortController {
34    /// Create a new AbortController with custom max listeners
35    pub fn new(max_listeners: usize) -> Self {
36        Self {
37            signal: Arc::new(AbortSignal::new(max_listeners)),
38        }
39    }
40
41    /// Get the abort signal
42    pub fn signal(&self) -> &Arc<AbortSignal> {
43        &self.signal
44    }
45
46    /// Abort the controller with an optional reason
47    pub fn abort(&self, reason: Option<Arc<dyn std::any::Any + Send + Sync>>) {
48        self.signal.abort(reason);
49    }
50
51    /// Check if aborted
52    pub fn is_aborted(&self) -> bool {
53        self.signal.is_aborted()
54    }
55}
56
57impl Default for AbortController {
58    fn default() -> Self {
59        Self::new(DEFAULT_MAX_LISTENERS)
60    }
61}
62
63impl Clone for AbortController {
64    fn clone(&self) -> Self {
65        Self {
66            signal: Arc::clone(&self.signal),
67        }
68    }
69}
70
71/// AbortSignal implementation
72pub struct AbortSignal {
73    aborted: std::sync::atomic::AtomicBool,
74    reason: std::sync::Mutex<Option<Arc<dyn std::any::Any + Send + Sync>>>,
75    listeners: std::sync::Mutex<Vec<AbortCallback>>,
76    max_listeners: usize,
77}
78
79pub type AbortCallback = Box<dyn Fn(Option<&dyn std::any::Any>) + Send + Sync>;
80
81impl AbortSignal {
82    /// Create a new AbortSignal with custom max listeners
83    pub fn new(max_listeners: usize) -> Self {
84        Self {
85            aborted: std::sync::atomic::AtomicBool::new(false),
86            reason: std::sync::Mutex::new(None),
87            listeners: std::sync::Mutex::new(Vec::new()),
88            max_listeners,
89        }
90    }
91
92    /// Check if aborted
93    pub fn is_aborted(&self) -> bool {
94        self.aborted.load(Ordering::SeqCst)
95    }
96
97    /// Get the abort reason
98    pub fn reason(&self) -> Option<Arc<dyn std::any::Any + Send + Sync>> {
99        self.reason.lock().ok().and_then(|guard| guard.clone())
100    }
101
102    /// Abort the signal
103    pub fn abort(&self, reason: Option<Arc<dyn std::any::Any + Send + Sync>>) {
104        if self.aborted.swap(true, Ordering::SeqCst) {
105            return; // Already aborted
106        }
107
108        *self.reason.lock().unwrap() = reason.clone();
109
110        // Notify all listeners - iterate directly over the locked guard
111        // This is safe because we hold the lock during iteration
112        let reason_ref = reason.as_deref().map(|a| a as &dyn std::any::Any);
113        for listener in self.listeners.lock().unwrap().iter() {
114            listener(reason_ref);
115        }
116    }
117
118    /// Add an abort listener
119    /// Returns the number of listeners after adding
120    pub fn add_event_listener(&self, callback: AbortCallback) -> usize {
121        let mut listeners = self.listeners.lock().unwrap();
122        if listeners.len() >= self.max_listeners {
123            log::warn!(
124                "Max listeners ({}) exceeded for AbortSignal",
125                self.max_listeners
126            );
127        }
128        listeners.push(callback);
129        listeners.len()
130    }
131
132    /// Remove an abort listener
133    #[allow(dead_code)]
134    pub fn remove_event_listener(&self, _callback: &AbortCallback) {
135        // Note: Full implementation would require function pointer comparison
136        // For now, this is a placeholder
137    }
138
139    /// Get the number of listeners
140    #[allow(dead_code)]
141    pub fn listener_count(&self) -> usize {
142        self.listeners.lock().unwrap().len()
143    }
144}
145
146impl Default for AbortSignal {
147    fn default() -> Self {
148        Self::new(DEFAULT_MAX_LISTENERS)
149    }
150}
151
152impl Clone for AbortSignal {
153    fn clone(&self) -> Self {
154        Self {
155            aborted: std::sync::atomic::AtomicBool::new(self.aborted.load(Ordering::SeqCst)),
156            reason: std::sync::Mutex::new(self.reason.lock().ok().and_then(|g| g.clone())),
157            listeners: std::sync::Mutex::new(Vec::new()), // Cloned signals don't share listeners
158            max_listeners: self.max_listeners,
159        }
160    }
161}
162
163/// Creates a child AbortController that aborts when its parent aborts.
164/// Aborting the child does NOT affect the parent.
165///
166/// Memory-safe: Uses Arc so that parent doesn't retain abandoned children.
167/// If the child is dropped without being aborted, it can still be GC'd.
168/// When the child IS aborted, the parent listener is removed to prevent
169/// accumulation of dead handlers.
170///
171/// # Arguments
172/// * `parent` - The parent AbortController
173/// * `max_listeners` - Maximum number of listeners (default: 50)
174///
175/// # Returns
176/// Child AbortController
177#[allow(dead_code)]
178pub fn create_child_abort_controller(
179    parent: &AbortController,
180    max_listeners: Option<usize>,
181) -> AbortController {
182    let max_listeners = max_listeners.unwrap_or(DEFAULT_MAX_LISTENERS);
183    let child = AbortController::new(max_listeners);
184
185    // Fast path: parent already aborted, no listener setup needed
186    if parent.is_aborted() {
187        child.abort(parent.signal.reason());
188        return child;
189    }
190
191    // Clone the child signal to use in the closure
192    let child_signal = Arc::clone(&child.signal);
193    let parent_signal = Arc::clone(parent.signal());
194
195    // Get the reason now, before moving into closure
196    let reason = parent_signal.reason();
197
198    // Use a wrapper to handle the propagation
199    // Note: We need both signals to be Send + Sync, which they are
200    parent_signal.add_event_listener(Box::new(move |_reason| {
201        // Propagate the captured reason to child
202        child_signal.abort(reason.clone());
203    }));
204
205    child
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211
212    #[test]
213    fn test_create_abort_controller() {
214        let controller = create_abort_controller(50);
215        assert!(!controller.is_aborted());
216    }
217
218    #[test]
219    fn test_abort_controller_abort() {
220        let controller = create_abort_controller(50);
221        controller.abort(None);
222        assert!(controller.is_aborted());
223    }
224
225    #[test]
226    fn test_abort_with_reason() {
227        let controller = create_abort_controller(50);
228        let reason = Arc::new("test reason".to_string()) as Arc<dyn std::any::Any + Send + Sync>;
229        controller.abort(Some(reason));
230
231        assert!(controller.is_aborted());
232        let stored_reason = controller.signal().reason();
233        assert!(stored_reason.is_some());
234    }
235
236    #[test]
237    fn test_abort_listener() {
238        let controller = create_abort_controller(50);
239        let called = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
240        let called_clone = called.clone();
241
242        controller
243            .signal()
244            .add_event_listener(Box::new(move |_reason| {
245                called.store(true, std::sync::atomic::Ordering::SeqCst);
246            }));
247
248        controller.abort(None);
249        assert!(called_clone.load(std::sync::atomic::Ordering::SeqCst));
250    }
251
252    #[test]
253    fn test_child_abort_controller() {
254        let parent = create_abort_controller(50);
255        let child = create_child_abort_controller(&parent, None);
256
257        assert!(!parent.is_aborted());
258        assert!(!child.is_aborted());
259
260        parent.abort(None);
261
262        assert!(parent.is_aborted());
263        assert!(child.is_aborted());
264    }
265
266    #[test]
267    fn test_child_already_aborted_parent() {
268        let parent = create_abort_controller(50);
269        parent.abort(None);
270
271        let child = create_child_abort_controller(&parent, None);
272
273        assert!(child.is_aborted());
274    }
275}