ai_agent/utils/
abort_controller.rs1use std::sync::Arc;
4use std::sync::atomic::Ordering;
5
6const DEFAULT_MAX_LISTENERS: usize = 50;
8
9pub fn create_abort_controller(max_listeners: usize) -> AbortController {
19 AbortController::new(max_listeners)
20}
21
22pub fn create_abort_controller_default() -> AbortController {
24 create_abort_controller(DEFAULT_MAX_LISTENERS)
25}
26
27pub struct AbortController {
30 signal: Arc<AbortSignal>,
31}
32
33impl AbortController {
34 pub fn new(max_listeners: usize) -> Self {
36 Self {
37 signal: Arc::new(AbortSignal::new(max_listeners)),
38 }
39 }
40
41 pub fn signal(&self) -> &Arc<AbortSignal> {
43 &self.signal
44 }
45
46 pub fn abort(&self, reason: Option<Arc<dyn std::any::Any + Send + Sync>>) {
48 self.signal.abort(reason);
49 }
50
51 pub fn is_aborted(&self) -> bool {
53 self.signal.is_aborted()
54 }
55}
56
57impl Default for AbortController {
58 fn default() -> Self {
59 Self::new(DEFAULT_MAX_LISTENERS)
60 }
61}
62
63impl Clone for AbortController {
64 fn clone(&self) -> Self {
65 Self {
66 signal: Arc::clone(&self.signal),
67 }
68 }
69}
70
71pub struct AbortSignal {
73 aborted: std::sync::atomic::AtomicBool,
74 reason: std::sync::Mutex<Option<Arc<dyn std::any::Any + Send + Sync>>>,
75 listeners: std::sync::Mutex<Vec<AbortCallback>>,
76 max_listeners: usize,
77}
78
79pub type AbortCallback = Box<dyn Fn(Option<&dyn std::any::Any>) + Send + Sync>;
80
81impl AbortSignal {
82 pub fn new(max_listeners: usize) -> Self {
84 Self {
85 aborted: std::sync::atomic::AtomicBool::new(false),
86 reason: std::sync::Mutex::new(None),
87 listeners: std::sync::Mutex::new(Vec::new()),
88 max_listeners,
89 }
90 }
91
92 pub fn is_aborted(&self) -> bool {
94 self.aborted.load(Ordering::SeqCst)
95 }
96
97 pub fn abort_flag(&self) -> &std::sync::atomic::AtomicBool {
100 &self.aborted
101 }
102
103 pub fn reason(&self) -> Option<Arc<dyn std::any::Any + Send + Sync>> {
105 self.reason.lock().ok().and_then(|guard| guard.clone())
106 }
107
108 pub fn abort(&self, reason: Option<Arc<dyn std::any::Any + Send + Sync>>) {
110 if self.aborted.swap(true, Ordering::SeqCst) {
111 return; }
113
114 *self.reason.lock().unwrap() = reason.clone();
115
116 let reason_ref = reason.as_deref().map(|a| a as &dyn std::any::Any);
119 for listener in self.listeners.lock().unwrap().iter() {
120 listener(reason_ref);
121 }
122 }
123
124 pub fn add_event_listener(&self, callback: AbortCallback) -> usize {
127 let mut listeners = self.listeners.lock().unwrap();
128 if listeners.len() >= self.max_listeners {
129 log::warn!(
130 "Max listeners ({}) exceeded for AbortSignal",
131 self.max_listeners
132 );
133 }
134 listeners.push(callback);
135 listeners.len()
136 }
137
138 #[allow(dead_code)]
140 pub fn remove_event_listener(&self, _callback: &AbortCallback) {
141 }
144
145 #[allow(dead_code)]
147 pub fn listener_count(&self) -> usize {
148 self.listeners.lock().unwrap().len()
149 }
150}
151
152impl Default for AbortSignal {
153 fn default() -> Self {
154 Self::new(DEFAULT_MAX_LISTENERS)
155 }
156}
157
158impl Clone for AbortSignal {
159 fn clone(&self) -> Self {
160 Self {
161 aborted: std::sync::atomic::AtomicBool::new(self.aborted.load(Ordering::SeqCst)),
162 reason: std::sync::Mutex::new(self.reason.lock().ok().and_then(|g| g.clone())),
163 listeners: std::sync::Mutex::new(Vec::new()), max_listeners: self.max_listeners,
165 }
166 }
167}
168
169impl std::fmt::Debug for AbortSignal {
170 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171 f.debug_struct("AbortSignal")
172 .field("aborted", &self.aborted.load(Ordering::SeqCst))
173 .field("max_listeners", &self.max_listeners)
174 .finish()
175 }
176}
177
178#[allow(dead_code)]
193pub fn create_child_abort_controller(
194 parent: &AbortController,
195 max_listeners: Option<usize>,
196) -> AbortController {
197 let max_listeners = max_listeners.unwrap_or(DEFAULT_MAX_LISTENERS);
198 let child = AbortController::new(max_listeners);
199
200 if parent.is_aborted() {
202 child.abort(parent.signal.reason());
203 return child;
204 }
205
206 let child_signal = Arc::clone(&child.signal);
208 let parent_signal = Arc::clone(parent.signal());
209
210 let reason = parent_signal.reason();
212
213 parent_signal.add_event_listener(Box::new(move |_reason| {
216 child_signal.abort(reason.clone());
218 }));
219
220 child
221}
222
223#[cfg(test)]
224mod tests {
225 use super::*;
226
227 #[test]
228 fn test_create_abort_controller() {
229 let controller = create_abort_controller(50);
230 assert!(!controller.is_aborted());
231 }
232
233 #[test]
234 fn test_abort_controller_abort() {
235 let controller = create_abort_controller(50);
236 controller.abort(None);
237 assert!(controller.is_aborted());
238 }
239
240 #[test]
241 fn test_abort_with_reason() {
242 let controller = create_abort_controller(50);
243 let reason = Arc::new("test reason".to_string()) as Arc<dyn std::any::Any + Send + Sync>;
244 controller.abort(Some(reason));
245
246 assert!(controller.is_aborted());
247 let stored_reason = controller.signal().reason();
248 assert!(stored_reason.is_some());
249 }
250
251 #[test]
252 fn test_abort_listener() {
253 let controller = create_abort_controller(50);
254 let called = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
255 let called_clone = called.clone();
256
257 controller
258 .signal()
259 .add_event_listener(Box::new(move |_reason| {
260 called.store(true, std::sync::atomic::Ordering::SeqCst);
261 }));
262
263 controller.abort(None);
264 assert!(called_clone.load(std::sync::atomic::Ordering::SeqCst));
265 }
266
267 #[test]
268 fn test_child_abort_controller() {
269 let parent = create_abort_controller(50);
270 let child = create_child_abort_controller(&parent, None);
271
272 assert!(!parent.is_aborted());
273 assert!(!child.is_aborted());
274
275 parent.abort(None);
276
277 assert!(parent.is_aborted());
278 assert!(child.is_aborted());
279 }
280
281 #[test]
282 fn test_child_already_aborted_parent() {
283 let parent = create_abort_controller(50);
284 parent.abort(None);
285
286 let child = create_child_abort_controller(&parent, None);
287
288 assert!(child.is_aborted());
289 }
290
291 #[test]
292 fn test_abort_flag_reflects_state() {
293 let controller = create_abort_controller(50);
294 let flag = controller.signal().abort_flag();
295 assert!(!flag.load(Ordering::SeqCst));
296
297 controller.abort(None);
298 assert!(flag.load(Ordering::SeqCst));
299 }
300
301 #[test]
302 fn test_abort_flag_survives_guard() {
303 let abort_ctrl = create_abort_controller(50);
304 let flag = abort_ctrl.signal().abort_flag();
305 assert!(!flag.load(Ordering::SeqCst));
307 abort_ctrl.abort(None);
308 assert!(flag.load(Ordering::SeqCst));
309 }
310}