use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use super::accumulator::{AccumulatorState, NodeAccumulators, RunningStats};
use super::data_models::{LlmCallPrediction, PredictionTrieNode};
use crate::types::records::{CallKind, CallRecord, RunRecord};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SensitivityConfig {
pub sensitivity_scale: u32,
pub w_critical: f64,
pub w_fanout: f64,
pub w_position: f64,
pub w_parallel: f64,
}
impl Default for SensitivityConfig {
fn default() -> Self {
Self {
sensitivity_scale: 5,
w_critical: 0.5,
w_fanout: 0.3,
w_position: 0.2,
w_parallel: 0.0,
}
}
}
#[derive(Debug, Clone)]
pub(crate) struct LlmCallContext {
pub path: Vec<String>,
pub call_index: u32,
pub remaining_calls: u32,
pub time_to_next_ms: Option<f64>,
pub output_tokens: u32,
pub call_duration_s: f64,
pub workflow_duration_s: f64,
pub parallel_slack_ratio: f64,
pub sensitivity_score: f64,
pub span_start_time: f64,
pub span_end_time: f64,
}
pub struct PredictionTrieBuilder {
accumulators: AccumulatorState,
sensitivity_config: Option<SensitivityConfig>,
}
impl PredictionTrieBuilder {
pub fn new(sensitivity_config: Option<SensitivityConfig>) -> Self {
Self {
accumulators: AccumulatorState::default(),
sensitivity_config,
}
}
pub fn with_accumulators(
accumulators: AccumulatorState,
sensitivity_config: Option<SensitivityConfig>,
) -> Self {
Self {
accumulators,
sensitivity_config,
}
}
pub fn add_run(&mut self, run: &RunRecord) {
let mut contexts = extract_llm_contexts(run);
if let Some(ref config) = self.sensitivity_config {
compute_sensitivity_scores(&mut contexts, config);
}
for ctx in &contexts {
self.update_accumulators(ctx);
}
}
pub fn build(&self) -> PredictionTrieNode {
let mut root = PredictionTrieNode::new("root");
for (path_key, node_accs) in &self.accumulators.nodes {
let node = get_or_create_node(&mut root, path_key);
populate_node_predictions(node, node_accs, &self.sensitivity_config);
}
root
}
pub fn accumulators(&self) -> &AccumulatorState {
&self.accumulators
}
fn update_accumulators(&mut self, ctx: &LlmCallContext) {
let has_sensitivity = self.sensitivity_config.is_some();
let root_accs = self.accumulators.nodes.entry(String::new()).or_default();
add_to_accumulators(root_accs, ctx, has_sensitivity);
for i in 0..ctx.path.len() {
let path_key = ctx.path[..=i].join("/");
let node_accs = self.accumulators.nodes.entry(path_key).or_default();
add_to_accumulators(node_accs, ctx, has_sensitivity);
}
}
}
fn extract_llm_contexts(run: &RunRecord) -> Vec<LlmCallContext> {
let workflow_duration_s = if let Some(end) = run.ended_at {
(end - run.started_at).num_milliseconds() as f64 / 1000.0
} else {
run.calls
.iter()
.filter_map(|c| c.ended_at)
.max()
.map(|end| (end - run.started_at).num_milliseconds() as f64 / 1000.0)
.unwrap_or(0.0)
};
let llm_calls: Vec<(usize, &CallRecord)> = run
.calls
.iter()
.enumerate()
.filter(|(_, c)| c.kind == CallKind::Llm && c.ended_at.is_some())
.collect();
let total_llm = llm_calls.len();
let mut call_counts: HashMap<String, u32> = HashMap::new();
let mut contexts = Vec::with_capacity(total_llm);
for (llm_pos, (orig_idx, call)) in llm_calls.iter().enumerate() {
let ended_at = call.ended_at.expect("filtered to completed calls");
let path = vec![call.name.clone()];
let counter = call_counts.entry(call.name.clone()).or_insert(0);
*counter += 1;
let call_index = *counter;
let remaining_calls = (total_llm - llm_pos - 1) as u32;
let time_to_next_ms = run
.calls
.iter()
.skip(orig_idx + 1)
.find(|c| c.kind == CallKind::Llm)
.map(|next_llm| {
next_llm
.started_at
.signed_duration_since(ended_at)
.num_milliseconds() as f64
});
let output_tokens = call.output_tokens.unwrap_or(0);
let call_duration_s = (ended_at - call.started_at).num_milliseconds() as f64 / 1000.0;
let span_start_time = call.started_at.timestamp() as f64;
let span_end_time = ended_at.timestamp() as f64;
contexts.push(LlmCallContext {
path,
call_index,
remaining_calls,
time_to_next_ms,
output_tokens,
call_duration_s,
workflow_duration_s,
parallel_slack_ratio: 0.0,
sensitivity_score: 0.0,
span_start_time,
span_end_time,
});
}
contexts
}
fn compute_sensitivity_scores(contexts: &mut [LlmCallContext], config: &SensitivityConfig) {
if contexts.is_empty() {
return;
}
let logical_positions = compute_logical_positions(contexts);
let num_logical_steps = logical_step_count(&logical_positions);
let max_logical_remaining = num_logical_steps.saturating_sub(1);
let group_sizes = logical_group_sizes(&logical_positions);
let raw_scores = compute_raw_sensitivity_scores(
contexts,
&logical_positions,
&group_sizes,
num_logical_steps,
max_logical_remaining,
config,
);
normalize_sensitivity_scores(contexts, &raw_scores);
}
fn logical_step_count(logical_positions: &[usize]) -> usize {
logical_positions
.iter()
.copied()
.max()
.map(|max_position| max_position + 1)
.unwrap_or(1)
}
fn logical_group_sizes(logical_positions: &[usize]) -> HashMap<usize, usize> {
let mut group_sizes = HashMap::new();
for &position in logical_positions {
*group_sizes.entry(position).or_insert(0) += 1;
}
group_sizes
}
fn compute_raw_sensitivity_scores(
contexts: &[LlmCallContext],
logical_positions: &[usize],
group_sizes: &HashMap<usize, usize>,
num_logical_steps: usize,
max_logical_remaining: usize,
config: &SensitivityConfig,
) -> Vec<f64> {
contexts
.iter()
.enumerate()
.map(|(index, ctx)| {
let logical_position = logical_positions[index];
let critical_path_weight = critical_path_weight(ctx);
let fanout_score = fanout_score(logical_position, max_logical_remaining);
let position_score = position_score(logical_position, num_logical_steps);
let parallel_penalty =
parallel_penalty(ctx.parallel_slack_ratio, group_sizes, logical_position);
config.w_critical * critical_path_weight
+ config.w_fanout * fanout_score
+ config.w_position * position_score
- config.w_parallel * parallel_penalty
})
.collect()
}
fn critical_path_weight(ctx: &LlmCallContext) -> f64 {
if ctx.workflow_duration_s > 0.0 {
(ctx.call_duration_s / ctx.workflow_duration_s).min(1.0)
} else {
1.0
}
}
fn fanout_score(logical_position: usize, max_logical_remaining: usize) -> f64 {
if max_logical_remaining > 0 {
max_logical_remaining.saturating_sub(logical_position) as f64 / max_logical_remaining as f64
} else {
0.0
}
}
fn position_score(logical_position: usize, num_logical_steps: usize) -> f64 {
if num_logical_steps > 1 {
let normalized_pos = logical_position as f64 / (num_logical_steps - 1) as f64;
(1.0 - normalized_pos).max(normalized_pos)
} else {
1.0
}
}
fn parallel_penalty(
parallel_slack_ratio: f64,
group_sizes: &HashMap<usize, usize>,
logical_position: usize,
) -> f64 {
let group_size = group_sizes.get(&logical_position).copied().unwrap_or(1);
if group_size > 1 {
let group_penalty = (group_size - 1) as f64 / group_size as f64;
(parallel_slack_ratio + group_penalty) / 2.0
} else {
parallel_slack_ratio
}
}
fn normalize_sensitivity_scores(contexts: &mut [LlmCallContext], raw_scores: &[f64]) {
let min_score = raw_scores.iter().copied().fold(f64::INFINITY, f64::min);
let max_score = raw_scores.iter().copied().fold(f64::NEG_INFINITY, f64::max);
let score_range = max_score - min_score;
for (ctx, &raw) in contexts.iter_mut().zip(raw_scores.iter()) {
ctx.sensitivity_score = if score_range > 0.0 {
(raw - min_score) / score_range
} else {
0.5
};
}
}
fn compute_logical_positions(contexts: &[LlmCallContext]) -> Vec<usize> {
if contexts.is_empty() {
return vec![];
}
let n = contexts.len();
let mut sorted_indices: Vec<usize> = (0..n).collect();
sorted_indices.sort_by(|&a, &b| {
contexts[a]
.span_start_time
.partial_cmp(&contexts[b].span_start_time)
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut group_assignments = vec![0usize; n];
let mut current_group = 0usize;
let mut group_max_end = contexts[sorted_indices[0]].span_end_time;
group_assignments[sorted_indices[0]] = current_group;
for &idx in &sorted_indices[1..] {
if contexts[idx].span_start_time < group_max_end {
group_assignments[idx] = current_group;
group_max_end = group_max_end.max(contexts[idx].span_end_time);
} else {
current_group += 1;
group_assignments[idx] = current_group;
group_max_end = contexts[idx].span_end_time;
}
}
group_assignments
}
fn add_to_accumulators(accs: &mut NodeAccumulators, ctx: &LlmCallContext, has_sensitivity: bool) {
accs.remaining_calls
.entry(ctx.call_index)
.or_default()
.add_sample(ctx.remaining_calls as f64);
accs.output_tokens
.entry(ctx.call_index)
.or_default()
.add_sample(ctx.output_tokens as f64);
if let Some(ttm) = ctx.time_to_next_ms {
accs.interarrival_ms
.entry(ctx.call_index)
.or_default()
.add_sample(ttm);
}
accs.all_remaining_calls
.add_sample(ctx.remaining_calls as f64);
accs.all_output_tokens.add_sample(ctx.output_tokens as f64);
if let Some(ttm) = ctx.time_to_next_ms {
accs.all_interarrival_ms.add_sample(ttm);
}
if has_sensitivity {
accs.sensitivity
.entry(ctx.call_index)
.or_default()
.add_sample(ctx.sensitivity_score);
accs.all_sensitivity.add_sample(ctx.sensitivity_score);
}
}
fn get_or_create_node<'a>(
root: &'a mut PredictionTrieNode,
path_key: &str,
) -> &'a mut PredictionTrieNode {
if path_key.is_empty() {
return root;
}
let mut current = root;
for name in path_key.split('/') {
current = current
.children
.entry(name.to_string())
.or_insert_with(|| PredictionTrieNode::new(name));
}
current
}
fn populate_node_predictions(
node: &mut PredictionTrieNode,
accs: &NodeAccumulators,
sensitivity_config: &Option<SensitivityConfig>,
) {
let mut all_indices: std::collections::HashSet<u32> = std::collections::HashSet::new();
all_indices.extend(accs.remaining_calls.keys());
all_indices.extend(accs.interarrival_ms.keys());
all_indices.extend(accs.output_tokens.keys());
let scale = sensitivity_config.as_ref().map(|c| c.sensitivity_scale);
for idx in all_indices {
let remaining = accs
.remaining_calls
.get(&idx)
.map(|s| s.compute_metrics())
.unwrap_or_default();
let interarrival = accs
.interarrival_ms
.get(&idx)
.map(|s| s.compute_metrics())
.unwrap_or_default();
let output_tok = accs
.output_tokens
.get(&idx)
.map(|s| s.compute_metrics())
.unwrap_or_default();
let sensitivity = match (scale, accs.sensitivity.get(&idx)) {
(Some(s), Some(acc)) => score_to_sensitivity(acc, s),
_ => None,
};
node.predictions_by_call_index.insert(
idx,
LlmCallPrediction {
remaining_calls: remaining,
interarrival_ms: interarrival,
output_tokens: output_tok,
latency_sensitivity: sensitivity,
},
);
}
if accs.all_remaining_calls.has_samples() {
let sensitivity = match scale {
Some(s) if accs.all_sensitivity.has_samples() => {
score_to_sensitivity(&accs.all_sensitivity, s)
}
_ => None,
};
node.predictions_any_index = Some(LlmCallPrediction {
remaining_calls: accs.all_remaining_calls.compute_metrics(),
interarrival_ms: accs.all_interarrival_ms.compute_metrics(),
output_tokens: accs.all_output_tokens.compute_metrics(),
latency_sensitivity: sensitivity,
});
}
}
fn score_to_sensitivity(acc: &RunningStats, scale: u32) -> Option<u32> {
if !acc.has_samples() {
return None;
}
let mean_score = acc.compute_metrics().mean;
let raw = (mean_score * (scale as f64 - 1.0)).round() as i64 + 1;
Some(raw.clamp(1, scale as i64) as u32)
}
#[cfg(test)]
#[path = "../../tests/unit/trie/builder_tests.rs"]
mod tests;