agent_kernel/
shutdown.rs

1//! Graceful shutdown coordination for agent lifecycle.
2//!
3//! This module provides a state machine for coordinating graceful shutdown of agents,
4//! ensuring in-flight work completes before termination.
5
6use std::sync::Arc;
7use std::sync::atomic::{AtomicU8, AtomicU32, Ordering};
8use std::time::Duration;
9
10use thiserror::Error;
11use tokio::sync::broadcast;
12use tokio::time::timeout;
13use tracing::{debug, warn};
14
15/// Shutdown state enumeration.
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum ShutdownState {
18    /// Agent is running normally.
19    Running = 0,
20    /// Agent is draining in-flight work.
21    Draining = 1,
22    /// Agent has terminated.
23    Terminated = 2,
24}
25
26impl ShutdownState {
27    /// Convert from atomic representation.
28    fn from_u8(value: u8) -> Self {
29        match value {
30            0 => ShutdownState::Running,
31            1 => ShutdownState::Draining,
32            _ => ShutdownState::Terminated,
33        }
34    }
35
36    /// Convert to atomic representation.
37    fn as_u8(self) -> u8 {
38        self as u8
39    }
40}
41
42/// Errors from shutdown operations.
43#[derive(Debug, Error)]
44pub enum ShutdownError {
45    /// Shutdown was already initiated.
46    #[error("shutdown already initiated")]
47    AlreadyShuttingDown,
48
49    /// Drain timeout expired before all work completed.
50    #[error("drain timeout expired with {remaining} tasks still in flight")]
51    DrainTimeout {
52        /// Number of tasks still in flight when timeout expired.
53        remaining: u32,
54    },
55
56    /// Broadcast channel error.
57    #[error("broadcast error: {0}")]
58    BroadcastError(String),
59}
60
61/// Result type for shutdown operations.
62pub type ShutdownResult<T> = Result<T, ShutdownError>;
63
64/// Coordinates graceful shutdown of agent workloads.
65///
66/// The shutdown coordinator manages a state machine with three states:
67/// - `Running`: Normal operation, new work can be registered
68/// - `Draining`: Shutdown initiated, no new work accepted, existing work completes
69/// - `Terminated`: All work completed or timeout expired
70///
71/// # Example
72///
73/// ```ignore
74/// let coordinator = Arc::new(ShutdownCoordinator::new(Duration::from_secs(30)));
75/// let guard = coordinator.register_work();
76/// // ... do work ...
77/// drop(guard); // Work completes
78///
79/// coordinator.shutdown().await.unwrap();
80/// assert_eq!(coordinator.state(), ShutdownState::Terminated);
81/// ```
82#[derive(Debug)]
83pub struct ShutdownCoordinator {
84    state: AtomicU8,
85    drain_timeout: Duration,
86    in_flight: AtomicU32,
87    shutdown_signal: broadcast::Sender<()>,
88}
89
90impl ShutdownCoordinator {
91    /// Creates a new shutdown coordinator with the specified drain timeout.
92    ///
93    /// # Arguments
94    ///
95    /// * `drain_timeout` - Maximum duration to wait for in-flight work to complete
96    #[must_use]
97    pub fn new(drain_timeout: Duration) -> Self {
98        let (tx, _) = broadcast::channel(1);
99        Self {
100            state: AtomicU8::new(ShutdownState::Running.as_u8()),
101            drain_timeout,
102            in_flight: AtomicU32::new(0),
103            shutdown_signal: tx,
104        }
105    }
106
107    /// Registers a unit of in-flight work.
108    ///
109    /// Returns a `WorkGuard` that automatically decrements the in-flight counter
110    /// when dropped. If shutdown is already in progress, returns an error.
111    ///
112    /// # Errors
113    ///
114    /// Returns `ShutdownError::AlreadyShuttingDown` if shutdown has been initiated.
115    pub fn register_work(self: &Arc<Self>) -> ShutdownResult<WorkGuard> {
116        let current_state = ShutdownState::from_u8(self.state.load(Ordering::SeqCst));
117
118        if current_state != ShutdownState::Running {
119            return Err(ShutdownError::AlreadyShuttingDown);
120        }
121
122        self.in_flight.fetch_add(1, Ordering::SeqCst);
123        Ok(WorkGuard {
124            coordinator: Arc::clone(self),
125        })
126    }
127
128    /// Initiates graceful shutdown.
129    ///
130    /// Transitions the coordinator to `Draining` state and waits for in-flight work
131    /// to complete up to the configured drain timeout. If the timeout expires,
132    /// returns an error but still transitions to `Terminated`.
133    ///
134    /// # Errors
135    ///
136    /// Returns `ShutdownError::DrainTimeout` if in-flight work doesn't complete
137    /// within the drain timeout.
138    pub async fn shutdown(&self) -> ShutdownResult<()> {
139        // Transition to Draining
140        let old_state = self.state.compare_exchange(
141            ShutdownState::Running.as_u8(),
142            ShutdownState::Draining.as_u8(),
143            Ordering::SeqCst,
144            Ordering::SeqCst,
145        );
146
147        if old_state.is_err() {
148            return Err(ShutdownError::AlreadyShuttingDown);
149        }
150
151        debug!("shutdown initiated, draining in-flight work");
152
153        // Broadcast shutdown signal
154        let _ = self.shutdown_signal.send(());
155
156        // Wait for in-flight work to complete or timeout
157        let result = timeout(self.drain_timeout, self.wait_for_drain()).await;
158
159        // Transition to Terminated
160        self.state
161            .store(ShutdownState::Terminated.as_u8(), Ordering::SeqCst);
162
163        if result.is_ok() {
164            debug!("shutdown complete, all work drained");
165            Ok(())
166        } else {
167            let remaining = self.in_flight.load(Ordering::SeqCst);
168            warn!(remaining, "drain timeout expired");
169            Err(ShutdownError::DrainTimeout { remaining })
170        }
171    }
172
173    /// Returns the current shutdown state.
174    #[must_use]
175    pub fn state(&self) -> ShutdownState {
176        ShutdownState::from_u8(self.state.load(Ordering::SeqCst))
177    }
178
179    /// Subscribes to shutdown signal.
180    ///
181    /// Returns a receiver that will be notified when shutdown is initiated.
182    #[must_use]
183    pub fn subscribe(&self) -> broadcast::Receiver<()> {
184        self.shutdown_signal.subscribe()
185    }
186
187    /// Returns the number of in-flight work items.
188    #[must_use]
189    pub fn in_flight_count(&self) -> u32 {
190        self.in_flight.load(Ordering::SeqCst)
191    }
192
193    /// Waits for all in-flight work to complete.
194    async fn wait_for_drain(&self) {
195        loop {
196            if self.in_flight.load(Ordering::SeqCst) == 0 {
197                break;
198            }
199            tokio::time::sleep(Duration::from_millis(10)).await;
200        }
201    }
202}
203
204/// RAII guard for tracking in-flight work.
205///
206/// When dropped, automatically decrements the in-flight work counter.
207pub struct WorkGuard {
208    coordinator: Arc<ShutdownCoordinator>,
209}
210
211impl Drop for WorkGuard {
212    fn drop(&mut self) {
213        self.coordinator.in_flight.fetch_sub(1, Ordering::SeqCst);
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220
221    #[test]
222    fn initial_state_is_running() {
223        let coordinator = Arc::new(ShutdownCoordinator::new(Duration::from_secs(1)));
224        assert_eq!(coordinator.state(), ShutdownState::Running);
225        assert_eq!(coordinator.in_flight_count(), 0);
226    }
227
228    #[test]
229    fn register_work_increments_counter() {
230        let coordinator = Arc::new(ShutdownCoordinator::new(Duration::from_secs(1)));
231        let _guard = coordinator.register_work().unwrap();
232        assert_eq!(coordinator.in_flight_count(), 1);
233    }
234
235    #[test]
236    fn work_guard_drop_decrements_counter() {
237        let coordinator = Arc::new(ShutdownCoordinator::new(Duration::from_secs(1)));
238        {
239            let _guard = coordinator.register_work().unwrap();
240            assert_eq!(coordinator.in_flight_count(), 1);
241        }
242        assert_eq!(coordinator.in_flight_count(), 0);
243    }
244
245    #[test]
246    fn register_work_fails_during_shutdown() {
247        let rt = tokio::runtime::Runtime::new().unwrap();
248        rt.block_on(async {
249            let coordinator = Arc::new(ShutdownCoordinator::new(Duration::from_secs(1)));
250            let _guard = coordinator.register_work().unwrap();
251
252            let coordinator_clone = Arc::clone(&coordinator);
253            tokio::spawn(async move {
254                let _ = coordinator_clone.shutdown().await;
255            });
256
257            tokio::time::sleep(Duration::from_millis(50)).await;
258            let result = coordinator.register_work();
259            assert!(result.is_err());
260        });
261    }
262
263    #[tokio::test]
264    async fn shutdown_waits_for_work_to_complete() {
265        let coordinator = Arc::new(ShutdownCoordinator::new(Duration::from_secs(5)));
266        let guard = coordinator.register_work().unwrap();
267
268        let coordinator_clone = Arc::clone(&coordinator);
269        let shutdown_task = tokio::spawn(async move { coordinator_clone.shutdown().await });
270
271        tokio::time::sleep(Duration::from_millis(100)).await;
272        drop(guard);
273
274        let result = shutdown_task.await.unwrap();
275        assert!(result.is_ok());
276        assert_eq!(coordinator.state(), ShutdownState::Terminated);
277    }
278
279    #[tokio::test]
280    async fn shutdown_timeout_with_pending_work() {
281        let coordinator = Arc::new(ShutdownCoordinator::new(Duration::from_millis(100)));
282        let _guard = coordinator.register_work().unwrap();
283
284        let result = coordinator.shutdown().await;
285        assert!(result.is_err());
286        assert_eq!(coordinator.state(), ShutdownState::Terminated);
287    }
288
289    #[tokio::test]
290    async fn shutdown_signal_broadcast() {
291        let coordinator = Arc::new(ShutdownCoordinator::new(Duration::from_secs(1)));
292        let mut rx = coordinator.subscribe();
293
294        let coordinator_clone = Arc::clone(&coordinator);
295        tokio::spawn(async move {
296            tokio::time::sleep(Duration::from_millis(50)).await;
297            let _ = coordinator_clone.shutdown().await;
298        });
299
300        let result = tokio::time::timeout(Duration::from_secs(1), rx.recv()).await;
301        assert!(result.is_ok());
302    }
303}