Skip to main content

ai_agents_runtime/optimization/
scheduler.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4
5use ai_agents_core::{AgentError, Result};
6use futures::stream::{FuturesUnordered, StreamExt};
7
8use super::branch::{
9    RuntimeBranch, RuntimeBranchOutcome, RuntimeBranchResult, RuntimeBranchStatus,
10};
11use super::turn::TurnOptimizationContext;
12
13/// Small per-turn scheduler guard for branch limits and reservations.
14#[derive(Debug)]
15pub struct TurnBranchScheduler {
16    max_parallel_tasks: usize,
17    active_tasks: usize,
18}
19
20pub type BranchFuture<'a> = Pin<Box<dyn Future<Output = RuntimeBranchResult> + Send + 'a>>;
21
22pub struct ScheduledBranchSet<'a> {
23    scheduler: TurnBranchScheduler,
24    pending: HashMap<String, RuntimeBranch>,
25    futures: FuturesUnordered<Pin<Box<dyn Future<Output = RuntimeBranchOutcome> + Send + 'a>>>,
26}
27
28impl TurnBranchScheduler {
29    pub fn new(max_parallel_tasks: usize) -> Result<Self> {
30        if max_parallel_tasks == 0 {
31            return Err(AgentError::InvalidSpec(
32                "runtime.optimization.max_parallel_runtime_tasks must be greater than 0".into(),
33            ));
34        }
35        Ok(Self {
36            max_parallel_tasks,
37            active_tasks: 0,
38        })
39    }
40
41    pub fn can_schedule_branch(&self) -> bool {
42        self.active_tasks < self.max_parallel_tasks
43    }
44
45    pub fn reserve_task(&mut self) -> bool {
46        if !self.can_schedule_branch() {
47            return false;
48        }
49        self.active_tasks += 1;
50        true
51    }
52
53    pub fn reserve_llm_branch(&mut self, turn: &mut TurnOptimizationContext) -> bool {
54        if !self.can_schedule_branch() || !turn.reserve_speculative_llm_call() {
55            return false;
56        }
57        self.active_tasks += 1;
58        true
59    }
60
61    pub fn reserve_branch(&mut self, turn: &mut TurnOptimizationContext) -> bool {
62        self.reserve_llm_branch(turn)
63    }
64
65    pub fn release_task(&mut self) {
66        self.active_tasks = self.active_tasks.saturating_sub(1);
67    }
68
69    pub fn complete_branch(&mut self, branch: &mut RuntimeBranch) -> Result<()> {
70        self.release_task();
71        branch.transition_to(RuntimeBranchStatus::Completed)
72    }
73}
74
75impl<'a> ScheduledBranchSet<'a> {
76    pub fn new(max_parallel_tasks: usize) -> Result<Self> {
77        Ok(Self {
78            scheduler: TurnBranchScheduler::new(max_parallel_tasks)?,
79            pending: HashMap::new(),
80            futures: FuturesUnordered::new(),
81        })
82    }
83
84    pub fn reserve_task(&mut self) -> bool {
85        self.scheduler.reserve_task()
86    }
87
88    pub fn release_task(&mut self) {
89        self.scheduler.release_task();
90    }
91
92    pub fn schedule(&mut self, branch: RuntimeBranch, future: BranchFuture<'a>) -> bool {
93        if !self.scheduler.can_schedule_branch() {
94            return false;
95        }
96        self.scheduler.active_tasks += 1;
97        let id = branch.branch_id();
98        self.pending.insert(id.clone(), branch.clone());
99        self.futures.push(Box::pin(async move {
100            let mut branch = branch;
101            let result = future.await;
102            let _ = branch.complete();
103            RuntimeBranchOutcome { branch, result }
104        }));
105        true
106    }
107
108    pub async fn next_completed(&mut self) -> Option<RuntimeBranchOutcome> {
109        let outcome = self.futures.next().await?;
110        self.pending.remove(&outcome.branch.branch_id());
111        self.scheduler.release_task();
112        Some(outcome)
113    }
114
115    pub fn cancel_pending(&mut self) -> Vec<RuntimeBranch> {
116        self.futures = FuturesUnordered::new();
117        let pending = std::mem::take(&mut self.pending)
118            .into_values()
119            .collect::<Vec<_>>();
120        for _ in &pending {
121            self.scheduler.release_task();
122        }
123        pending
124    }
125
126    pub fn is_empty(&self) -> bool {
127        self.pending.is_empty()
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134    use std::collections::HashMap;
135
136    #[test]
137    fn scheduler_enforces_parallel_limit() {
138        let mut scheduler = TurnBranchScheduler::new(1).unwrap();
139        let mut turn = TurnOptimizationContext::new("input", HashMap::new(), 2);
140        assert!(scheduler.reserve_branch(&mut turn));
141        assert!(!scheduler.reserve_branch(&mut turn));
142    }
143
144    #[test]
145    fn scheduler_enforces_speculative_call_limit() {
146        let mut scheduler = TurnBranchScheduler::new(2).unwrap();
147        let mut turn = TurnOptimizationContext::new("input", HashMap::new(), 1);
148        assert!(scheduler.reserve_branch(&mut turn));
149        assert!(!scheduler.reserve_branch(&mut turn));
150    }
151}