nemo_flow_adaptive/
tool_parallelism_learner.rs1use std::collections::{BTreeMap, BTreeSet, HashSet};
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::{Arc, RwLock};
10
11use crate::acg::canonicalize::sha256_hex;
12use chrono::{DateTime, Utc};
13use serde_json::json;
14use uuid::Uuid;
15
16use crate::error::{AdaptiveError, Result};
17use crate::learner::traits::Learner;
18use crate::storage::traits::StorageBackendDyn;
19use crate::types::cache::HotCache;
20use crate::types::metadata::{MetadataEnvelope, ParallelHint};
21use crate::types::plan::{ExecutionPlan, ParallelGroup};
22use crate::types::records::{CallKind, RunRecord};
23
24pub struct ToolParallelismLearner {
26 agent_id: String,
27}
28
29impl ToolParallelismLearner {
30 pub fn new(agent_id: impl Into<String>) -> Self {
38 Self {
39 agent_id: agent_id.into(),
40 }
41 }
42}
43
44impl Learner for ToolParallelismLearner {
45 fn process_run<'a>(
46 &'a self,
47 run: &'a RunRecord,
48 backend: &'a dyn StorageBackendDyn,
49 hot_cache: &'a Arc<RwLock<HotCache>>,
50 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>> {
51 Box::pin(async move {
52 let observed_cohorts = derive_observed_cohorts(run);
53 if observed_cohorts.is_empty() {
54 return Ok(());
55 }
56
57 let mut plan = backend
58 .load_plan_dyn(&self.agent_id)
59 .await?
60 .unwrap_or_else(|| empty_execution_plan(&self.agent_id, run.id));
61 plan.agent_id = self.agent_id.clone();
62
63 merge_observed_cohorts(&mut plan, &observed_cohorts, run.id);
64 backend.store_plan(&plan)?;
65
66 let mut guard = hot_cache.write().map_err(|error| {
67 AdaptiveError::Internal(format!("hot cache lock poisoned: {error}"))
68 })?;
69 guard.plan = Some(plan.clone());
70 Ok(())
71 })
72 }
73}
74
75#[derive(Clone)]
76struct ObservedToolCall {
77 name: String,
78 started_at: DateTime<Utc>,
79 ended_at: DateTime<Utc>,
80}
81
82fn derive_observed_cohorts(run: &RunRecord) -> Vec<Vec<String>> {
83 let mut calls: Vec<ObservedToolCall> = run
84 .calls
85 .iter()
86 .filter(|call| call.kind == CallKind::Tool)
87 .filter_map(|call| {
88 call.ended_at.map(|ended_at| ObservedToolCall {
89 name: call.name.clone(),
90 started_at: call.started_at,
91 ended_at,
92 })
93 })
94 .collect();
95 calls.sort_by_key(|call| call.started_at);
96
97 let mut active: Vec<ObservedToolCall> = Vec::new();
98 let mut cohorts: HashSet<Vec<String>> = HashSet::new();
99
100 for current in calls {
101 active.retain(|call| call.ended_at > current.started_at);
102 if active.len() + 1 > 1 {
103 let mut tool_names: Vec<String> = active.iter().map(|call| call.name.clone()).collect();
104 tool_names.push(current.name.clone());
105 tool_names.sort();
106 cohorts.insert(tool_names);
107 }
108 active.push(current);
109 }
110
111 let mut observed: Vec<Vec<String>> = cohorts.into_iter().collect();
112 observed.sort();
113 observed
114}
115
116fn merge_observed_cohorts(
117 plan: &mut ExecutionPlan,
118 observed_cohorts: &[Vec<String>],
119 run_id: Uuid,
120) {
121 let mut groups_by_id: BTreeMap<String, ParallelGroup> = plan
122 .parallel_groups
123 .iter()
124 .cloned()
125 .map(|group| (group.group_id.clone(), group))
126 .collect();
127 let mut hints_by_key: BTreeMap<(String, String), ParallelHint> = plan
128 .metadata_template
129 .parallel_hints
130 .iter()
131 .cloned()
132 .map(|hint| ((hint.tool_name.clone(), hint.group_id.clone()), hint))
133 .collect();
134
135 for cohort in observed_cohorts {
136 let group = build_parallel_group(cohort);
137 let group_id = group.group_id.clone();
138 let mut unique_tool_names: BTreeSet<String> = BTreeSet::new();
139 for tool_name in &group.tool_names {
140 if unique_tool_names.insert(tool_name.clone()) {
141 hints_by_key.insert(
142 (tool_name.clone(), group_id.clone()),
143 ParallelHint {
144 tool_name: tool_name.clone(),
145 group_id: group_id.clone(),
146 explicit: false,
147 },
148 );
149 }
150 }
151 groups_by_id.insert(group_id, group);
152 }
153
154 plan.parallel_groups = groups_by_id.into_values().collect();
155 plan.metadata_template.agent_id = plan.agent_id.clone();
156 plan.metadata_template.run_id = run_id;
157 plan.metadata_template.parallel_hints = hints_by_key.into_values().collect();
158}
159
160fn build_parallel_group(tool_names: &[String]) -> ParallelGroup {
161 let joined = tool_names.join("|");
162 let group_hash = sha256_hex(&joined);
163 ParallelGroup {
164 group_id: format!("fanout:{}", &group_hash[..12]),
165 tool_names: tool_names.to_vec(),
166 }
167}
168
169fn empty_execution_plan(agent_id: &str, run_id: Uuid) -> ExecutionPlan {
170 ExecutionPlan {
171 agent_id: agent_id.to_string(),
172 parallel_groups: vec![],
173 metadata_template: MetadataEnvelope {
174 run_id,
175 agent_id: agent_id.to_string(),
176 parallel_hints: vec![],
177 extensions: json!({}),
178 },
179 }
180}
181
182#[cfg(test)]
183#[path = "../tests/unit/tool_parallelism_learner_tests.rs"]
184mod tests;