Skip to main content

sage_runtime/
supervisor.rs

1//! Supervision tree implementation for Sage v2.
2//!
3//! This module provides Erlang/OTP-style supervision trees for managing
4//! agent lifecycles with automatic restart capabilities.
5//!
6//! # Supervision Strategies
7//!
8//! - **OneForOne**: Restart only the failed child
9//! - **OneForAll**: Restart all children if one fails
10//! - **RestForOne**: Restart the failed child and all children started after it
11//!
12//! # Restart Policies
13//!
14//! - **Permanent**: Always restart, regardless of exit reason
15//! - **Transient**: Restart only on abnormal termination (error)
16//! - **Temporary**: Never restart
17//!
18//! # Example
19//!
20//! ```ignore
21//! use sage_runtime::supervisor::{Supervisor, Strategy, RestartPolicy};
22//!
23//! let mut supervisor = Supervisor::new(Strategy::OneForOne, Default::default());
24//!
25//! supervisor.add_child("Worker", RestartPolicy::Permanent, || {
26//!     sage_runtime::spawn(|mut ctx| async move {
27//!         // Agent logic
28//!         ctx.emit(())
29//!     })
30//! });
31//!
32//! supervisor.run().await?;
33//! ```
34
35use crate::error::{SageError, SageResult};
36use std::collections::VecDeque;
37use std::future::Future;
38use std::pin::Pin;
39use std::time::{Duration, Instant};
40use tokio::task::JoinHandle;
41
42/// Supervision strategy (OTP-inspired).
43#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
44pub enum Strategy {
45    /// Restart only the failed child.
46    #[default]
47    OneForOne,
48    /// Restart all children if one fails.
49    OneForAll,
50    /// Restart the failed child and all children started after it.
51    RestForOne,
52}
53
54/// Restart policy for supervised children.
55#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
56pub enum RestartPolicy {
57    /// Always restart, regardless of exit reason.
58    #[default]
59    Permanent,
60    /// Restart only on abnormal termination (error).
61    Transient,
62    /// Never restart.
63    Temporary,
64}
65
66/// Configuration for restart intensity limiting (circuit breaker).
67#[derive(Debug, Clone)]
68pub struct RestartConfig {
69    /// Maximum number of restarts allowed within the time window.
70    pub max_restarts: u32,
71    /// Time window in which max_restarts is measured.
72    pub within: Duration,
73}
74
75impl Default for RestartConfig {
76    fn default() -> Self {
77        Self {
78            max_restarts: 5,
79            within: Duration::from_secs(60),
80        }
81    }
82}
83
84/// Tracks restart history for circuit breaker functionality.
85struct RestartTracker {
86    timestamps: VecDeque<Instant>,
87    config: RestartConfig,
88}
89
90impl RestartTracker {
91    fn new(config: RestartConfig) -> Self {
92        Self {
93            timestamps: VecDeque::new(),
94            config,
95        }
96    }
97
98    /// Record a restart and check if we've exceeded the limit.
99    /// Returns true if we should allow the restart, false if circuit breaker trips.
100    fn record_restart(&mut self) -> bool {
101        let now = Instant::now();
102
103        // Remove old timestamps outside the window
104        while let Some(&oldest) = self.timestamps.front() {
105            if now.duration_since(oldest) > self.config.within {
106                self.timestamps.pop_front();
107            } else {
108                break;
109            }
110        }
111
112        // Check if we're at the limit
113        if self.timestamps.len() >= self.config.max_restarts as usize {
114            return false; // Circuit breaker trips
115        }
116
117        self.timestamps.push_back(now);
118        true
119    }
120}
121
122/// A spawn function that creates an agent and returns its join handle.
123pub type SpawnFn = Box<dyn Fn() -> Pin<Box<dyn Future<Output = SageResult<()>> + Send>> + Send>;
124
125/// Handle to a supervised child.
126struct ChildHandle {
127    name: String,
128    restart_policy: RestartPolicy,
129    spawn_fn: SpawnFn,
130    handle: Option<JoinHandle<SageResult<()>>>,
131}
132
133impl ChildHandle {
134    fn new(name: String, restart_policy: RestartPolicy, spawn_fn: SpawnFn) -> Self {
135        Self {
136            name,
137            restart_policy,
138            spawn_fn,
139            handle: None,
140        }
141    }
142
143    /// Spawn (or respawn) this child.
144    fn spawn(&mut self) {
145        let future = (self.spawn_fn)();
146        self.handle = Some(tokio::spawn(async move { future.await }));
147    }
148
149    /// Check if the child is running.
150    fn is_running(&self) -> bool {
151        self.handle
152            .as_ref()
153            .map(|h| !h.is_finished())
154            .unwrap_or(false)
155    }
156
157    /// Take the join handle (for awaiting).
158    fn take_handle(&mut self) -> Option<JoinHandle<SageResult<()>>> {
159        self.handle.take()
160    }
161}
162
163/// A supervisor that manages child agents with restart strategies.
164pub struct Supervisor {
165    strategy: Strategy,
166    children: Vec<ChildHandle>,
167    restart_tracker: RestartTracker,
168}
169
170impl Supervisor {
171    /// Create a new supervisor with the given strategy and restart configuration.
172    pub fn new(strategy: Strategy, config: RestartConfig) -> Self {
173        Self {
174            strategy,
175            children: Vec::new(),
176            restart_tracker: RestartTracker::new(config),
177        }
178    }
179
180    /// Add a child to the supervisor.
181    ///
182    /// The spawn function should create the agent and return its future.
183    pub fn add_child<F, Fut>(&mut self, name: impl Into<String>, restart_policy: RestartPolicy, spawn_fn: F)
184    where
185        F: Fn() -> Fut + Send + 'static,
186        Fut: Future<Output = SageResult<()>> + Send + 'static,
187    {
188        let spawn_fn: SpawnFn = Box::new(move || Box::pin(spawn_fn()));
189        self.children.push(ChildHandle::new(name.into(), restart_policy, spawn_fn));
190    }
191
192    /// Start all children and begin supervision.
193    ///
194    /// This method runs until all children have terminated (according to their
195    /// restart policies) or the circuit breaker trips.
196    pub async fn run(&mut self) -> SageResult<()> {
197        // Start all children
198        for child in &mut self.children {
199            child.spawn();
200        }
201
202        // Monitor loop
203        loop {
204            // Wait for any child to complete
205            let (index, result) = self.wait_for_child_exit().await;
206
207            // Check if all children are done
208            if index.is_none() {
209                // All children have finished
210                break;
211            }
212
213            let index = index.unwrap();
214            let child_name = self.children[index].name.clone();
215            let restart_policy = self.children[index].restart_policy;
216
217            // Determine if we should restart
218            let should_restart = match (restart_policy, &result) {
219                (RestartPolicy::Permanent, _) => true,
220                (RestartPolicy::Transient, Err(_)) => true,
221                (RestartPolicy::Transient, Ok(_)) => false,
222                (RestartPolicy::Temporary, _) => false,
223            };
224
225            if should_restart {
226                // Check circuit breaker
227                if !self.restart_tracker.record_restart() {
228                    return Err(SageError::Supervisor(format!(
229                        "Maximum restart intensity reached for supervisor (child '{}' failed too many times)",
230                        child_name
231                    )));
232                }
233
234                // Apply restart strategy
235                match self.strategy {
236                    Strategy::OneForOne => {
237                        self.restart_child(index);
238                    }
239                    Strategy::OneForAll => {
240                        self.restart_all();
241                    }
242                    Strategy::RestForOne => {
243                        self.restart_rest(index);
244                    }
245                }
246            }
247
248            // Check if any children are still running
249            if !self.any_running() {
250                break;
251            }
252        }
253
254        Ok(())
255    }
256
257    /// Wait for any child to exit, returning the index and result.
258    async fn wait_for_child_exit(&mut self) -> (Option<usize>, SageResult<()>) {
259        use futures::future::select_all;
260
261        // Collect all running children's handles with their indices
262        let handles_with_indices: Vec<(usize, JoinHandle<SageResult<()>>)> = self
263            .children
264            .iter_mut()
265            .enumerate()
266            .filter_map(|(i, c)| c.take_handle().map(|h| (i, h)))
267            .collect();
268
269        if handles_with_indices.is_empty() {
270            return (None, Ok(()));
271        }
272
273        // We need to track indices separately since select_all works on the handles
274        let indices: Vec<usize> = handles_with_indices.iter().map(|(i, _)| *i).collect();
275        let handles: Vec<JoinHandle<SageResult<()>>> =
276            handles_with_indices.into_iter().map(|(_, h)| h).collect();
277
278        // Wait for any handle to complete
279        let (join_result, completed_idx, remaining_handles) = select_all(handles).await;
280
281        // Get the original child index
282        let child_index = indices[completed_idx];
283
284        // Convert JoinError to SageError
285        let final_result = join_result.unwrap_or_else(|e| Err(SageError::Agent(e.to_string())));
286
287        // Put back the remaining handles to their respective children
288        // Build list of (handle, original_index) pairs for non-completed handles
289        let mut remaining_iter = remaining_handles.into_iter();
290        for (pos, &original_idx) in indices.iter().enumerate() {
291            if pos != completed_idx {
292                if let (Some(handle), Some(child)) =
293                    (remaining_iter.next(), self.children.get_mut(original_idx))
294                {
295                    child.handle = Some(handle);
296                }
297            }
298        }
299
300        (Some(child_index), final_result)
301    }
302
303    /// Restart a single child.
304    fn restart_child(&mut self, index: usize) {
305        if let Some(child) = self.children.get_mut(index) {
306            child.spawn();
307        }
308    }
309
310    /// Restart all children (stop all first, then start all).
311    fn restart_all(&mut self) {
312        // Abort all running children
313        for child in &mut self.children {
314            if let Some(handle) = child.take_handle() {
315                handle.abort();
316            }
317        }
318
319        // Start all children
320        for child in &mut self.children {
321            child.spawn();
322        }
323    }
324
325    /// Restart the failed child and all children started after it.
326    fn restart_rest(&mut self, from_index: usize) {
327        // Abort children from index onwards
328        for child in self.children.iter_mut().skip(from_index) {
329            if let Some(handle) = child.take_handle() {
330                handle.abort();
331            }
332        }
333
334        // Restart children from index onwards
335        for child in self.children.iter_mut().skip(from_index) {
336            child.spawn();
337        }
338    }
339
340    /// Check if any children are still running.
341    fn any_running(&self) -> bool {
342        self.children.iter().any(|c| c.is_running())
343    }
344}
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349    use std::sync::atomic::{AtomicU32, Ordering};
350    use std::sync::Arc;
351
352    #[tokio::test]
353    async fn test_one_for_one_restart() {
354        let counter = Arc::new(AtomicU32::new(0));
355        let counter_clone = counter.clone();
356
357        let mut supervisor = Supervisor::new(Strategy::OneForOne, RestartConfig::default());
358
359        // Use Transient policy - restart on error, stop on success
360        supervisor.add_child("Worker", RestartPolicy::Transient, move || {
361            let counter = counter_clone.clone();
362            async move {
363                let count = counter.fetch_add(1, Ordering::SeqCst);
364                if count < 2 {
365                    Err(SageError::Agent("Simulated failure".to_string()))
366                } else {
367                    Ok(())
368                }
369            }
370        });
371
372        let result = supervisor.run().await;
373        assert!(result.is_ok(), "supervisor failed: {:?}", result);
374        assert_eq!(counter.load(Ordering::SeqCst), 3);
375    }
376
377    #[tokio::test]
378    async fn test_transient_no_restart_on_success() {
379        let counter = Arc::new(AtomicU32::new(0));
380        let counter_clone = counter.clone();
381
382        let mut supervisor = Supervisor::new(Strategy::OneForOne, RestartConfig::default());
383
384        supervisor.add_child("Worker", RestartPolicy::Transient, move || {
385            let counter = counter_clone.clone();
386            async move {
387                counter.fetch_add(1, Ordering::SeqCst);
388                Ok(())
389            }
390        });
391
392        let result = supervisor.run().await;
393        assert!(result.is_ok());
394        assert_eq!(counter.load(Ordering::SeqCst), 1); // Only ran once
395    }
396
397    #[tokio::test]
398    async fn test_temporary_never_restarts() {
399        let counter = Arc::new(AtomicU32::new(0));
400        let counter_clone = counter.clone();
401
402        let mut supervisor = Supervisor::new(Strategy::OneForOne, RestartConfig::default());
403
404        supervisor.add_child("Worker", RestartPolicy::Temporary, move || {
405            let counter = counter_clone.clone();
406            async move {
407                counter.fetch_add(1, Ordering::SeqCst);
408                Err(SageError::Agent("Simulated failure".to_string()))
409            }
410        });
411
412        let result = supervisor.run().await;
413        assert!(result.is_ok()); // Supervisor should succeed even if child fails
414        assert_eq!(counter.load(Ordering::SeqCst), 1); // Only ran once
415    }
416
417    #[tokio::test]
418    async fn test_circuit_breaker() {
419        let counter = Arc::new(AtomicU32::new(0));
420        let counter_clone = counter.clone();
421
422        let config = RestartConfig {
423            max_restarts: 3,
424            within: Duration::from_secs(60),
425        };
426
427        let mut supervisor = Supervisor::new(Strategy::OneForOne, config);
428
429        supervisor.add_child("Worker", RestartPolicy::Permanent, move || {
430            let counter = counter_clone.clone();
431            async move {
432                counter.fetch_add(1, Ordering::SeqCst);
433                Err(SageError::Agent("Always fails".to_string()))
434            }
435        });
436
437        let result = supervisor.run().await;
438        assert!(result.is_err()); // Circuit breaker should trip
439        assert!(counter.load(Ordering::SeqCst) <= 4); // At most 4 attempts (1 + 3 restarts)
440    }
441
442    #[tokio::test]
443    async fn test_permanent_restarts_on_success() {
444        // Permanent policy restarts even when child exits normally.
445        // This test verifies the circuit breaker eventually stops it.
446        let counter = Arc::new(AtomicU32::new(0));
447        let counter_clone = counter.clone();
448
449        let config = RestartConfig {
450            max_restarts: 3,
451            within: Duration::from_secs(60),
452        };
453
454        let mut supervisor = Supervisor::new(Strategy::OneForOne, config);
455
456        supervisor.add_child("Worker", RestartPolicy::Permanent, move || {
457            let counter = counter_clone.clone();
458            async move {
459                counter.fetch_add(1, Ordering::SeqCst);
460                Ok(()) // Exits successfully each time
461            }
462        });
463
464        let result = supervisor.run().await;
465        // Circuit breaker trips because Permanent keeps restarting even on success
466        assert!(result.is_err());
467        assert!(counter.load(Ordering::SeqCst) <= 4);
468    }
469
470    #[tokio::test]
471    async fn test_rest_for_one_restarts_downstream() {
472        // RestForOne: when child fails, it and all children added after it restart.
473        let counter1 = Arc::new(AtomicU32::new(0));
474        let counter2 = Arc::new(AtomicU32::new(0));
475        let counter3 = Arc::new(AtomicU32::new(0));
476        let counter1_clone = counter1.clone();
477        let counter2_clone = counter2.clone();
478        let counter3_clone = counter3.clone();
479
480        let mut supervisor = Supervisor::new(Strategy::RestForOne, RestartConfig::default());
481
482        // Child 1: Always succeeds
483        supervisor.add_child("Child1", RestartPolicy::Temporary, move || {
484            let counter = counter1_clone.clone();
485            async move {
486                counter.fetch_add(1, Ordering::SeqCst);
487                // Wait a bit so it doesn't exit before child 2 fails
488                tokio::time::sleep(Duration::from_millis(50)).await;
489                Ok(())
490            }
491        });
492
493        // Child 2: Fails twice then succeeds (this triggers RestForOne)
494        supervisor.add_child("Child2", RestartPolicy::Transient, move || {
495            let counter = counter2_clone.clone();
496            async move {
497                let count = counter.fetch_add(1, Ordering::SeqCst);
498                if count < 2 {
499                    Err(SageError::Agent("Simulated failure".to_string()))
500                } else {
501                    Ok(())
502                }
503            }
504        });
505
506        // Child 3: Succeeds but should be restarted when Child2 fails
507        supervisor.add_child("Child3", RestartPolicy::Temporary, move || {
508            let counter = counter3_clone.clone();
509            async move {
510                counter.fetch_add(1, Ordering::SeqCst);
511                // Wait a bit so it doesn't exit before child 2 fails
512                tokio::time::sleep(Duration::from_millis(50)).await;
513                Ok(())
514            }
515        });
516
517        let result = supervisor.run().await;
518        assert!(result.is_ok(), "supervisor failed: {:?}", result);
519
520        // Child1 should only run once (it's before the failing child)
521        assert_eq!(counter1.load(Ordering::SeqCst), 1, "Child1 should run only once");
522
523        // Child2 runs 3 times (2 failures + 1 success)
524        assert_eq!(counter2.load(Ordering::SeqCst), 3, "Child2 should run 3 times");
525
526        // Child3 should be restarted when Child2 fails (2 restarts + initial)
527        assert!(
528            counter3.load(Ordering::SeqCst) >= 2,
529            "Child3 should be restarted at least once with RestForOne, got {}",
530            counter3.load(Ordering::SeqCst)
531        );
532    }
533
534    #[tokio::test]
535    async fn test_one_for_all_restarts_all() {
536        // OneForAll: when any child fails, all children restart.
537        let counter1 = Arc::new(AtomicU32::new(0));
538        let counter2 = Arc::new(AtomicU32::new(0));
539        let counter1_clone = counter1.clone();
540        let counter2_clone = counter2.clone();
541
542        let mut supervisor = Supervisor::new(Strategy::OneForAll, RestartConfig::default());
543
544        // Child 1: Always succeeds but runs longer
545        supervisor.add_child("Child1", RestartPolicy::Temporary, move || {
546            let counter = counter1_clone.clone();
547            async move {
548                counter.fetch_add(1, Ordering::SeqCst);
549                tokio::time::sleep(Duration::from_millis(100)).await;
550                Ok(())
551            }
552        });
553
554        // Child 2: Fails twice then succeeds (this triggers OneForAll)
555        supervisor.add_child("Child2", RestartPolicy::Transient, move || {
556            let counter = counter2_clone.clone();
557            async move {
558                let count = counter.fetch_add(1, Ordering::SeqCst);
559                if count < 2 {
560                    Err(SageError::Agent("Simulated failure".to_string()))
561                } else {
562                    tokio::time::sleep(Duration::from_millis(10)).await;
563                    Ok(())
564                }
565            }
566        });
567
568        let result = supervisor.run().await;
569        assert!(result.is_ok(), "supervisor failed: {:?}", result);
570
571        // Child2 runs 3 times (2 failures + 1 success)
572        assert_eq!(counter2.load(Ordering::SeqCst), 3, "Child2 should run 3 times");
573
574        // Child1 should be restarted when Child2 fails (OneForAll restarts all)
575        assert!(
576            counter1.load(Ordering::SeqCst) >= 2,
577            "Child1 should be restarted at least once with OneForAll, got {}",
578            counter1.load(Ordering::SeqCst)
579        );
580    }
581}