Skip to main content

nemo_flow_adaptive/trie/
builder.rs

1// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Prediction trie builder with incremental accumulator merge.
5//!
6//! Ports the core algorithm from NAT's `trie_builder.py`: extract LLM call
7//! contexts from run records, compute 4-signal sensitivity scores with
8//! min-max normalization, update streaming accumulators at every trie node
9//! along the path, and build the final [`PredictionTrieNode`] tree.
10
11use std::collections::HashMap;
12
13use serde::{Deserialize, Serialize};
14
15use super::accumulator::{AccumulatorState, NodeAccumulators, RunningStats};
16use super::data_models::{LlmCallPrediction, PredictionTrieNode};
17use crate::types::records::{CallKind, CallRecord, RunRecord};
18
19/// Configuration for auto-sensitivity scoring.
20///
21/// Weights and scale match NAT defaults from trie_builder.py lines 41-48.
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct SensitivityConfig {
24    /// Integer scale for quantized sensitivity (1..=scale).
25    pub sensitivity_scale: u32,
26    /// Weight for the critical-path signal.
27    pub w_critical: f64,
28    /// Weight for the fan-out signal.
29    pub w_fanout: f64,
30    /// Weight for the U-shaped position signal.
31    pub w_position: f64,
32    /// Weight for the parallel-penalty signal.
33    pub w_parallel: f64,
34}
35
36impl Default for SensitivityConfig {
37    fn default() -> Self {
38        Self {
39            sensitivity_scale: 5,
40            w_critical: 0.5,
41            w_fanout: 0.3,
42            w_position: 0.2,
43            w_parallel: 0.0,
44        }
45    }
46}
47
48/// Internal context for a single LLM call extracted from a [`RunRecord`].
49#[derive(Debug, Clone)]
50pub(crate) struct LlmCallContext {
51    pub path: Vec<String>,
52    pub call_index: u32,
53    pub remaining_calls: u32,
54    pub time_to_next_ms: Option<f64>,
55    pub output_tokens: u32,
56    pub call_duration_s: f64,
57    pub workflow_duration_s: f64,
58    pub parallel_slack_ratio: f64,
59    pub sensitivity_score: f64,
60    pub span_start_time: f64,
61    pub span_end_time: f64,
62}
63
64/// Builds a [`PredictionTrieNode`] tree from [`RunRecord`]s via incremental
65/// accumulator merge.
66///
67/// # Usage
68///
69/// ```ignore
70/// let mut builder = PredictionTrieBuilder::new(Some(SensitivityConfig::default()));
71/// builder.add_run(&run1);
72/// builder.add_run(&run2);
73/// let trie = builder.build();
74/// ```
75pub struct PredictionTrieBuilder {
76    accumulators: AccumulatorState,
77    sensitivity_config: Option<SensitivityConfig>,
78}
79
80impl PredictionTrieBuilder {
81    /// Creates a new builder with optional sensitivity scoring.
82    pub fn new(sensitivity_config: Option<SensitivityConfig>) -> Self {
83        Self {
84            accumulators: AccumulatorState::default(),
85            sensitivity_config,
86        }
87    }
88
89    /// Creates a builder seeded with pre-existing accumulators.
90    ///
91    /// Used by the learner pipeline to resume incremental learning
92    /// from a stored [`AccumulatorState`].
93    pub fn with_accumulators(
94        accumulators: AccumulatorState,
95        sensitivity_config: Option<SensitivityConfig>,
96    ) -> Self {
97        Self {
98            accumulators,
99            sensitivity_config,
100        }
101    }
102
103    /// Processes a single [`RunRecord`] and updates accumulators.
104    ///
105    /// Extracts LLM call contexts, optionally computes sensitivity scores,
106    /// and updates accumulators at every node along each call's path.
107    pub fn add_run(&mut self, run: &RunRecord) {
108        let mut contexts = extract_llm_contexts(run);
109        if let Some(ref config) = self.sensitivity_config {
110            compute_sensitivity_scores(&mut contexts, config);
111        }
112        for ctx in &contexts {
113            self.update_accumulators(ctx);
114        }
115    }
116
117    /// Constructs the prediction trie from accumulated data.
118    ///
119    /// Iterates all accumulated nodes, navigates/creates the trie path,
120    /// and populates predictions from the accumulators.
121    pub fn build(&self) -> PredictionTrieNode {
122        let mut root = PredictionTrieNode::new("root");
123
124        for (path_key, node_accs) in &self.accumulators.nodes {
125            let node = get_or_create_node(&mut root, path_key);
126            populate_node_predictions(node, node_accs, &self.sensitivity_config);
127        }
128
129        root
130    }
131
132    /// Returns a reference to the underlying accumulator state.
133    pub fn accumulators(&self) -> &AccumulatorState {
134        &self.accumulators
135    }
136
137    /// Updates accumulators at root + each ancestor + leaf for a given context.
138    fn update_accumulators(&mut self, ctx: &LlmCallContext) {
139        let has_sensitivity = self.sensitivity_config.is_some();
140
141        // Update root node (key = "")
142        let root_accs = self.accumulators.nodes.entry(String::new()).or_default();
143        add_to_accumulators(root_accs, ctx, has_sensitivity);
144
145        // Update each node along the path
146        for i in 0..ctx.path.len() {
147            let path_key = ctx.path[..=i].join("/");
148            let node_accs = self.accumulators.nodes.entry(path_key).or_default();
149            add_to_accumulators(node_accs, ctx, has_sensitivity);
150        }
151    }
152}
153
154/// Extracts [`LlmCallContext`]s from a [`RunRecord`].
155///
156/// Port of NAT's `_extract_llm_contexts` adapted for `RunRecord`/`CallRecord`.
157/// Only completed LLM calls (with `ended_at`) are extracted.
158fn extract_llm_contexts(run: &RunRecord) -> Vec<LlmCallContext> {
159    // Compute workflow duration
160    let workflow_duration_s = if let Some(end) = run.ended_at {
161        (end - run.started_at).num_milliseconds() as f64 / 1000.0
162    } else {
163        // Fall back to last call ended_at
164        run.calls
165            .iter()
166            .filter_map(|c| c.ended_at)
167            .max()
168            .map(|end| (end - run.started_at).num_milliseconds() as f64 / 1000.0)
169            .unwrap_or(0.0)
170    };
171
172    // Collect completed LLM calls with their original indices
173    let llm_calls: Vec<(usize, &CallRecord)> = run
174        .calls
175        .iter()
176        .enumerate()
177        .filter(|(_, c)| c.kind == CallKind::Llm && c.ended_at.is_some())
178        .collect();
179
180    let total_llm = llm_calls.len();
181
182    // Track call_index per parent key (for Phase 4, parent = call name)
183    let mut call_counts: HashMap<String, u32> = HashMap::new();
184
185    let mut contexts = Vec::with_capacity(total_llm);
186
187    for (llm_pos, (orig_idx, call)) in llm_calls.iter().enumerate() {
188        let ended_at = call.ended_at.expect("filtered to completed calls");
189
190        // Path: Phase 4 simplification -- single-element vec with call name
191        let path = vec![call.name.clone()];
192
193        // Call index per parent
194        let counter = call_counts.entry(call.name.clone()).or_insert(0);
195        *counter += 1;
196        let call_index = *counter;
197
198        // Remaining calls
199        let remaining_calls = (total_llm - llm_pos - 1) as u32;
200
201        // Time to next LLM start: scan forward in ALL calls to find next LLM start
202        let time_to_next_ms = run
203            .calls
204            .iter()
205            .skip(orig_idx + 1)
206            .find(|c| c.kind == CallKind::Llm)
207            .map(|next_llm| {
208                next_llm
209                    .started_at
210                    .signed_duration_since(ended_at)
211                    .num_milliseconds() as f64
212            });
213
214        // Output tokens
215        let output_tokens = call.output_tokens.unwrap_or(0);
216
217        // Call duration
218        let call_duration_s = (ended_at - call.started_at).num_milliseconds() as f64 / 1000.0;
219
220        // Span timestamps
221        let span_start_time = call.started_at.timestamp() as f64;
222        let span_end_time = ended_at.timestamp() as f64;
223
224        contexts.push(LlmCallContext {
225            path,
226            call_index,
227            remaining_calls,
228            time_to_next_ms,
229            output_tokens,
230            call_duration_s,
231            workflow_duration_s,
232            parallel_slack_ratio: 0.0,
233            sensitivity_score: 0.0,
234            span_start_time,
235            span_end_time,
236        });
237    }
238
239    contexts
240}
241
242/// Computes composite sensitivity scores for each call in a trace.
243///
244/// Direct port of NAT trie_builder.py lines 186-272: four weighted signals
245/// (critical path, fan-out, position, parallel penalty) with min-max
246/// normalization across the trace.
247fn compute_sensitivity_scores(contexts: &mut [LlmCallContext], config: &SensitivityConfig) {
248    if contexts.is_empty() {
249        return;
250    }
251
252    let logical_positions = compute_logical_positions(contexts);
253    let num_logical_steps = logical_step_count(&logical_positions);
254    let max_logical_remaining = num_logical_steps.saturating_sub(1);
255    let group_sizes = logical_group_sizes(&logical_positions);
256    let raw_scores = compute_raw_sensitivity_scores(
257        contexts,
258        &logical_positions,
259        &group_sizes,
260        num_logical_steps,
261        max_logical_remaining,
262        config,
263    );
264    normalize_sensitivity_scores(contexts, &raw_scores);
265}
266
267fn logical_step_count(logical_positions: &[usize]) -> usize {
268    logical_positions
269        .iter()
270        .copied()
271        .max()
272        .map(|max_position| max_position + 1)
273        .unwrap_or(1)
274}
275
276fn logical_group_sizes(logical_positions: &[usize]) -> HashMap<usize, usize> {
277    let mut group_sizes = HashMap::new();
278    for &position in logical_positions {
279        *group_sizes.entry(position).or_insert(0) += 1;
280    }
281    group_sizes
282}
283
284fn compute_raw_sensitivity_scores(
285    contexts: &[LlmCallContext],
286    logical_positions: &[usize],
287    group_sizes: &HashMap<usize, usize>,
288    num_logical_steps: usize,
289    max_logical_remaining: usize,
290    config: &SensitivityConfig,
291) -> Vec<f64> {
292    contexts
293        .iter()
294        .enumerate()
295        .map(|(index, ctx)| {
296            let logical_position = logical_positions[index];
297            let critical_path_weight = critical_path_weight(ctx);
298            let fanout_score = fanout_score(logical_position, max_logical_remaining);
299            let position_score = position_score(logical_position, num_logical_steps);
300            let parallel_penalty =
301                parallel_penalty(ctx.parallel_slack_ratio, group_sizes, logical_position);
302
303            config.w_critical * critical_path_weight
304                + config.w_fanout * fanout_score
305                + config.w_position * position_score
306                - config.w_parallel * parallel_penalty
307        })
308        .collect()
309}
310
311fn critical_path_weight(ctx: &LlmCallContext) -> f64 {
312    if ctx.workflow_duration_s > 0.0 {
313        (ctx.call_duration_s / ctx.workflow_duration_s).min(1.0)
314    } else {
315        1.0
316    }
317}
318
319fn fanout_score(logical_position: usize, max_logical_remaining: usize) -> f64 {
320    if max_logical_remaining > 0 {
321        max_logical_remaining.saturating_sub(logical_position) as f64 / max_logical_remaining as f64
322    } else {
323        0.0
324    }
325}
326
327fn position_score(logical_position: usize, num_logical_steps: usize) -> f64 {
328    if num_logical_steps > 1 {
329        let normalized_pos = logical_position as f64 / (num_logical_steps - 1) as f64;
330        (1.0 - normalized_pos).max(normalized_pos)
331    } else {
332        1.0
333    }
334}
335
336fn parallel_penalty(
337    parallel_slack_ratio: f64,
338    group_sizes: &HashMap<usize, usize>,
339    logical_position: usize,
340) -> f64 {
341    let group_size = group_sizes.get(&logical_position).copied().unwrap_or(1);
342    if group_size > 1 {
343        let group_penalty = (group_size - 1) as f64 / group_size as f64;
344        (parallel_slack_ratio + group_penalty) / 2.0
345    } else {
346        parallel_slack_ratio
347    }
348}
349
350fn normalize_sensitivity_scores(contexts: &mut [LlmCallContext], raw_scores: &[f64]) {
351    let min_score = raw_scores.iter().copied().fold(f64::INFINITY, f64::min);
352    let max_score = raw_scores.iter().copied().fold(f64::NEG_INFINITY, f64::max);
353    let score_range = max_score - min_score;
354
355    for (ctx, &raw) in contexts.iter_mut().zip(raw_scores.iter()) {
356        ctx.sensitivity_score = if score_range > 0.0 {
357            (raw - min_score) / score_range
358        } else {
359            0.5
360        };
361    }
362}
363
364/// Assigns logical positions to calls, collapsing parallel siblings.
365///
366/// Uses standard interval-merging: contexts sorted by span start time,
367/// overlapping intervals get the same group index. Direct port of NAT's
368/// `_compute_logical_positions`.
369fn compute_logical_positions(contexts: &[LlmCallContext]) -> Vec<usize> {
370    if contexts.is_empty() {
371        return vec![];
372    }
373
374    let n = contexts.len();
375
376    // Sort indices by span_start_time
377    let mut sorted_indices: Vec<usize> = (0..n).collect();
378    sorted_indices.sort_by(|&a, &b| {
379        contexts[a]
380            .span_start_time
381            .partial_cmp(&contexts[b].span_start_time)
382            .unwrap_or(std::cmp::Ordering::Equal)
383    });
384
385    let mut group_assignments = vec![0usize; n];
386    let mut current_group = 0usize;
387    let mut group_max_end = contexts[sorted_indices[0]].span_end_time;
388
389    group_assignments[sorted_indices[0]] = current_group;
390
391    for &idx in &sorted_indices[1..] {
392        if contexts[idx].span_start_time < group_max_end {
393            // Overlaps with current group
394            group_assignments[idx] = current_group;
395            group_max_end = group_max_end.max(contexts[idx].span_end_time);
396        } else {
397            // New sequential step
398            current_group += 1;
399            group_assignments[idx] = current_group;
400            group_max_end = contexts[idx].span_end_time;
401        }
402    }
403
404    group_assignments
405}
406
407/// Adds context data to a node's accumulators.
408///
409/// Updates both per-call-index and aggregated (all_*) accumulators.
410fn add_to_accumulators(accs: &mut NodeAccumulators, ctx: &LlmCallContext, has_sensitivity: bool) {
411    // By call index
412    accs.remaining_calls
413        .entry(ctx.call_index)
414        .or_default()
415        .add_sample(ctx.remaining_calls as f64);
416    accs.output_tokens
417        .entry(ctx.call_index)
418        .or_default()
419        .add_sample(ctx.output_tokens as f64);
420    if let Some(ttm) = ctx.time_to_next_ms {
421        accs.interarrival_ms
422            .entry(ctx.call_index)
423            .or_default()
424            .add_sample(ttm);
425    }
426
427    // Aggregated across all indices
428    accs.all_remaining_calls
429        .add_sample(ctx.remaining_calls as f64);
430    accs.all_output_tokens.add_sample(ctx.output_tokens as f64);
431    if let Some(ttm) = ctx.time_to_next_ms {
432        accs.all_interarrival_ms.add_sample(ttm);
433    }
434
435    // Sensitivity accumulators
436    if has_sensitivity {
437        accs.sensitivity
438            .entry(ctx.call_index)
439            .or_default()
440            .add_sample(ctx.sensitivity_score);
441        accs.all_sensitivity.add_sample(ctx.sensitivity_score);
442    }
443}
444
445/// Navigates from root through path segments (split by "/"), creating nodes as needed.
446fn get_or_create_node<'a>(
447    root: &'a mut PredictionTrieNode,
448    path_key: &str,
449) -> &'a mut PredictionTrieNode {
450    if path_key.is_empty() {
451        return root;
452    }
453
454    let mut current = root;
455    for name in path_key.split('/') {
456        current = current
457            .children
458            .entry(name.to_string())
459            .or_insert_with(|| PredictionTrieNode::new(name));
460    }
461    current
462}
463
464/// Populates a trie node's predictions from its accumulators.
465fn populate_node_predictions(
466    node: &mut PredictionTrieNode,
467    accs: &NodeAccumulators,
468    sensitivity_config: &Option<SensitivityConfig>,
469) {
470    // Collect all call indices from all per-index maps
471    let mut all_indices: std::collections::HashSet<u32> = std::collections::HashSet::new();
472    all_indices.extend(accs.remaining_calls.keys());
473    all_indices.extend(accs.interarrival_ms.keys());
474    all_indices.extend(accs.output_tokens.keys());
475
476    let scale = sensitivity_config.as_ref().map(|c| c.sensitivity_scale);
477
478    for idx in all_indices {
479        let remaining = accs
480            .remaining_calls
481            .get(&idx)
482            .map(|s| s.compute_metrics())
483            .unwrap_or_default();
484        let interarrival = accs
485            .interarrival_ms
486            .get(&idx)
487            .map(|s| s.compute_metrics())
488            .unwrap_or_default();
489        let output_tok = accs
490            .output_tokens
491            .get(&idx)
492            .map(|s| s.compute_metrics())
493            .unwrap_or_default();
494        let sensitivity = match (scale, accs.sensitivity.get(&idx)) {
495            (Some(s), Some(acc)) => score_to_sensitivity(acc, s),
496            _ => None,
497        };
498
499        node.predictions_by_call_index.insert(
500            idx,
501            LlmCallPrediction {
502                remaining_calls: remaining,
503                interarrival_ms: interarrival,
504                output_tokens: output_tok,
505                latency_sensitivity: sensitivity,
506            },
507        );
508    }
509
510    // Aggregated predictions
511    if accs.all_remaining_calls.has_samples() {
512        let sensitivity = match scale {
513            Some(s) if accs.all_sensitivity.has_samples() => {
514                score_to_sensitivity(&accs.all_sensitivity, s)
515            }
516            _ => None,
517        };
518
519        node.predictions_any_index = Some(LlmCallPrediction {
520            remaining_calls: accs.all_remaining_calls.compute_metrics(),
521            interarrival_ms: accs.all_interarrival_ms.compute_metrics(),
522            output_tokens: accs.all_output_tokens.compute_metrics(),
523            latency_sensitivity: sensitivity,
524        });
525    }
526}
527
528/// Converts accumulated sensitivity scores to a clamped integer on [1, scale].
529///
530/// Returns `None` if the accumulator has no samples.
531fn score_to_sensitivity(acc: &RunningStats, scale: u32) -> Option<u32> {
532    if !acc.has_samples() {
533        return None;
534    }
535    let mean_score = acc.compute_metrics().mean;
536    let raw = (mean_score * (scale as f64 - 1.0)).round() as i64 + 1;
537    Some(raw.clamp(1, scale as i64) as u32)
538}
539
540#[cfg(test)]
541#[path = "../../tests/unit/trie/builder_tests.rs"]
542mod tests;