ai_agents_runtime/optimization/
scheduler.rs1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4
5use ai_agents_core::{AgentError, Result};
6use futures::stream::{FuturesUnordered, StreamExt};
7
8use super::branch::{
9 RuntimeBranch, RuntimeBranchOutcome, RuntimeBranchResult, RuntimeBranchStatus,
10};
11use super::turn::TurnOptimizationContext;
12
13#[derive(Debug)]
15pub struct TurnBranchScheduler {
16 max_parallel_tasks: usize,
17 active_tasks: usize,
18}
19
20pub type BranchFuture<'a> = Pin<Box<dyn Future<Output = RuntimeBranchResult> + Send + 'a>>;
21
22pub struct ScheduledBranchSet<'a> {
23 scheduler: TurnBranchScheduler,
24 pending: HashMap<String, RuntimeBranch>,
25 futures: FuturesUnordered<Pin<Box<dyn Future<Output = RuntimeBranchOutcome> + Send + 'a>>>,
26}
27
28impl TurnBranchScheduler {
29 pub fn new(max_parallel_tasks: usize) -> Result<Self> {
30 if max_parallel_tasks == 0 {
31 return Err(AgentError::InvalidSpec(
32 "runtime.optimization.max_parallel_runtime_tasks must be greater than 0".into(),
33 ));
34 }
35 Ok(Self {
36 max_parallel_tasks,
37 active_tasks: 0,
38 })
39 }
40
41 pub fn can_schedule_branch(&self) -> bool {
42 self.active_tasks < self.max_parallel_tasks
43 }
44
45 pub fn reserve_task(&mut self) -> bool {
46 if !self.can_schedule_branch() {
47 return false;
48 }
49 self.active_tasks += 1;
50 true
51 }
52
53 pub fn reserve_llm_branch(&mut self, turn: &mut TurnOptimizationContext) -> bool {
54 if !self.can_schedule_branch() || !turn.reserve_speculative_llm_call() {
55 return false;
56 }
57 self.active_tasks += 1;
58 true
59 }
60
61 pub fn reserve_branch(&mut self, turn: &mut TurnOptimizationContext) -> bool {
62 self.reserve_llm_branch(turn)
63 }
64
65 pub fn release_task(&mut self) {
66 self.active_tasks = self.active_tasks.saturating_sub(1);
67 }
68
69 pub fn complete_branch(&mut self, branch: &mut RuntimeBranch) -> Result<()> {
70 self.release_task();
71 branch.transition_to(RuntimeBranchStatus::Completed)
72 }
73}
74
75impl<'a> ScheduledBranchSet<'a> {
76 pub fn new(max_parallel_tasks: usize) -> Result<Self> {
77 Ok(Self {
78 scheduler: TurnBranchScheduler::new(max_parallel_tasks)?,
79 pending: HashMap::new(),
80 futures: FuturesUnordered::new(),
81 })
82 }
83
84 pub fn reserve_task(&mut self) -> bool {
85 self.scheduler.reserve_task()
86 }
87
88 pub fn release_task(&mut self) {
89 self.scheduler.release_task();
90 }
91
92 pub fn schedule(&mut self, branch: RuntimeBranch, future: BranchFuture<'a>) -> bool {
93 if !self.scheduler.can_schedule_branch() {
94 return false;
95 }
96 self.scheduler.active_tasks += 1;
97 let id = branch.branch_id();
98 self.pending.insert(id.clone(), branch.clone());
99 self.futures.push(Box::pin(async move {
100 let mut branch = branch;
101 let result = future.await;
102 let _ = branch.complete();
103 RuntimeBranchOutcome { branch, result }
104 }));
105 true
106 }
107
108 pub async fn next_completed(&mut self) -> Option<RuntimeBranchOutcome> {
109 let outcome = self.futures.next().await?;
110 self.pending.remove(&outcome.branch.branch_id());
111 self.scheduler.release_task();
112 Some(outcome)
113 }
114
115 pub fn cancel_pending(&mut self) -> Vec<RuntimeBranch> {
116 self.futures = FuturesUnordered::new();
117 let pending = std::mem::take(&mut self.pending)
118 .into_values()
119 .collect::<Vec<_>>();
120 for _ in &pending {
121 self.scheduler.release_task();
122 }
123 pending
124 }
125
126 pub fn is_empty(&self) -> bool {
127 self.pending.is_empty()
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134 use std::collections::HashMap;
135
136 #[test]
137 fn scheduler_enforces_parallel_limit() {
138 let mut scheduler = TurnBranchScheduler::new(1).unwrap();
139 let mut turn = TurnOptimizationContext::new("input", HashMap::new(), 2);
140 assert!(scheduler.reserve_branch(&mut turn));
141 assert!(!scheduler.reserve_branch(&mut turn));
142 }
143
144 #[test]
145 fn scheduler_enforces_speculative_call_limit() {
146 let mut scheduler = TurnBranchScheduler::new(2).unwrap();
147 let mut turn = TurnOptimizationContext::new("input", HashMap::new(), 1);
148 assert!(scheduler.reserve_branch(&mut turn));
149 assert!(!scheduler.reserve_branch(&mut turn));
150 }
151}