agent_kernel/
shutdown.rs

1//! Graceful shutdown coordination for agent lifecycle.
2//!
3//! This module provides a state machine for coordinating graceful shutdown of agents,
4//! ensuring in-flight work completes before termination.
5
6use std::sync::Arc;
7use std::sync::atomic::{AtomicU8, AtomicU32, Ordering};
8use std::time::Duration;
9
10use thiserror::Error;
11use tokio::sync::broadcast;
12use tokio::time::timeout;
13use tracing::{debug, warn};
14
15/// Shutdown state enumeration.
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum ShutdownState {
18    /// Agent is running normally.
19    Running = 0,
20    /// Agent is draining in-flight work.
21    Draining = 1,
22    /// Agent has terminated.
23    Terminated = 2,
24}
25
26impl ShutdownState {
27    /// Convert from atomic representation.
28    fn from_u8(value: u8) -> Self {
29        match value {
30            0 => ShutdownState::Running,
31            1 => ShutdownState::Draining,
32            2 => ShutdownState::Terminated,
33            _ => ShutdownState::Terminated,
34        }
35    }
36
37    /// Convert to atomic representation.
38    fn as_u8(self) -> u8 {
39        self as u8
40    }
41}
42
43/// Errors from shutdown operations.
44#[derive(Debug, Error)]
45pub enum ShutdownError {
46    /// Shutdown was already initiated.
47    #[error("shutdown already initiated")]
48    AlreadyShuttingDown,
49
50    /// Drain timeout expired before all work completed.
51    #[error("drain timeout expired with {remaining} tasks still in flight")]
52    DrainTimeout {
53        /// Number of tasks still in flight when timeout expired.
54        remaining: u32,
55    },
56
57    /// Broadcast channel error.
58    #[error("broadcast error: {0}")]
59    BroadcastError(String),
60}
61
62/// Result type for shutdown operations.
63pub type ShutdownResult<T> = Result<T, ShutdownError>;
64
65/// Coordinates graceful shutdown of agent workloads.
66///
67/// The shutdown coordinator manages a state machine with three states:
68/// - `Running`: Normal operation, new work can be registered
69/// - `Draining`: Shutdown initiated, no new work accepted, existing work completes
70/// - `Terminated`: All work completed or timeout expired
71///
72/// # Example
73///
74/// ```ignore
75/// let coordinator = Arc::new(ShutdownCoordinator::new(Duration::from_secs(30)));
76/// let guard = coordinator.register_work();
77/// // ... do work ...
78/// drop(guard); // Work completes
79///
80/// coordinator.shutdown().await.unwrap();
81/// assert_eq!(coordinator.state(), ShutdownState::Terminated);
82/// ```
83#[derive(Debug)]
84pub struct ShutdownCoordinator {
85    state: AtomicU8,
86    drain_timeout: Duration,
87    in_flight: AtomicU32,
88    shutdown_signal: broadcast::Sender<()>,
89}
90
91impl ShutdownCoordinator {
92    /// Creates a new shutdown coordinator with the specified drain timeout.
93    ///
94    /// # Arguments
95    ///
96    /// * `drain_timeout` - Maximum duration to wait for in-flight work to complete
97    pub fn new(drain_timeout: Duration) -> Self {
98        let (tx, _) = broadcast::channel(1);
99        Self {
100            state: AtomicU8::new(ShutdownState::Running.as_u8()),
101            drain_timeout,
102            in_flight: AtomicU32::new(0),
103            shutdown_signal: tx,
104        }
105    }
106
107    /// Registers a unit of in-flight work.
108    ///
109    /// Returns a `WorkGuard` that automatically decrements the in-flight counter
110    /// when dropped. If shutdown is already in progress, returns an error.
111    ///
112    /// # Errors
113    ///
114    /// Returns `ShutdownError::AlreadyShuttingDown` if shutdown has been initiated.
115    pub fn register_work(self: &Arc<Self>) -> ShutdownResult<WorkGuard> {
116        let current_state = ShutdownState::from_u8(self.state.load(Ordering::SeqCst));
117
118        if current_state != ShutdownState::Running {
119            return Err(ShutdownError::AlreadyShuttingDown);
120        }
121
122        self.in_flight.fetch_add(1, Ordering::SeqCst);
123        Ok(WorkGuard {
124            coordinator: Arc::clone(self),
125        })
126    }
127
128    /// Initiates graceful shutdown.
129    ///
130    /// Transitions the coordinator to `Draining` state and waits for in-flight work
131    /// to complete up to the configured drain timeout. If the timeout expires,
132    /// returns an error but still transitions to `Terminated`.
133    ///
134    /// # Errors
135    ///
136    /// Returns `ShutdownError::DrainTimeout` if in-flight work doesn't complete
137    /// within the drain timeout.
138    pub async fn shutdown(&self) -> ShutdownResult<()> {
139        // Transition to Draining
140        let old_state = self.state.compare_exchange(
141            ShutdownState::Running.as_u8(),
142            ShutdownState::Draining.as_u8(),
143            Ordering::SeqCst,
144            Ordering::SeqCst,
145        );
146
147        if old_state.is_err() {
148            return Err(ShutdownError::AlreadyShuttingDown);
149        }
150
151        debug!("shutdown initiated, draining in-flight work");
152
153        // Broadcast shutdown signal
154        let _ = self.shutdown_signal.send(());
155
156        // Wait for in-flight work to complete or timeout
157        let result = timeout(self.drain_timeout, self.wait_for_drain()).await;
158
159        // Transition to Terminated
160        self.state
161            .store(ShutdownState::Terminated.as_u8(), Ordering::SeqCst);
162
163        match result {
164            Ok(_) => {
165                debug!("shutdown complete, all work drained");
166                Ok(())
167            }
168            Err(_) => {
169                let remaining = self.in_flight.load(Ordering::SeqCst);
170                warn!(remaining, "drain timeout expired");
171                Err(ShutdownError::DrainTimeout { remaining })
172            }
173        }
174    }
175
176    /// Returns the current shutdown state.
177    #[must_use]
178    pub fn state(&self) -> ShutdownState {
179        ShutdownState::from_u8(self.state.load(Ordering::SeqCst))
180    }
181
182    /// Subscribes to shutdown signal.
183    ///
184    /// Returns a receiver that will be notified when shutdown is initiated.
185    #[must_use]
186    pub fn subscribe(&self) -> broadcast::Receiver<()> {
187        self.shutdown_signal.subscribe()
188    }
189
190    /// Returns the number of in-flight work items.
191    #[must_use]
192    pub fn in_flight_count(&self) -> u32 {
193        self.in_flight.load(Ordering::SeqCst)
194    }
195
196    /// Waits for all in-flight work to complete.
197    async fn wait_for_drain(&self) {
198        loop {
199            if self.in_flight.load(Ordering::SeqCst) == 0 {
200                break;
201            }
202            tokio::time::sleep(Duration::from_millis(10)).await;
203        }
204    }
205}
206
207/// RAII guard for tracking in-flight work.
208///
209/// When dropped, automatically decrements the in-flight work counter.
210pub struct WorkGuard {
211    coordinator: Arc<ShutdownCoordinator>,
212}
213
214impl Drop for WorkGuard {
215    fn drop(&mut self) {
216        self.coordinator.in_flight.fetch_sub(1, Ordering::SeqCst);
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223
224    #[test]
225    fn initial_state_is_running() {
226        let coordinator = Arc::new(ShutdownCoordinator::new(Duration::from_secs(1)));
227        assert_eq!(coordinator.state(), ShutdownState::Running);
228        assert_eq!(coordinator.in_flight_count(), 0);
229    }
230
231    #[test]
232    fn register_work_increments_counter() {
233        let coordinator = Arc::new(ShutdownCoordinator::new(Duration::from_secs(1)));
234        let _guard = coordinator.register_work().unwrap();
235        assert_eq!(coordinator.in_flight_count(), 1);
236    }
237
238    #[test]
239    fn work_guard_drop_decrements_counter() {
240        let coordinator = Arc::new(ShutdownCoordinator::new(Duration::from_secs(1)));
241        {
242            let _guard = coordinator.register_work().unwrap();
243            assert_eq!(coordinator.in_flight_count(), 1);
244        }
245        assert_eq!(coordinator.in_flight_count(), 0);
246    }
247
248    #[test]
249    fn register_work_fails_during_shutdown() {
250        let rt = tokio::runtime::Runtime::new().unwrap();
251        rt.block_on(async {
252            let coordinator = Arc::new(ShutdownCoordinator::new(Duration::from_secs(1)));
253            let _guard = coordinator.register_work().unwrap();
254
255            let coordinator_clone = Arc::clone(&coordinator);
256            tokio::spawn(async move {
257                let _ = coordinator_clone.shutdown().await;
258            });
259
260            tokio::time::sleep(Duration::from_millis(50)).await;
261            let result = coordinator.register_work();
262            assert!(result.is_err());
263        });
264    }
265
266    #[tokio::test]
267    async fn shutdown_waits_for_work_to_complete() {
268        let coordinator = Arc::new(ShutdownCoordinator::new(Duration::from_secs(5)));
269        let guard = coordinator.register_work().unwrap();
270
271        let coordinator_clone = Arc::clone(&coordinator);
272        let shutdown_task = tokio::spawn(async move { coordinator_clone.shutdown().await });
273
274        tokio::time::sleep(Duration::from_millis(100)).await;
275        drop(guard);
276
277        let result = shutdown_task.await.unwrap();
278        assert!(result.is_ok());
279        assert_eq!(coordinator.state(), ShutdownState::Terminated);
280    }
281
282    #[tokio::test]
283    async fn shutdown_timeout_with_pending_work() {
284        let coordinator = Arc::new(ShutdownCoordinator::new(Duration::from_millis(100)));
285        let _guard = coordinator.register_work().unwrap();
286
287        let result = coordinator.shutdown().await;
288        assert!(result.is_err());
289        assert_eq!(coordinator.state(), ShutdownState::Terminated);
290    }
291
292    #[tokio::test]
293    async fn shutdown_signal_broadcast() {
294        let coordinator = Arc::new(ShutdownCoordinator::new(Duration::from_secs(1)));
295        let mut rx = coordinator.subscribe();
296
297        let coordinator_clone = Arc::clone(&coordinator);
298        tokio::spawn(async move {
299            tokio::time::sleep(Duration::from_millis(50)).await;
300            let _ = coordinator_clone.shutdown().await;
301        });
302
303        let result = tokio::time::timeout(Duration::from_secs(1), rx.recv()).await;
304        assert!(result.is_ok());
305    }
306}