Skip to main content

ai_agents_runtime/optimization/
maintenance.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::hash::{Hash, Hasher};
4use std::sync::Arc;
5
6use parking_lot::Mutex;
7use tokio::task::JoinHandle;
8
9use ai_agents_core::{AgentError, Result};
10
11/// Runtime work categories used for optimization and background ordering.
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
13pub enum RuntimeTaskPurpose {
14    /// Main user-visible response generation.
15    MainResponse,
16    /// State transition selection or commit work.
17    StateTransition,
18    /// Skill routing decision work.
19    SkillRouting,
20    /// Automatic reasoning mode decision work.
21    ReasoningJudge,
22    /// Post-turn fact extraction.
23    PostTurnFacts,
24    /// Post-turn relationship update.
25    PostTurnRelationship,
26    /// Post-turn session maintenance.
27    PostTurnSessionMaintenance,
28    /// Post-turn memory compression.
29    PostTurnCompression,
30    /// Orchestration vote extraction.
31    OrchestrationVoteExtraction,
32    /// Observability export work.
33    ObservabilityExport,
34}
35
36/// Sequence key used to preserve order for actor or session scoped maintenance.
37#[derive(Debug, Clone, Eq)]
38pub struct MaintenanceSequenceKey {
39    /// Agent whose maintenance task owns the sequence.
40    pub agent_id: String,
41    /// Actor, session, or resource identifier for ordering.
42    pub scope_id: String,
43    /// Task kind that must remain ordered within the scope.
44    pub task_kind: RuntimeTaskPurpose,
45}
46
47impl MaintenanceSequenceKey {
48    /// Creates an actor-scoped sequence key.
49    pub fn actor(
50        agent_id: impl Into<String>,
51        actor_id: impl Into<String>,
52        task_kind: RuntimeTaskPurpose,
53    ) -> Self {
54        Self {
55            agent_id: agent_id.into(),
56            scope_id: actor_id.into(),
57            task_kind,
58        }
59    }
60}
61
62impl PartialEq for MaintenanceSequenceKey {
63    fn eq(&self, other: &Self) -> bool {
64        self.agent_id == other.agent_id
65            && self.scope_id == other.scope_id
66            && self.task_kind == other.task_kind
67    }
68}
69
70impl Hash for MaintenanceSequenceKey {
71    fn hash<H: Hasher>(&self, state: &mut H) {
72        self.agent_id.hash(state);
73        self.scope_id.hash(state);
74        self.task_kind.hash(state);
75    }
76}
77
78struct TrackedTask {
79    key: Option<MaintenanceSequenceKey>,
80    handle: JoinHandle<Result<()>>,
81}
82
83/// Bounded queue for post-turn maintenance that can be flushed by eval and shutdown paths.
84pub struct BackgroundMaintenanceQueue {
85    max_tasks: usize,
86    tasks: Mutex<Vec<TrackedTask>>,
87    locks: Mutex<HashMap<MaintenanceSequenceKey, Arc<tokio::sync::Mutex<()>>>>,
88}
89
90impl BackgroundMaintenanceQueue {
91    /// Creates a queue with a positive task limit.
92    pub fn new(max_tasks: usize) -> Self {
93        Self {
94            max_tasks: max_tasks.max(1),
95            tasks: Mutex::new(Vec::new()),
96            locks: Mutex::new(HashMap::new()),
97        }
98    }
99
100    /// Returns the number of tracked tasks that have not been flushed yet.
101    pub fn len(&self) -> usize {
102        self.tasks.lock().len()
103    }
104
105    /// Returns true when no tracked tasks remain.
106    pub fn is_empty(&self) -> bool {
107        self.len() == 0
108    }
109
110    /// Returns true when the queue cannot accept another tracked task.
111    pub fn is_full(&self) -> bool {
112        self.unfinished_count() >= self.max_tasks
113    }
114
115    /// Spawns a background task and applies per-key ordering when a key is supplied.
116    pub fn spawn<F>(&self, key: Option<MaintenanceSequenceKey>, future: F) -> Result<()>
117    where
118        F: Future<Output = Result<()>> + Send + 'static,
119    {
120        let mut tasks = self.tasks.lock();
121        if tasks
122            .iter()
123            .filter(|task| !task.handle.is_finished())
124            .count()
125            >= self.max_tasks
126        {
127            return Err(AgentError::Other(format!(
128                "background maintenance queue is full (limit {})",
129                self.max_tasks
130            )));
131        }
132
133        let lock = key.as_ref().map(|key| {
134            let mut locks = self.locks.lock();
135            locks
136                .entry(key.clone())
137                .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
138                .clone()
139        });
140
141        let handle = tokio::spawn(async move {
142            if let Some(lock) = lock {
143                let _guard = lock.lock().await;
144                future.await
145            } else {
146                future.await
147            }
148        });
149        tasks.push(TrackedTask { key, handle });
150        Ok(())
151    }
152
153    /// Waits for all tracked background tasks to complete.
154    pub async fn flush_all(&self) -> Result<()> {
155        let tasks = std::mem::take(&mut *self.tasks.lock());
156        for task in tasks {
157            match task.handle.await {
158                Ok(Ok(())) => {}
159                Ok(Err(error)) => return Err(error),
160                Err(error) => {
161                    return Err(AgentError::Other(format!(
162                        "background maintenance task failed to join: {}",
163                        error
164                    )));
165                }
166            }
167        }
168        Ok(())
169    }
170
171    /// Waits for tasks with the requested scope identifier and keeps unrelated work queued.
172    pub async fn flush_scope(&self, scope_id: &str) -> Result<()> {
173        self.flush_matching(|key| key.scope_id == scope_id).await
174    }
175
176    /// Waits for tasks with the requested purpose and keeps unrelated work queued.
177    pub async fn flush_purpose(&self, purpose: RuntimeTaskPurpose) -> Result<()> {
178        self.flush_matching(|key| key.task_kind == purpose).await
179    }
180
181    /// Waits for tasks with the requested scope and purpose.
182    pub async fn flush_scope_purpose(
183        &self,
184        scope_id: &str,
185        purpose: RuntimeTaskPurpose,
186    ) -> Result<()> {
187        self.flush_matching(|key| key.scope_id == scope_id && key.task_kind == purpose)
188            .await
189    }
190
191    async fn flush_matching(
192        &self,
193        matches_key: impl Fn(&MaintenanceSequenceKey) -> bool,
194    ) -> Result<()> {
195        let (matching, remaining): (Vec<_>, Vec<_>) = std::mem::take(&mut *self.tasks.lock())
196            .into_iter()
197            .partition(|task| task.key.as_ref().map(&matches_key).unwrap_or(false));
198        *self.tasks.lock() = remaining;
199        for task in matching {
200            match task.handle.await {
201                Ok(Ok(())) => {}
202                Ok(Err(error)) => return Err(error),
203                Err(error) => {
204                    return Err(AgentError::Other(format!(
205                        "background maintenance task failed to join: {}",
206                        error
207                    )));
208                }
209            }
210        }
211        Ok(())
212    }
213
214    fn unfinished_count(&self) -> usize {
215        self.tasks
216            .lock()
217            .iter()
218            .filter(|task| !task.handle.is_finished())
219            .count()
220    }
221}
222
223impl Default for BackgroundMaintenanceQueue {
224    fn default() -> Self {
225        Self::new(16)
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232
233    #[tokio::test]
234    async fn finished_task_error_surfaces_on_flush_after_capacity_check() {
235        let queue = BackgroundMaintenanceQueue::new(1);
236        queue
237            .spawn(None, async {
238                Err(AgentError::Other("background failed".to_string()))
239            })
240            .unwrap();
241        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
242        assert!(!queue.is_full());
243        let error = queue.flush_all().await.unwrap_err();
244        assert!(error.to_string().contains("background failed"));
245    }
246
247    #[tokio::test]
248    async fn flush_scope_purpose_keeps_unmatched_tasks() {
249        let queue = BackgroundMaintenanceQueue::new(2);
250        let (release_tx, release_rx) = tokio::sync::oneshot::channel::<()>();
251        queue
252            .spawn(
253                Some(MaintenanceSequenceKey::actor(
254                    "agent",
255                    "actor",
256                    RuntimeTaskPurpose::PostTurnFacts,
257                )),
258                async { Ok(()) },
259            )
260            .unwrap();
261        queue
262            .spawn(
263                Some(MaintenanceSequenceKey::actor(
264                    "agent",
265                    "actor",
266                    RuntimeTaskPurpose::PostTurnRelationship,
267                )),
268                async move {
269                    let _ = release_rx.await;
270                    Ok(())
271                },
272            )
273            .unwrap();
274
275        queue
276            .flush_scope_purpose("actor", RuntimeTaskPurpose::PostTurnFacts)
277            .await
278            .unwrap();
279        assert_eq!(queue.len(), 1);
280        let _ = release_tx.send(());
281        queue.flush_all().await.unwrap();
282        assert!(queue.is_empty());
283    }
284}