ai_agents_runtime/optimization/
maintenance.rs1use std::collections::HashMap;
2use std::future::Future;
3use std::hash::{Hash, Hasher};
4use std::sync::Arc;
5
6use parking_lot::Mutex;
7use tokio::task::JoinHandle;
8
9use ai_agents_core::{AgentError, Result};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
13pub enum RuntimeTaskPurpose {
14 MainResponse,
16 StateTransition,
18 SkillRouting,
20 ReasoningJudge,
22 PostTurnFacts,
24 PostTurnRelationship,
26 PostTurnSessionMaintenance,
28 PostTurnCompression,
30 OrchestrationVoteExtraction,
32 ObservabilityExport,
34}
35
36#[derive(Debug, Clone, Eq)]
38pub struct MaintenanceSequenceKey {
39 pub agent_id: String,
41 pub scope_id: String,
43 pub task_kind: RuntimeTaskPurpose,
45}
46
47impl MaintenanceSequenceKey {
48 pub fn actor(
50 agent_id: impl Into<String>,
51 actor_id: impl Into<String>,
52 task_kind: RuntimeTaskPurpose,
53 ) -> Self {
54 Self {
55 agent_id: agent_id.into(),
56 scope_id: actor_id.into(),
57 task_kind,
58 }
59 }
60}
61
62impl PartialEq for MaintenanceSequenceKey {
63 fn eq(&self, other: &Self) -> bool {
64 self.agent_id == other.agent_id
65 && self.scope_id == other.scope_id
66 && self.task_kind == other.task_kind
67 }
68}
69
70impl Hash for MaintenanceSequenceKey {
71 fn hash<H: Hasher>(&self, state: &mut H) {
72 self.agent_id.hash(state);
73 self.scope_id.hash(state);
74 self.task_kind.hash(state);
75 }
76}
77
78struct TrackedTask {
79 key: Option<MaintenanceSequenceKey>,
80 handle: JoinHandle<Result<()>>,
81}
82
83pub struct BackgroundMaintenanceQueue {
85 max_tasks: usize,
86 tasks: Mutex<Vec<TrackedTask>>,
87 locks: Mutex<HashMap<MaintenanceSequenceKey, Arc<tokio::sync::Mutex<()>>>>,
88}
89
90impl BackgroundMaintenanceQueue {
91 pub fn new(max_tasks: usize) -> Self {
93 Self {
94 max_tasks: max_tasks.max(1),
95 tasks: Mutex::new(Vec::new()),
96 locks: Mutex::new(HashMap::new()),
97 }
98 }
99
100 pub fn len(&self) -> usize {
102 self.tasks.lock().len()
103 }
104
105 pub fn is_empty(&self) -> bool {
107 self.len() == 0
108 }
109
110 pub fn is_full(&self) -> bool {
112 self.unfinished_count() >= self.max_tasks
113 }
114
115 pub fn spawn<F>(&self, key: Option<MaintenanceSequenceKey>, future: F) -> Result<()>
117 where
118 F: Future<Output = Result<()>> + Send + 'static,
119 {
120 let mut tasks = self.tasks.lock();
121 if tasks
122 .iter()
123 .filter(|task| !task.handle.is_finished())
124 .count()
125 >= self.max_tasks
126 {
127 return Err(AgentError::Other(format!(
128 "background maintenance queue is full (limit {})",
129 self.max_tasks
130 )));
131 }
132
133 let lock = key.as_ref().map(|key| {
134 let mut locks = self.locks.lock();
135 locks
136 .entry(key.clone())
137 .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
138 .clone()
139 });
140
141 let handle = tokio::spawn(async move {
142 if let Some(lock) = lock {
143 let _guard = lock.lock().await;
144 future.await
145 } else {
146 future.await
147 }
148 });
149 tasks.push(TrackedTask { key, handle });
150 Ok(())
151 }
152
153 pub async fn flush_all(&self) -> Result<()> {
155 let tasks = std::mem::take(&mut *self.tasks.lock());
156 for task in tasks {
157 match task.handle.await {
158 Ok(Ok(())) => {}
159 Ok(Err(error)) => return Err(error),
160 Err(error) => {
161 return Err(AgentError::Other(format!(
162 "background maintenance task failed to join: {}",
163 error
164 )));
165 }
166 }
167 }
168 Ok(())
169 }
170
171 pub async fn flush_scope(&self, scope_id: &str) -> Result<()> {
173 self.flush_matching(|key| key.scope_id == scope_id).await
174 }
175
176 pub async fn flush_purpose(&self, purpose: RuntimeTaskPurpose) -> Result<()> {
178 self.flush_matching(|key| key.task_kind == purpose).await
179 }
180
181 pub async fn flush_scope_purpose(
183 &self,
184 scope_id: &str,
185 purpose: RuntimeTaskPurpose,
186 ) -> Result<()> {
187 self.flush_matching(|key| key.scope_id == scope_id && key.task_kind == purpose)
188 .await
189 }
190
191 async fn flush_matching(
192 &self,
193 matches_key: impl Fn(&MaintenanceSequenceKey) -> bool,
194 ) -> Result<()> {
195 let (matching, remaining): (Vec<_>, Vec<_>) = std::mem::take(&mut *self.tasks.lock())
196 .into_iter()
197 .partition(|task| task.key.as_ref().map(&matches_key).unwrap_or(false));
198 *self.tasks.lock() = remaining;
199 for task in matching {
200 match task.handle.await {
201 Ok(Ok(())) => {}
202 Ok(Err(error)) => return Err(error),
203 Err(error) => {
204 return Err(AgentError::Other(format!(
205 "background maintenance task failed to join: {}",
206 error
207 )));
208 }
209 }
210 }
211 Ok(())
212 }
213
214 fn unfinished_count(&self) -> usize {
215 self.tasks
216 .lock()
217 .iter()
218 .filter(|task| !task.handle.is_finished())
219 .count()
220 }
221}
222
223impl Default for BackgroundMaintenanceQueue {
224 fn default() -> Self {
225 Self::new(16)
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232
233 #[tokio::test]
234 async fn finished_task_error_surfaces_on_flush_after_capacity_check() {
235 let queue = BackgroundMaintenanceQueue::new(1);
236 queue
237 .spawn(None, async {
238 Err(AgentError::Other("background failed".to_string()))
239 })
240 .unwrap();
241 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
242 assert!(!queue.is_full());
243 let error = queue.flush_all().await.unwrap_err();
244 assert!(error.to_string().contains("background failed"));
245 }
246
247 #[tokio::test]
248 async fn flush_scope_purpose_keeps_unmatched_tasks() {
249 let queue = BackgroundMaintenanceQueue::new(2);
250 let (release_tx, release_rx) = tokio::sync::oneshot::channel::<()>();
251 queue
252 .spawn(
253 Some(MaintenanceSequenceKey::actor(
254 "agent",
255 "actor",
256 RuntimeTaskPurpose::PostTurnFacts,
257 )),
258 async { Ok(()) },
259 )
260 .unwrap();
261 queue
262 .spawn(
263 Some(MaintenanceSequenceKey::actor(
264 "agent",
265 "actor",
266 RuntimeTaskPurpose::PostTurnRelationship,
267 )),
268 async move {
269 let _ = release_rx.await;
270 Ok(())
271 },
272 )
273 .unwrap();
274
275 queue
276 .flush_scope_purpose("actor", RuntimeTaskPurpose::PostTurnFacts)
277 .await
278 .unwrap();
279 assert_eq!(queue.len(), 1);
280 let _ = release_tx.send(());
281 queue.flush_all().await.unwrap();
282 assert!(queue.is_empty());
283 }
284}