1use std::collections::HashMap;
12
13use serde::{Deserialize, Serialize};
14
15use super::accumulator::{AccumulatorState, NodeAccumulators, RunningStats};
16use super::data_models::{LlmCallPrediction, PredictionTrieNode};
17use crate::types::records::{CallKind, CallRecord, RunRecord};
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct SensitivityConfig {
24 pub sensitivity_scale: u32,
26 pub w_critical: f64,
28 pub w_fanout: f64,
30 pub w_position: f64,
32 pub w_parallel: f64,
34}
35
36impl Default for SensitivityConfig {
37 fn default() -> Self {
38 Self {
39 sensitivity_scale: 5,
40 w_critical: 0.5,
41 w_fanout: 0.3,
42 w_position: 0.2,
43 w_parallel: 0.0,
44 }
45 }
46}
47
48#[derive(Debug, Clone)]
50pub(crate) struct LlmCallContext {
51 pub path: Vec<String>,
52 pub call_index: u32,
53 pub remaining_calls: u32,
54 pub time_to_next_ms: Option<f64>,
55 pub output_tokens: u32,
56 pub call_duration_s: f64,
57 pub workflow_duration_s: f64,
58 pub parallel_slack_ratio: f64,
59 pub sensitivity_score: f64,
60 pub span_start_time: f64,
61 pub span_end_time: f64,
62}
63
64pub struct PredictionTrieBuilder {
76 accumulators: AccumulatorState,
77 sensitivity_config: Option<SensitivityConfig>,
78}
79
80impl PredictionTrieBuilder {
81 pub fn new(sensitivity_config: Option<SensitivityConfig>) -> Self {
83 Self {
84 accumulators: AccumulatorState::default(),
85 sensitivity_config,
86 }
87 }
88
89 pub fn with_accumulators(
94 accumulators: AccumulatorState,
95 sensitivity_config: Option<SensitivityConfig>,
96 ) -> Self {
97 Self {
98 accumulators,
99 sensitivity_config,
100 }
101 }
102
103 pub fn add_run(&mut self, run: &RunRecord) {
108 let mut contexts = extract_llm_contexts(run);
109 if let Some(ref config) = self.sensitivity_config {
110 compute_sensitivity_scores(&mut contexts, config);
111 }
112 for ctx in &contexts {
113 self.update_accumulators(ctx);
114 }
115 }
116
117 pub fn build(&self) -> PredictionTrieNode {
122 let mut root = PredictionTrieNode::new("root");
123
124 for (path_key, node_accs) in &self.accumulators.nodes {
125 let node = get_or_create_node(&mut root, path_key);
126 populate_node_predictions(node, node_accs, &self.sensitivity_config);
127 }
128
129 root
130 }
131
132 pub fn accumulators(&self) -> &AccumulatorState {
134 &self.accumulators
135 }
136
137 fn update_accumulators(&mut self, ctx: &LlmCallContext) {
139 let has_sensitivity = self.sensitivity_config.is_some();
140
141 let root_accs = self.accumulators.nodes.entry(String::new()).or_default();
143 add_to_accumulators(root_accs, ctx, has_sensitivity);
144
145 for i in 0..ctx.path.len() {
147 let path_key = ctx.path[..=i].join("/");
148 let node_accs = self.accumulators.nodes.entry(path_key).or_default();
149 add_to_accumulators(node_accs, ctx, has_sensitivity);
150 }
151 }
152}
153
154fn extract_llm_contexts(run: &RunRecord) -> Vec<LlmCallContext> {
159 let workflow_duration_s = if let Some(end) = run.ended_at {
161 (end - run.started_at).num_milliseconds() as f64 / 1000.0
162 } else {
163 run.calls
165 .iter()
166 .filter_map(|c| c.ended_at)
167 .max()
168 .map(|end| (end - run.started_at).num_milliseconds() as f64 / 1000.0)
169 .unwrap_or(0.0)
170 };
171
172 let llm_calls: Vec<(usize, &CallRecord)> = run
174 .calls
175 .iter()
176 .enumerate()
177 .filter(|(_, c)| c.kind == CallKind::Llm && c.ended_at.is_some())
178 .collect();
179
180 let total_llm = llm_calls.len();
181
182 let mut call_counts: HashMap<String, u32> = HashMap::new();
184
185 let mut contexts = Vec::with_capacity(total_llm);
186
187 for (llm_pos, (orig_idx, call)) in llm_calls.iter().enumerate() {
188 let ended_at = call.ended_at.expect("filtered to completed calls");
189
190 let path = vec![call.name.clone()];
192
193 let counter = call_counts.entry(call.name.clone()).or_insert(0);
195 *counter += 1;
196 let call_index = *counter;
197
198 let remaining_calls = (total_llm - llm_pos - 1) as u32;
200
201 let time_to_next_ms = run
203 .calls
204 .iter()
205 .skip(orig_idx + 1)
206 .find(|c| c.kind == CallKind::Llm)
207 .map(|next_llm| {
208 next_llm
209 .started_at
210 .signed_duration_since(ended_at)
211 .num_milliseconds() as f64
212 });
213
214 let output_tokens = call.output_tokens.unwrap_or(0);
216
217 let call_duration_s = (ended_at - call.started_at).num_milliseconds() as f64 / 1000.0;
219
220 let span_start_time = call.started_at.timestamp() as f64;
222 let span_end_time = ended_at.timestamp() as f64;
223
224 contexts.push(LlmCallContext {
225 path,
226 call_index,
227 remaining_calls,
228 time_to_next_ms,
229 output_tokens,
230 call_duration_s,
231 workflow_duration_s,
232 parallel_slack_ratio: 0.0,
233 sensitivity_score: 0.0,
234 span_start_time,
235 span_end_time,
236 });
237 }
238
239 contexts
240}
241
242fn compute_sensitivity_scores(contexts: &mut [LlmCallContext], config: &SensitivityConfig) {
248 if contexts.is_empty() {
249 return;
250 }
251
252 let logical_positions = compute_logical_positions(contexts);
253 let num_logical_steps = logical_step_count(&logical_positions);
254 let max_logical_remaining = num_logical_steps.saturating_sub(1);
255 let group_sizes = logical_group_sizes(&logical_positions);
256 let raw_scores = compute_raw_sensitivity_scores(
257 contexts,
258 &logical_positions,
259 &group_sizes,
260 num_logical_steps,
261 max_logical_remaining,
262 config,
263 );
264 normalize_sensitivity_scores(contexts, &raw_scores);
265}
266
267fn logical_step_count(logical_positions: &[usize]) -> usize {
268 logical_positions
269 .iter()
270 .copied()
271 .max()
272 .map(|max_position| max_position + 1)
273 .unwrap_or(1)
274}
275
276fn logical_group_sizes(logical_positions: &[usize]) -> HashMap<usize, usize> {
277 let mut group_sizes = HashMap::new();
278 for &position in logical_positions {
279 *group_sizes.entry(position).or_insert(0) += 1;
280 }
281 group_sizes
282}
283
284fn compute_raw_sensitivity_scores(
285 contexts: &[LlmCallContext],
286 logical_positions: &[usize],
287 group_sizes: &HashMap<usize, usize>,
288 num_logical_steps: usize,
289 max_logical_remaining: usize,
290 config: &SensitivityConfig,
291) -> Vec<f64> {
292 contexts
293 .iter()
294 .enumerate()
295 .map(|(index, ctx)| {
296 let logical_position = logical_positions[index];
297 let critical_path_weight = critical_path_weight(ctx);
298 let fanout_score = fanout_score(logical_position, max_logical_remaining);
299 let position_score = position_score(logical_position, num_logical_steps);
300 let parallel_penalty =
301 parallel_penalty(ctx.parallel_slack_ratio, group_sizes, logical_position);
302
303 config.w_critical * critical_path_weight
304 + config.w_fanout * fanout_score
305 + config.w_position * position_score
306 - config.w_parallel * parallel_penalty
307 })
308 .collect()
309}
310
311fn critical_path_weight(ctx: &LlmCallContext) -> f64 {
312 if ctx.workflow_duration_s > 0.0 {
313 (ctx.call_duration_s / ctx.workflow_duration_s).min(1.0)
314 } else {
315 1.0
316 }
317}
318
319fn fanout_score(logical_position: usize, max_logical_remaining: usize) -> f64 {
320 if max_logical_remaining > 0 {
321 max_logical_remaining.saturating_sub(logical_position) as f64 / max_logical_remaining as f64
322 } else {
323 0.0
324 }
325}
326
327fn position_score(logical_position: usize, num_logical_steps: usize) -> f64 {
328 if num_logical_steps > 1 {
329 let normalized_pos = logical_position as f64 / (num_logical_steps - 1) as f64;
330 (1.0 - normalized_pos).max(normalized_pos)
331 } else {
332 1.0
333 }
334}
335
336fn parallel_penalty(
337 parallel_slack_ratio: f64,
338 group_sizes: &HashMap<usize, usize>,
339 logical_position: usize,
340) -> f64 {
341 let group_size = group_sizes.get(&logical_position).copied().unwrap_or(1);
342 if group_size > 1 {
343 let group_penalty = (group_size - 1) as f64 / group_size as f64;
344 (parallel_slack_ratio + group_penalty) / 2.0
345 } else {
346 parallel_slack_ratio
347 }
348}
349
350fn normalize_sensitivity_scores(contexts: &mut [LlmCallContext], raw_scores: &[f64]) {
351 let min_score = raw_scores.iter().copied().fold(f64::INFINITY, f64::min);
352 let max_score = raw_scores.iter().copied().fold(f64::NEG_INFINITY, f64::max);
353 let score_range = max_score - min_score;
354
355 for (ctx, &raw) in contexts.iter_mut().zip(raw_scores.iter()) {
356 ctx.sensitivity_score = if score_range > 0.0 {
357 (raw - min_score) / score_range
358 } else {
359 0.5
360 };
361 }
362}
363
364fn compute_logical_positions(contexts: &[LlmCallContext]) -> Vec<usize> {
370 if contexts.is_empty() {
371 return vec![];
372 }
373
374 let n = contexts.len();
375
376 let mut sorted_indices: Vec<usize> = (0..n).collect();
378 sorted_indices.sort_by(|&a, &b| {
379 contexts[a]
380 .span_start_time
381 .partial_cmp(&contexts[b].span_start_time)
382 .unwrap_or(std::cmp::Ordering::Equal)
383 });
384
385 let mut group_assignments = vec![0usize; n];
386 let mut current_group = 0usize;
387 let mut group_max_end = contexts[sorted_indices[0]].span_end_time;
388
389 group_assignments[sorted_indices[0]] = current_group;
390
391 for &idx in &sorted_indices[1..] {
392 if contexts[idx].span_start_time < group_max_end {
393 group_assignments[idx] = current_group;
395 group_max_end = group_max_end.max(contexts[idx].span_end_time);
396 } else {
397 current_group += 1;
399 group_assignments[idx] = current_group;
400 group_max_end = contexts[idx].span_end_time;
401 }
402 }
403
404 group_assignments
405}
406
407fn add_to_accumulators(accs: &mut NodeAccumulators, ctx: &LlmCallContext, has_sensitivity: bool) {
411 accs.remaining_calls
413 .entry(ctx.call_index)
414 .or_default()
415 .add_sample(ctx.remaining_calls as f64);
416 accs.output_tokens
417 .entry(ctx.call_index)
418 .or_default()
419 .add_sample(ctx.output_tokens as f64);
420 if let Some(ttm) = ctx.time_to_next_ms {
421 accs.interarrival_ms
422 .entry(ctx.call_index)
423 .or_default()
424 .add_sample(ttm);
425 }
426
427 accs.all_remaining_calls
429 .add_sample(ctx.remaining_calls as f64);
430 accs.all_output_tokens.add_sample(ctx.output_tokens as f64);
431 if let Some(ttm) = ctx.time_to_next_ms {
432 accs.all_interarrival_ms.add_sample(ttm);
433 }
434
435 if has_sensitivity {
437 accs.sensitivity
438 .entry(ctx.call_index)
439 .or_default()
440 .add_sample(ctx.sensitivity_score);
441 accs.all_sensitivity.add_sample(ctx.sensitivity_score);
442 }
443}
444
445fn get_or_create_node<'a>(
447 root: &'a mut PredictionTrieNode,
448 path_key: &str,
449) -> &'a mut PredictionTrieNode {
450 if path_key.is_empty() {
451 return root;
452 }
453
454 let mut current = root;
455 for name in path_key.split('/') {
456 current = current
457 .children
458 .entry(name.to_string())
459 .or_insert_with(|| PredictionTrieNode::new(name));
460 }
461 current
462}
463
464fn populate_node_predictions(
466 node: &mut PredictionTrieNode,
467 accs: &NodeAccumulators,
468 sensitivity_config: &Option<SensitivityConfig>,
469) {
470 let mut all_indices: std::collections::HashSet<u32> = std::collections::HashSet::new();
472 all_indices.extend(accs.remaining_calls.keys());
473 all_indices.extend(accs.interarrival_ms.keys());
474 all_indices.extend(accs.output_tokens.keys());
475
476 let scale = sensitivity_config.as_ref().map(|c| c.sensitivity_scale);
477
478 for idx in all_indices {
479 let remaining = accs
480 .remaining_calls
481 .get(&idx)
482 .map(|s| s.compute_metrics())
483 .unwrap_or_default();
484 let interarrival = accs
485 .interarrival_ms
486 .get(&idx)
487 .map(|s| s.compute_metrics())
488 .unwrap_or_default();
489 let output_tok = accs
490 .output_tokens
491 .get(&idx)
492 .map(|s| s.compute_metrics())
493 .unwrap_or_default();
494 let sensitivity = match (scale, accs.sensitivity.get(&idx)) {
495 (Some(s), Some(acc)) => score_to_sensitivity(acc, s),
496 _ => None,
497 };
498
499 node.predictions_by_call_index.insert(
500 idx,
501 LlmCallPrediction {
502 remaining_calls: remaining,
503 interarrival_ms: interarrival,
504 output_tokens: output_tok,
505 latency_sensitivity: sensitivity,
506 },
507 );
508 }
509
510 if accs.all_remaining_calls.has_samples() {
512 let sensitivity = match scale {
513 Some(s) if accs.all_sensitivity.has_samples() => {
514 score_to_sensitivity(&accs.all_sensitivity, s)
515 }
516 _ => None,
517 };
518
519 node.predictions_any_index = Some(LlmCallPrediction {
520 remaining_calls: accs.all_remaining_calls.compute_metrics(),
521 interarrival_ms: accs.all_interarrival_ms.compute_metrics(),
522 output_tokens: accs.all_output_tokens.compute_metrics(),
523 latency_sensitivity: sensitivity,
524 });
525 }
526}
527
528fn score_to_sensitivity(acc: &RunningStats, scale: u32) -> Option<u32> {
532 if !acc.has_samples() {
533 return None;
534 }
535 let mean_score = acc.compute_metrics().mean;
536 let raw = (mean_score * (scale as f64 - 1.0)).round() as i64 + 1;
537 Some(raw.clamp(1, scale as i64) as u32)
538}
539
540#[cfg(test)]
541#[path = "../../tests/unit/trie/builder_tests.rs"]
542mod tests;