Skip to main content

nemo_flow_adaptive/
tool_parallelism_learner.rs

1// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Learner that derives tool parallelism plans from observed runs.
5
6use std::collections::{BTreeMap, BTreeSet, HashSet};
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::{Arc, RwLock};
10
11use crate::acg::canonicalize::sha256_hex;
12use chrono::{DateTime, Utc};
13use serde_json::json;
14use uuid::Uuid;
15
16use crate::error::{AdaptiveError, Result};
17use crate::learner::traits::Learner;
18use crate::storage::traits::StorageBackendDyn;
19use crate::types::cache::HotCache;
20use crate::types::metadata::{MetadataEnvelope, ParallelHint};
21use crate::types::plan::{ExecutionPlan, ParallelGroup};
22use crate::types::records::{CallKind, RunRecord};
23
24/// Learner that discovers tool fan-out groups from run telemetry.
25pub struct ToolParallelismLearner {
26    agent_id: String,
27}
28
29impl ToolParallelismLearner {
30    /// Create a new tool-parallelism learner.
31    ///
32    /// # Parameters
33    /// - `agent_id`: Agent identifier whose execution plan should be updated.
34    ///
35    /// # Returns
36    /// A configured [`ToolParallelismLearner`].
37    pub fn new(agent_id: impl Into<String>) -> Self {
38        Self {
39            agent_id: agent_id.into(),
40        }
41    }
42}
43
44impl Learner for ToolParallelismLearner {
45    fn process_run<'a>(
46        &'a self,
47        run: &'a RunRecord,
48        backend: &'a dyn StorageBackendDyn,
49        hot_cache: &'a Arc<RwLock<HotCache>>,
50    ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>> {
51        Box::pin(async move {
52            let observed_cohorts = derive_observed_cohorts(run);
53            if observed_cohorts.is_empty() {
54                return Ok(());
55            }
56
57            let mut plan = backend
58                .load_plan_dyn(&self.agent_id)
59                .await?
60                .unwrap_or_else(|| empty_execution_plan(&self.agent_id, run.id));
61            plan.agent_id = self.agent_id.clone();
62
63            merge_observed_cohorts(&mut plan, &observed_cohorts, run.id);
64            backend.store_plan(&plan)?;
65
66            let mut guard = hot_cache.write().map_err(|error| {
67                AdaptiveError::Internal(format!("hot cache lock poisoned: {error}"))
68            })?;
69            guard.plan = Some(plan.clone());
70            Ok(())
71        })
72    }
73}
74
75#[derive(Clone)]
76struct ObservedToolCall {
77    name: String,
78    started_at: DateTime<Utc>,
79    ended_at: DateTime<Utc>,
80}
81
82fn derive_observed_cohorts(run: &RunRecord) -> Vec<Vec<String>> {
83    let mut calls: Vec<ObservedToolCall> = run
84        .calls
85        .iter()
86        .filter(|call| call.kind == CallKind::Tool)
87        .filter_map(|call| {
88            call.ended_at.map(|ended_at| ObservedToolCall {
89                name: call.name.clone(),
90                started_at: call.started_at,
91                ended_at,
92            })
93        })
94        .collect();
95    calls.sort_by_key(|call| call.started_at);
96
97    let mut active: Vec<ObservedToolCall> = Vec::new();
98    let mut cohorts: HashSet<Vec<String>> = HashSet::new();
99
100    for current in calls {
101        active.retain(|call| call.ended_at > current.started_at);
102        if active.len() + 1 > 1 {
103            let mut tool_names: Vec<String> = active.iter().map(|call| call.name.clone()).collect();
104            tool_names.push(current.name.clone());
105            tool_names.sort();
106            cohorts.insert(tool_names);
107        }
108        active.push(current);
109    }
110
111    let mut observed: Vec<Vec<String>> = cohorts.into_iter().collect();
112    observed.sort();
113    observed
114}
115
116fn merge_observed_cohorts(
117    plan: &mut ExecutionPlan,
118    observed_cohorts: &[Vec<String>],
119    run_id: Uuid,
120) {
121    let mut groups_by_id: BTreeMap<String, ParallelGroup> = plan
122        .parallel_groups
123        .iter()
124        .cloned()
125        .map(|group| (group.group_id.clone(), group))
126        .collect();
127    let mut hints_by_key: BTreeMap<(String, String), ParallelHint> = plan
128        .metadata_template
129        .parallel_hints
130        .iter()
131        .cloned()
132        .map(|hint| ((hint.tool_name.clone(), hint.group_id.clone()), hint))
133        .collect();
134
135    for cohort in observed_cohorts {
136        let group = build_parallel_group(cohort);
137        let group_id = group.group_id.clone();
138        let mut unique_tool_names: BTreeSet<String> = BTreeSet::new();
139        for tool_name in &group.tool_names {
140            if unique_tool_names.insert(tool_name.clone()) {
141                hints_by_key.insert(
142                    (tool_name.clone(), group_id.clone()),
143                    ParallelHint {
144                        tool_name: tool_name.clone(),
145                        group_id: group_id.clone(),
146                        explicit: false,
147                    },
148                );
149            }
150        }
151        groups_by_id.insert(group_id, group);
152    }
153
154    plan.parallel_groups = groups_by_id.into_values().collect();
155    plan.metadata_template.agent_id = plan.agent_id.clone();
156    plan.metadata_template.run_id = run_id;
157    plan.metadata_template.parallel_hints = hints_by_key.into_values().collect();
158}
159
160fn build_parallel_group(tool_names: &[String]) -> ParallelGroup {
161    let joined = tool_names.join("|");
162    let group_hash = sha256_hex(&joined);
163    ParallelGroup {
164        group_id: format!("fanout:{}", &group_hash[..12]),
165        tool_names: tool_names.to_vec(),
166    }
167}
168
169fn empty_execution_plan(agent_id: &str, run_id: Uuid) -> ExecutionPlan {
170    ExecutionPlan {
171        agent_id: agent_id.to_string(),
172        parallel_groups: vec![],
173        metadata_template: MetadataEnvelope {
174            run_id,
175            agent_id: agent_id.to_string(),
176            parallel_hints: vec![],
177            extensions: json!({}),
178        },
179    }
180}
181
182#[cfg(test)]
183#[path = "../tests/unit/tool_parallelism_learner_tests.rs"]
184mod tests;