agent_kernel/
scheduler.rs

1//! Cooperative scheduler facade for agent workloads.
2
3use std::future::Future;
4use std::num::NonZeroUsize;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicBool, Ordering};
7
8use thiserror::Error;
9use tokio::sync::Semaphore;
10use tokio::task::JoinHandle;
11
12/// Maximum number of concurrent tasks allowed per agent.
13#[derive(Debug, Clone, Copy)]
14pub struct SchedulerConfig {
15    max_concurrency: NonZeroUsize,
16}
17
18impl SchedulerConfig {
19    /// Creates a new configuration with the supplied concurrency limit.
20    #[must_use]
21    pub const fn new(max_concurrency: NonZeroUsize) -> Self {
22        Self { max_concurrency }
23    }
24
25    /// Returns the configured concurrency limit.
26    #[must_use]
27    pub const fn max_concurrency(self) -> NonZeroUsize {
28        self.max_concurrency
29    }
30}
31
32impl Default for SchedulerConfig {
33    fn default() -> Self {
34        Self::new(NonZeroUsize::new(32).expect("non-zero"))
35    }
36}
37
38/// Lightweight wrapper around `tokio::spawn` that enforces per-agent concurrency.
39#[derive(Debug, Clone)]
40pub struct TaskScheduler {
41    semaphore: Arc<Semaphore>,
42    closed: Arc<AtomicBool>,
43    config: SchedulerConfig,
44}
45
46impl TaskScheduler {
47    /// Constructs a scheduler using the provided configuration.
48    #[must_use]
49    pub fn new(config: SchedulerConfig) -> Self {
50        let permits = config.max_concurrency().get();
51        Self {
52            semaphore: Arc::new(Semaphore::new(permits)),
53            closed: Arc::new(AtomicBool::new(false)),
54            config,
55        }
56    }
57
58    /// Returns the associated configuration.
59    #[must_use]
60    pub const fn config(&self) -> SchedulerConfig {
61        self.config
62    }
63
64    /// Returns `true` if the scheduler has been closed.
65    #[must_use]
66    pub fn is_closed(&self) -> bool {
67        self.closed.load(Ordering::Acquire)
68    }
69
70    /// Closes the scheduler, preventing new tasks from being spawned.
71    pub fn close(&self) {
72        self.closed.store(true, Ordering::Release);
73        self.semaphore.close();
74    }
75
76    /// Spawns a future, respecting the configured concurrency limit.
77    ///
78    /// # Errors
79    ///
80    /// Returns [`SchedulerError::Closed`] when the scheduler is closed before the
81    /// task is enqueued.
82    ///
83    /// # Panics
84    ///
85    /// Panics if the scheduler is closed while a task is awaiting a concurrency
86    /// permit. This indicates that `close` was invoked concurrently with task
87    /// submission.
88    pub fn spawn<F, T>(&self, future: F) -> SchedulerResult<JoinHandle<T>>
89    where
90        F: Future<Output = T> + Send + 'static,
91        T: Send + 'static,
92    {
93        if self.is_closed() {
94            return Err(SchedulerError::Closed);
95        }
96
97        let semaphore = Arc::clone(&self.semaphore);
98
99        let handle = tokio::spawn(async move {
100            let permit = semaphore
101                .acquire_owned()
102                .await
103                .expect("scheduler closed while awaiting permit");
104            let output = future.await;
105            drop(permit);
106            output
107        });
108
109        Ok(handle)
110    }
111}
112
113impl Default for TaskScheduler {
114    fn default() -> Self {
115        Self::new(SchedulerConfig::default())
116    }
117}
118
119/// Errors produced by the scheduler.
120#[derive(Debug, Error, PartialEq, Eq)]
121pub enum SchedulerError {
122    /// Scheduler is closed and will not accept new tasks.
123    #[error("scheduler closed")]
124    Closed,
125}
126
127/// Result alias for scheduler operations.
128pub type SchedulerResult<T> = Result<T, SchedulerError>;
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133    use std::sync::atomic::{AtomicUsize, Ordering};
134    use std::time::Duration;
135
136    #[tokio::test]
137    async fn respects_max_concurrency() {
138        let config = SchedulerConfig::new(NonZeroUsize::new(2).unwrap());
139        let scheduler = TaskScheduler::new(config);
140        let in_flight = Arc::new(AtomicUsize::new(0));
141        let max_seen = Arc::new(AtomicUsize::new(0));
142
143        let mut handles = Vec::new();
144        for _ in 0..3 {
145            let scheduler = scheduler.clone();
146            let in_flight = Arc::clone(&in_flight);
147            let max_seen = Arc::clone(&max_seen);
148            handles.push(
149                scheduler
150                    .spawn(async move {
151                        let current = in_flight.fetch_add(1, Ordering::SeqCst) + 1;
152                        max_seen.fetch_max(current, Ordering::SeqCst);
153                        tokio::time::sleep(Duration::from_millis(10)).await;
154                        in_flight.fetch_sub(1, Ordering::SeqCst);
155                    })
156                    .unwrap(),
157            );
158        }
159
160        for handle in handles {
161            handle.await.unwrap();
162        }
163
164        assert_eq!(max_seen.load(Ordering::SeqCst), 2);
165    }
166
167    #[tokio::test]
168    async fn close_prevents_new_tasks() {
169        let scheduler = TaskScheduler::default();
170        scheduler.close();
171
172        let result = scheduler.spawn(async move {});
173        assert_eq!(result.unwrap_err(), SchedulerError::Closed);
174    }
175}