1use crate::error::{SageError, SageResult};
19use std::time::Duration;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
23pub enum Strategy {
24 #[default]
26 OneForOne,
27 OneForAll,
29 RestForOne,
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
35pub enum RestartPolicy {
36 #[default]
38 Permanent,
39 Transient,
41 Temporary,
43}
44
45#[derive(Debug, Clone)]
47pub struct RestartConfig {
48 pub max_restarts: u32,
50 pub within: Duration,
52}
53
54impl Default for RestartConfig {
55 fn default() -> Self {
56 Self {
57 max_restarts: 5,
58 within: Duration::from_secs(60),
59 }
60 }
61}
62
63#[cfg(not(target_arch = "wasm32"))]
68mod native {
69 use super::*;
70 use std::collections::VecDeque;
71 use std::future::Future;
72 use std::pin::Pin;
73 use std::time::Instant;
74 use tokio::task::JoinHandle;
75
76 struct RestartTracker {
78 timestamps: VecDeque<Instant>,
79 config: RestartConfig,
80 }
81
82 impl RestartTracker {
83 fn new(config: RestartConfig) -> Self {
84 Self {
85 timestamps: VecDeque::new(),
86 config,
87 }
88 }
89
90 fn record_restart(&mut self) -> bool {
93 let now = Instant::now();
94
95 while let Some(&oldest) = self.timestamps.front() {
97 if now.duration_since(oldest) > self.config.within {
98 self.timestamps.pop_front();
99 } else {
100 break;
101 }
102 }
103
104 if self.timestamps.len() >= self.config.max_restarts as usize {
106 return false; }
108
109 self.timestamps.push_back(now);
110 true
111 }
112 }
113
114 pub type SpawnFn = Box<dyn Fn() -> Pin<Box<dyn Future<Output = SageResult<()>> + Send>> + Send>;
116
117 struct ChildHandle {
119 name: String,
120 restart_policy: RestartPolicy,
121 spawn_fn: SpawnFn,
122 handle: Option<JoinHandle<SageResult<()>>>,
123 }
124
125 impl ChildHandle {
126 fn new(name: String, restart_policy: RestartPolicy, spawn_fn: SpawnFn) -> Self {
127 Self {
128 name,
129 restart_policy,
130 spawn_fn,
131 handle: None,
132 }
133 }
134
135 fn spawn(&mut self) {
137 let future = (self.spawn_fn)();
138 self.handle = Some(tokio::spawn(future));
139 }
140
141 fn is_running(&self) -> bool {
143 self.handle
144 .as_ref()
145 .map(|h| !h.is_finished())
146 .unwrap_or(false)
147 }
148
149 fn take_handle(&mut self) -> Option<JoinHandle<SageResult<()>>> {
151 self.handle.take()
152 }
153 }
154
155 pub struct Supervisor {
157 strategy: Strategy,
158 children: Vec<ChildHandle>,
159 restart_tracker: RestartTracker,
160 }
161
162 impl Supervisor {
163 pub fn new(strategy: Strategy, config: RestartConfig) -> Self {
165 Self {
166 strategy,
167 children: Vec::new(),
168 restart_tracker: RestartTracker::new(config),
169 }
170 }
171
172 pub fn add_child<F, Fut>(
176 &mut self,
177 name: impl Into<String>,
178 restart_policy: RestartPolicy,
179 spawn_fn: F,
180 ) where
181 F: Fn() -> Fut + Send + 'static,
182 Fut: Future<Output = SageResult<()>> + Send + 'static,
183 {
184 let spawn_fn: SpawnFn = Box::new(move || Box::pin(spawn_fn()));
185 self.children
186 .push(ChildHandle::new(name.into(), restart_policy, spawn_fn));
187 }
188
189 pub async fn run(&mut self) -> SageResult<()> {
194 for child in &mut self.children {
196 child.spawn();
197 }
198
199 loop {
201 let (index, result) = self.wait_for_child_exit().await;
203
204 if index.is_none() {
206 break;
208 }
209
210 let index = index.unwrap();
211 let child_name = self.children[index].name.clone();
212 let restart_policy = self.children[index].restart_policy;
213
214 let should_restart = match (restart_policy, &result) {
216 (RestartPolicy::Permanent, _) => true,
217 (RestartPolicy::Transient, Err(_)) => true,
218 (RestartPolicy::Transient, Ok(_)) => false,
219 (RestartPolicy::Temporary, _) => false,
220 };
221
222 if should_restart {
223 if !self.restart_tracker.record_restart() {
225 return Err(SageError::Supervisor(format!(
226 "Maximum restart intensity reached for supervisor (child '{}' failed too many times)",
227 child_name
228 )));
229 }
230
231 match self.strategy {
233 Strategy::OneForOne => {
234 self.restart_child(index);
235 }
236 Strategy::OneForAll => {
237 self.restart_all();
238 }
239 Strategy::RestForOne => {
240 self.restart_rest(index);
241 }
242 }
243 }
244
245 if !self.any_running() {
247 break;
248 }
249 }
250
251 Ok(())
252 }
253
254 async fn wait_for_child_exit(&mut self) -> (Option<usize>, SageResult<()>) {
256 use futures::future::select_all;
257
258 let handles_with_indices: Vec<(usize, JoinHandle<SageResult<()>>)> = self
260 .children
261 .iter_mut()
262 .enumerate()
263 .filter_map(|(i, c)| c.take_handle().map(|h| (i, h)))
264 .collect();
265
266 if handles_with_indices.is_empty() {
267 return (None, Ok(()));
268 }
269
270 let indices: Vec<usize> = handles_with_indices.iter().map(|(i, _)| *i).collect();
272 let handles: Vec<JoinHandle<SageResult<()>>> =
273 handles_with_indices.into_iter().map(|(_, h)| h).collect();
274
275 let (join_result, completed_idx, remaining_handles) = select_all(handles).await;
277
278 let child_index = indices[completed_idx];
280
281 let final_result =
283 join_result.unwrap_or_else(|e| Err(SageError::Agent(e.to_string())));
284
285 let mut remaining_iter = remaining_handles.into_iter();
287 for (pos, &original_idx) in indices.iter().enumerate() {
288 if pos != completed_idx {
289 if let (Some(handle), Some(child)) =
290 (remaining_iter.next(), self.children.get_mut(original_idx))
291 {
292 child.handle = Some(handle);
293 }
294 }
295 }
296
297 (Some(child_index), final_result)
298 }
299
300 fn restart_child(&mut self, index: usize) {
302 if let Some(child) = self.children.get_mut(index) {
303 child.spawn();
304 }
305 }
306
307 fn restart_all(&mut self) {
309 for child in &mut self.children {
311 if let Some(handle) = child.take_handle() {
312 handle.abort();
313 }
314 }
315
316 for child in &mut self.children {
318 child.spawn();
319 }
320 }
321
322 fn restart_rest(&mut self, from_index: usize) {
324 for child in self.children.iter_mut().skip(from_index) {
326 if let Some(handle) = child.take_handle() {
327 handle.abort();
328 }
329 }
330
331 for child in self.children.iter_mut().skip(from_index) {
333 child.spawn();
334 }
335 }
336
337 fn any_running(&self) -> bool {
339 self.children.iter().any(|c| c.is_running())
340 }
341 }
342}
343
344#[cfg(not(target_arch = "wasm32"))]
345pub use native::{SpawnFn, Supervisor};
346
347#[cfg(target_arch = "wasm32")]
352mod wasm_stub {
353 use super::*;
354 use std::future::Future;
355
356 pub struct Supervisor {
361 _strategy: Strategy,
362 }
363
364 impl Supervisor {
365 pub fn new(strategy: Strategy, _config: RestartConfig) -> Self {
367 Self {
368 _strategy: strategy,
369 }
370 }
371
372 pub fn add_child<F, Fut>(
374 &mut self,
375 _name: impl Into<String>,
376 _restart_policy: RestartPolicy,
377 _spawn_fn: F,
378 ) where
379 F: Fn() -> Fut + 'static,
380 Fut: Future<Output = SageResult<()>> + 'static,
381 {
382 }
384
385 pub async fn run(&mut self) -> SageResult<()> {
389 Err(SageError::Supervisor(
390 "Supervision trees are not yet supported in the WASM target".to_string(),
391 ))
392 }
393 }
394}
395
396#[cfg(target_arch = "wasm32")]
397pub use wasm_stub::Supervisor;
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402 use std::sync::atomic::{AtomicU32, Ordering};
403 use std::sync::Arc;
404
405 #[tokio::test]
406 async fn test_one_for_one_restart() {
407 let counter = Arc::new(AtomicU32::new(0));
408 let counter_clone = counter.clone();
409
410 let mut supervisor = Supervisor::new(Strategy::OneForOne, RestartConfig::default());
411
412 supervisor.add_child("Worker", RestartPolicy::Transient, move || {
414 let counter = counter_clone.clone();
415 async move {
416 let count = counter.fetch_add(1, Ordering::SeqCst);
417 if count < 2 {
418 Err(SageError::Agent("Simulated failure".to_string()))
419 } else {
420 Ok(())
421 }
422 }
423 });
424
425 let result = supervisor.run().await;
426 assert!(result.is_ok(), "supervisor failed: {:?}", result);
427 assert_eq!(counter.load(Ordering::SeqCst), 3);
428 }
429
430 #[tokio::test]
431 async fn test_transient_no_restart_on_success() {
432 let counter = Arc::new(AtomicU32::new(0));
433 let counter_clone = counter.clone();
434
435 let mut supervisor = Supervisor::new(Strategy::OneForOne, RestartConfig::default());
436
437 supervisor.add_child("Worker", RestartPolicy::Transient, move || {
438 let counter = counter_clone.clone();
439 async move {
440 counter.fetch_add(1, Ordering::SeqCst);
441 Ok(())
442 }
443 });
444
445 let result = supervisor.run().await;
446 assert!(result.is_ok());
447 assert_eq!(counter.load(Ordering::SeqCst), 1); }
449
450 #[tokio::test]
451 async fn test_temporary_never_restarts() {
452 let counter = Arc::new(AtomicU32::new(0));
453 let counter_clone = counter.clone();
454
455 let mut supervisor = Supervisor::new(Strategy::OneForOne, RestartConfig::default());
456
457 supervisor.add_child("Worker", RestartPolicy::Temporary, move || {
458 let counter = counter_clone.clone();
459 async move {
460 counter.fetch_add(1, Ordering::SeqCst);
461 Err(SageError::Agent("Simulated failure".to_string()))
462 }
463 });
464
465 let result = supervisor.run().await;
466 assert!(result.is_ok()); assert_eq!(counter.load(Ordering::SeqCst), 1); }
469
470 #[tokio::test]
471 async fn test_circuit_breaker() {
472 let counter = Arc::new(AtomicU32::new(0));
473 let counter_clone = counter.clone();
474
475 let config = RestartConfig {
476 max_restarts: 3,
477 within: Duration::from_secs(60),
478 };
479
480 let mut supervisor = Supervisor::new(Strategy::OneForOne, config);
481
482 supervisor.add_child("Worker", RestartPolicy::Permanent, move || {
483 let counter = counter_clone.clone();
484 async move {
485 counter.fetch_add(1, Ordering::SeqCst);
486 Err(SageError::Agent("Always fails".to_string()))
487 }
488 });
489
490 let result = supervisor.run().await;
491 assert!(result.is_err()); assert!(counter.load(Ordering::SeqCst) <= 4); }
494
495 #[tokio::test]
496 async fn test_permanent_restarts_on_success() {
497 let counter = Arc::new(AtomicU32::new(0));
500 let counter_clone = counter.clone();
501
502 let config = RestartConfig {
503 max_restarts: 3,
504 within: Duration::from_secs(60),
505 };
506
507 let mut supervisor = Supervisor::new(Strategy::OneForOne, config);
508
509 supervisor.add_child("Worker", RestartPolicy::Permanent, move || {
510 let counter = counter_clone.clone();
511 async move {
512 counter.fetch_add(1, Ordering::SeqCst);
513 Ok(()) }
515 });
516
517 let result = supervisor.run().await;
518 assert!(result.is_err());
520 assert!(counter.load(Ordering::SeqCst) <= 4);
521 }
522
523 #[tokio::test]
524 async fn test_rest_for_one_restarts_downstream() {
525 let counter1 = Arc::new(AtomicU32::new(0));
527 let counter2 = Arc::new(AtomicU32::new(0));
528 let counter3 = Arc::new(AtomicU32::new(0));
529 let counter1_clone = counter1.clone();
530 let counter2_clone = counter2.clone();
531 let counter3_clone = counter3.clone();
532
533 let mut supervisor = Supervisor::new(Strategy::RestForOne, RestartConfig::default());
534
535 supervisor.add_child("Child1", RestartPolicy::Temporary, move || {
537 let counter = counter1_clone.clone();
538 async move {
539 counter.fetch_add(1, Ordering::SeqCst);
540 tokio::time::sleep(Duration::from_millis(50)).await;
542 Ok(())
543 }
544 });
545
546 supervisor.add_child("Child2", RestartPolicy::Transient, move || {
548 let counter = counter2_clone.clone();
549 async move {
550 let count = counter.fetch_add(1, Ordering::SeqCst);
551 if count < 2 {
552 Err(SageError::Agent("Simulated failure".to_string()))
553 } else {
554 Ok(())
555 }
556 }
557 });
558
559 supervisor.add_child("Child3", RestartPolicy::Temporary, move || {
561 let counter = counter3_clone.clone();
562 async move {
563 counter.fetch_add(1, Ordering::SeqCst);
564 tokio::time::sleep(Duration::from_millis(50)).await;
566 Ok(())
567 }
568 });
569
570 let result = supervisor.run().await;
571 assert!(result.is_ok(), "supervisor failed: {:?}", result);
572
573 assert_eq!(
575 counter1.load(Ordering::SeqCst),
576 1,
577 "Child1 should run only once"
578 );
579
580 assert_eq!(
582 counter2.load(Ordering::SeqCst),
583 3,
584 "Child2 should run 3 times"
585 );
586
587 assert!(
589 counter3.load(Ordering::SeqCst) >= 2,
590 "Child3 should be restarted at least once with RestForOne, got {}",
591 counter3.load(Ordering::SeqCst)
592 );
593 }
594
595 #[tokio::test]
596 async fn test_one_for_all_restarts_all() {
597 let counter1 = Arc::new(AtomicU32::new(0));
599 let counter2 = Arc::new(AtomicU32::new(0));
600 let counter1_clone = counter1.clone();
601 let counter2_clone = counter2.clone();
602
603 let mut supervisor = Supervisor::new(Strategy::OneForAll, RestartConfig::default());
604
605 supervisor.add_child("Child1", RestartPolicy::Temporary, move || {
607 let counter = counter1_clone.clone();
608 async move {
609 counter.fetch_add(1, Ordering::SeqCst);
610 tokio::time::sleep(Duration::from_millis(100)).await;
611 Ok(())
612 }
613 });
614
615 supervisor.add_child("Child2", RestartPolicy::Transient, move || {
617 let counter = counter2_clone.clone();
618 async move {
619 let count = counter.fetch_add(1, Ordering::SeqCst);
620 if count < 2 {
621 Err(SageError::Agent("Simulated failure".to_string()))
622 } else {
623 tokio::time::sleep(Duration::from_millis(10)).await;
624 Ok(())
625 }
626 }
627 });
628
629 let result = supervisor.run().await;
630 assert!(result.is_ok(), "supervisor failed: {:?}", result);
631
632 assert_eq!(
634 counter2.load(Ordering::SeqCst),
635 3,
636 "Child2 should run 3 times"
637 );
638
639 assert!(
641 counter1.load(Ordering::SeqCst) >= 2,
642 "Child1 should be restarted at least once with OneForAll, got {}",
643 counter1.load(Ordering::SeqCst)
644 );
645 }
646}