agent_kernel/
scheduler.rs1use std::future::Future;
4use std::num::NonZeroUsize;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicBool, Ordering};
7
8use thiserror::Error;
9use tokio::sync::Semaphore;
10use tokio::task::JoinHandle;
11
12#[derive(Debug, Clone, Copy)]
14pub struct SchedulerConfig {
15 max_concurrency: NonZeroUsize,
16}
17
18impl SchedulerConfig {
19 #[must_use]
21 pub const fn new(max_concurrency: NonZeroUsize) -> Self {
22 Self { max_concurrency }
23 }
24
25 #[must_use]
27 pub const fn max_concurrency(self) -> NonZeroUsize {
28 self.max_concurrency
29 }
30}
31
32impl Default for SchedulerConfig {
33 fn default() -> Self {
34 Self::new(NonZeroUsize::new(32).expect("non-zero"))
35 }
36}
37
38#[derive(Debug, Clone)]
40pub struct TaskScheduler {
41 semaphore: Arc<Semaphore>,
42 closed: Arc<AtomicBool>,
43 config: SchedulerConfig,
44}
45
46impl TaskScheduler {
47 #[must_use]
49 pub fn new(config: SchedulerConfig) -> Self {
50 let permits = config.max_concurrency().get();
51 Self {
52 semaphore: Arc::new(Semaphore::new(permits)),
53 closed: Arc::new(AtomicBool::new(false)),
54 config,
55 }
56 }
57
58 #[must_use]
60 pub const fn config(&self) -> SchedulerConfig {
61 self.config
62 }
63
64 #[must_use]
66 pub fn is_closed(&self) -> bool {
67 self.closed.load(Ordering::Acquire)
68 }
69
70 pub fn close(&self) {
72 self.closed.store(true, Ordering::Release);
73 self.semaphore.close();
74 }
75
76 pub fn spawn<F, T>(&self, future: F) -> SchedulerResult<JoinHandle<T>>
89 where
90 F: Future<Output = T> + Send + 'static,
91 T: Send + 'static,
92 {
93 if self.is_closed() {
94 return Err(SchedulerError::Closed);
95 }
96
97 let semaphore = Arc::clone(&self.semaphore);
98
99 let handle = tokio::spawn(async move {
100 let permit = semaphore
101 .acquire_owned()
102 .await
103 .expect("scheduler closed while awaiting permit");
104 let output = future.await;
105 drop(permit);
106 output
107 });
108
109 Ok(handle)
110 }
111}
112
113impl Default for TaskScheduler {
114 fn default() -> Self {
115 Self::new(SchedulerConfig::default())
116 }
117}
118
119#[derive(Debug, Error, PartialEq, Eq)]
121pub enum SchedulerError {
122 #[error("scheduler closed")]
124 Closed,
125}
126
127pub type SchedulerResult<T> = Result<T, SchedulerError>;
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133 use std::sync::atomic::{AtomicUsize, Ordering};
134 use std::time::Duration;
135
136 #[tokio::test]
137 async fn respects_max_concurrency() {
138 let config = SchedulerConfig::new(NonZeroUsize::new(2).unwrap());
139 let scheduler = TaskScheduler::new(config);
140 let in_flight = Arc::new(AtomicUsize::new(0));
141 let max_seen = Arc::new(AtomicUsize::new(0));
142
143 let mut handles = Vec::new();
144 for _ in 0..3 {
145 let scheduler = scheduler.clone();
146 let in_flight = Arc::clone(&in_flight);
147 let max_seen = Arc::clone(&max_seen);
148 handles.push(
149 scheduler
150 .spawn(async move {
151 let current = in_flight.fetch_add(1, Ordering::SeqCst) + 1;
152 max_seen.fetch_max(current, Ordering::SeqCst);
153 tokio::time::sleep(Duration::from_millis(10)).await;
154 in_flight.fetch_sub(1, Ordering::SeqCst);
155 })
156 .unwrap(),
157 );
158 }
159
160 for handle in handles {
161 handle.await.unwrap();
162 }
163
164 assert_eq!(max_seen.load(Ordering::SeqCst), 2);
165 }
166
167 #[tokio::test]
168 async fn close_prevents_new_tasks() {
169 let scheduler = TaskScheduler::default();
170 scheduler.close();
171
172 let result = scheduler.spawn(async move {});
173 assert_eq!(result.unwrap_err(), SchedulerError::Closed);
174 }
175}