1use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use std::sync::Arc;
19
20use crate::perf::RecordedStream;
21use crate::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse;
22
23pub enum LogprobType {
25 Normalized,
27
28 Unnormalized,
31}
32
33#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
35pub struct TokenLogprob {
36 pub token: String,
38 pub logprob: f32,
40 pub bytes: Option<Vec<u8>>,
42}
43
44#[derive(Debug, Clone)]
46pub struct TokenLogProbs {
47 selected: TokenLogprob,
48 alternatives: Vec<TokenLogprob>,
49 all_sorted: Vec<TokenLogprob>,
50}
51
52impl TokenLogProbs {
53 pub fn new(selected: TokenLogprob, mut alternatives: Vec<TokenLogprob>) -> Self {
55 alternatives.sort_by(|a, b| b.logprob.partial_cmp(&a.logprob).unwrap());
57
58 let mut all_sorted = Vec::new();
60 let mut added_selected = false;
61
62 let selected_in_alternatives = alternatives.iter().any(|alt| {
64 alt.token == selected.token && (alt.logprob - selected.logprob).abs() < 1e-6
65 });
66
67 if !selected_in_alternatives {
69 let mut insert_position = alternatives.len();
71 for (i, alt) in alternatives.iter().enumerate() {
72 if selected.logprob > alt.logprob {
73 insert_position = i;
74 break;
75 }
76 }
77
78 for (i, alt) in alternatives.iter().enumerate() {
80 if i == insert_position && !added_selected {
81 all_sorted.push(selected.clone());
82 added_selected = true;
83 }
84 all_sorted.push(alt.clone());
85 }
86
87 if !added_selected {
89 all_sorted.push(selected.clone());
90 }
91 } else {
92 all_sorted = alternatives.clone();
94 }
95
96 Self {
97 selected,
98 alternatives,
99 all_sorted,
100 }
101 }
102
103 pub fn selected_token(&self) -> &TokenLogprob {
105 &self.selected
106 }
107
108 pub fn alternative_tokens(&self) -> &[TokenLogprob] {
110 &self.alternatives
111 }
112
113 pub fn all_tokens(&self) -> &[TokenLogprob] {
115 &self.all_sorted
116 }
117}
118
119pub trait LogprobExtractor {
121 fn extract_logprobs_by_choice(&self) -> HashMap<u32, Vec<TokenLogProbs>>;
124}
125
126impl LogprobExtractor for NvCreateChatCompletionStreamResponse {
128 fn extract_logprobs_by_choice(&self) -> HashMap<u32, Vec<TokenLogProbs>> {
129 let mut result = HashMap::new();
130
131 for choice in &self.choices {
132 let choice_index = choice.index;
133
134 let choice_logprobs = choice
135 .logprobs
136 .as_ref()
137 .and_then(|logprobs| logprobs.content.as_ref())
138 .map(|content| {
139 content
140 .iter()
141 .map(|token_logprob| {
142 let selected_token = TokenLogprob {
143 token: token_logprob.token.clone(),
144 logprob: token_logprob.logprob,
145 bytes: token_logprob.bytes.clone(),
146 };
147
148 let alternatives: Vec<TokenLogprob> = token_logprob
150 .top_logprobs
151 .iter()
152 .map(|top_logprob| TokenLogprob {
153 token: top_logprob.token.clone(),
154 logprob: top_logprob.logprob,
155 bytes: top_logprob.bytes.clone(),
156 })
157 .collect();
158
159 TokenLogProbs::new(selected_token, alternatives)
160 })
161 .collect::<Vec<_>>()
162 })
163 .unwrap_or_default();
164
165 result.insert(choice_index, choice_logprobs);
166 }
167
168 result
169 }
170}
171
172pub fn validate_and_flatten_choices(
175 choice_logprobs: HashMap<u32, Vec<TokenLogProbs>>,
176) -> Result<Vec<Vec<TokenLogProbs>>, String> {
177 if choice_logprobs.is_empty() {
178 return Ok(Vec::new());
179 }
180
181 let max_choice = *choice_logprobs.keys().max().unwrap();
182 let expected_count = (max_choice + 1) as usize;
183
184 if choice_logprobs.len() != expected_count {
185 return Err(format!(
186 "Missing choice indices: expected {} choices [0, {}), but found {} choices: {:?}",
187 expected_count,
188 max_choice + 1,
189 choice_logprobs.len(),
190 choice_logprobs.keys().collect::<Vec<_>>()
191 ));
192 }
193
194 for i in 0..=max_choice {
196 if !choice_logprobs.contains_key(&i) {
197 return Err(format!(
198 "Missing choice index {}: expected [0, {}), found {:?}",
199 i,
200 max_choice + 1,
201 choice_logprobs.keys().collect::<Vec<_>>()
202 ));
203 }
204 }
205
206 let mut result = Vec::with_capacity(expected_count);
208 for i in 0..=max_choice {
209 result.push(choice_logprobs[&i].clone());
210 }
211
212 Ok(result)
213}
214
215#[derive(Debug, Clone, Serialize, Deserialize)]
217pub struct SensitivityAnalysis {
218 pub total_responses: usize,
220 pub choice_analyses: HashMap<u32, ChoiceAnalysis>,
222}
223
224#[derive(Debug, Clone, Serialize, Deserialize)]
226pub struct ChoiceAnalysis {
227 pub choice_index: u32,
229 pub position_closeness: Vec<PositionCloseness>,
231 pub positions_analyzed: usize,
233}
234
235#[derive(Debug, Clone, Serialize, Deserialize)]
237pub struct PositionCloseness {
238 pub stream_position: usize,
240 pub token_position: usize,
242 pub logprob_difference: f32,
244 pub probability_difference: f32,
246 pub probability_remaining: f32,
248 pub candidates: Vec<TokenLogprob>,
250}
251
252#[derive(Debug, Clone)]
254pub struct ClosePosition {
255 pub stream_position: usize,
257 pub token_position: usize,
259 pub logprob_difference: f32,
261 pub probability_difference: f32,
263 pub probability_remaining: f32,
265 pub top_candidates: Vec<TokenLogprob>,
267}
268
269pub fn analyze_logprob_sensitivity(
271 recorded_stream: Arc<RecordedStream<impl LogprobExtractor>>,
272) -> SensitivityAnalysis {
273 let mut choice_analyses: HashMap<u32, ChoiceAnalysis> = HashMap::new();
274 let mut choice_sequence_positions: HashMap<u32, usize> = HashMap::new();
276
277 for (stream_pos, timestamped_response) in recorded_stream.responses().iter().enumerate() {
278 let response = ×tamped_response.response;
279 let logprobs_by_choice = response.extract_logprobs_by_choice();
280
281 for (choice_index, choice_logprobs) in logprobs_by_choice {
282 let choice_analysis =
284 choice_analyses
285 .entry(choice_index)
286 .or_insert_with(|| ChoiceAnalysis {
287 choice_index,
288 position_closeness: Vec::new(),
289 positions_analyzed: 0,
290 });
291
292 let current_seq_pos = choice_sequence_positions.entry(choice_index).or_insert(0);
294
295 for token_logprobs in choice_logprobs {
296 let all_tokens = token_logprobs.all_tokens();
297
298 if all_tokens.len() < 2 {
299 *current_seq_pos += 1;
300 continue;
301 }
302
303 let sorted_candidates = all_tokens.to_vec();
305
306 let logprob_difference =
308 sorted_candidates[0].logprob - sorted_candidates[1].logprob;
309
310 let prob1 = sorted_candidates[0].logprob.exp();
312 let prob2 = sorted_candidates[1].logprob.exp();
313 let probability_difference = prob1 - prob2;
314
315 let total_prob_sum: f32 = sorted_candidates.iter().map(|t| t.logprob.exp()).sum();
317 let probability_remaining = 1.0 - total_prob_sum;
318
319 choice_analysis.position_closeness.push(PositionCloseness {
320 stream_position: stream_pos,
321 token_position: *current_seq_pos,
322 logprob_difference,
323 probability_difference,
324 probability_remaining,
325 candidates: sorted_candidates,
326 });
327
328 choice_analysis.positions_analyzed += 1;
329 *current_seq_pos += 1;
330 }
331 }
332 }
333
334 for choice_analysis in choice_analyses.values_mut() {
336 choice_analysis.position_closeness.sort_by(|a, b| {
337 a.probability_difference
338 .partial_cmp(&b.probability_difference)
339 .unwrap()
340 });
341 }
342
343 SensitivityAnalysis {
344 total_responses: recorded_stream.responses().len(),
345 choice_analyses,
346 }
347}
348
349impl SensitivityAnalysis {
350 pub fn get_close_positions_for_choice(
353 &self,
354 choice_index: u32,
355 threshold: f32,
356 ) -> Vec<&PositionCloseness> {
357 self.choice_analyses
358 .get(&choice_index)
359 .map(|analysis| {
360 analysis
361 .position_closeness
362 .iter()
363 .filter(|pos| pos.probability_difference <= threshold)
364 .collect()
365 })
366 .unwrap_or_default()
367 }
368
369 pub fn get_closest_positions_for_choice(
371 &self,
372 choice_index: u32,
373 count: usize,
374 ) -> Vec<&PositionCloseness> {
375 self.choice_analyses
376 .get(&choice_index)
377 .map(|analysis| analysis.position_closeness.iter().take(count).collect())
378 .unwrap_or_default()
379 }
380
381 pub fn print_summary(&self) {
383 println!("=== Logprob Sensitivity Analysis Summary ===");
384 println!("Total stream responses analyzed: {}", self.total_responses);
385 println!("Number of choices: {}", self.choice_analyses.len());
386 println!();
387
388 for (choice_index, choice_analysis) in &self.choice_analyses {
389 println!(
390 "Choice {}: {} positions analyzed",
391 choice_index, choice_analysis.positions_analyzed
392 );
393
394 if !choice_analysis.position_closeness.is_empty() {
395 println!(" Closest positions (smallest probability differences):");
396 for (j, pos) in choice_analysis
397 .position_closeness
398 .iter()
399 .take(3)
400 .enumerate()
401 {
402 let top_token = &pos.candidates[0].token;
403 let second_token = &pos.candidates[1].token;
404 let prob1 = pos.candidates[0].logprob.exp();
405 let prob2 = pos.candidates[1].logprob.exp();
406 println!(
407 " {}: Stream pos {}, token pos {} - '{}' ({:.1}%) vs '{}' ({:.1}%) (prob diff: {:.4})",
408 j + 1,
409 pos.stream_position,
410 pos.token_position,
411 top_token,
412 prob1 * 100.0,
413 second_token,
414 prob2 * 100.0,
415 pos.probability_difference
416 );
417 }
418 }
419 println!();
420 }
421 }
422
423 pub fn close_position_percentage_for_choice(&self, choice_index: u32, threshold: f32) -> f32 {
426 if let Some(analysis) = self.choice_analyses.get(&choice_index) {
427 if analysis.positions_analyzed == 0 {
428 return 0.0;
429 }
430 let close_count = analysis
431 .position_closeness
432 .iter()
433 .filter(|pos| pos.probability_difference <= threshold)
434 .count();
435 (close_count as f32 / analysis.positions_analyzed as f32) * 100.0
436 } else {
437 0.0
438 }
439 }
440
441 pub fn detect_multiple_close_tokens(
443 &self,
444 choice_index: u32,
445 threshold: f32,
446 ) -> Vec<MultipleCloseTokens> {
447 let mut results = Vec::new();
448
449 if let Some(analysis) = self.choice_analyses.get(&choice_index) {
450 for pos in &analysis.position_closeness {
451 let close_tokens = self.count_close_tokens_at_position(pos, threshold);
452 if close_tokens.close_count > 2 {
453 results.push(close_tokens);
454 }
455 }
456 }
457
458 results
459 }
460
461 pub fn detect_likely_greedy_decoding(&self, choice_index: u32) -> bool {
464 if let Some(analysis) = self.choice_analyses.get(&choice_index) {
465 if analysis.positions_analyzed == 0 {
466 return true; }
468
469 let likely_greedy_positions = analysis
473 .position_closeness
474 .iter()
475 .filter(|pos| {
476 if pos.candidates.is_empty() {
477 return true; }
479
480 pos.probability_difference < 0.01 || pos.probability_difference > 0.05
482 })
483 .count();
484
485 (likely_greedy_positions as f32 / analysis.positions_analyzed as f32) > 0.5
487 } else {
488 false
489 }
490 }
491
492 pub fn greedy_selection_percentage(&self, choice_index: u32) -> f32 {
494 if let Some(analysis) = self.choice_analyses.get(&choice_index) {
495 if analysis.positions_analyzed == 0 {
496 return 0.0;
497 }
498
499 let greedy_like_positions = analysis
500 .position_closeness
501 .iter()
502 .filter(|pos| {
503 pos.probability_difference < 0.01 || pos.probability_difference > 0.05
505 })
506 .count();
507
508 (greedy_like_positions as f32 / analysis.positions_analyzed as f32) * 100.0
509 } else {
510 0.0
511 }
512 }
513
514 fn count_close_tokens_at_position(
517 &self,
518 position: &PositionCloseness,
519 threshold: f32,
520 ) -> MultipleCloseTokens {
521 let top_prob = position.candidates[0].logprob.exp();
522 let mut close_count = 1; let mut close_tokens = vec![position.candidates[0].clone()];
524
525 for candidate in &position.candidates[1..] {
526 let candidate_prob = candidate.logprob.exp();
527 let prob_diff = top_prob - candidate_prob;
528 if prob_diff <= threshold {
529 close_count += 1;
530 close_tokens.push(candidate.clone());
531 } else {
532 break; }
534 }
535
536 let max_difference = if close_count > 1 {
537 let last_prob = close_tokens.last().unwrap().logprob.exp();
538 top_prob - last_prob
539 } else {
540 0.0
541 };
542
543 MultipleCloseTokens {
544 stream_position: position.stream_position,
545 token_position: position.token_position,
546 close_count,
547 close_tokens,
548 max_difference,
549 }
550 }
551}
552
553#[derive(Debug, Clone)]
555pub struct MultipleCloseTokens {
556 pub stream_position: usize,
557 pub token_position: usize,
558 pub close_count: usize,
559 pub close_tokens: Vec<TokenLogprob>,
560 pub max_difference: f32,
561}
562
563#[cfg(test)]
564mod tests {
565 use super::*;
566
567 type TestTokenAlternative = (&'static str, f32);
569 type TestTokenData = (&'static str, f32, Vec<TestTokenAlternative>);
570 type TestTokenDataVec = Vec<TestTokenData>;
571 use crate::perf::{RecordingMode, TimestampedResponse, record_stream_with_context};
572 use crate::protocols::codec::create_message_stream;
573 use crate::protocols::convert_sse_stream;
574 use approx::assert_abs_diff_eq;
575 use dynamo_async_openai::types::{
576 ChatChoiceLogprobs, ChatChoiceStream, ChatCompletionStreamResponseDelta,
577 ChatCompletionTokenLogprob, FinishReason, Role, TopLogprobs,
578 };
579 use futures::StreamExt;
580 use std::sync::Arc;
581 use std::time::Instant;
582
583 const FLOAT_EPSILON: f32 = 1e-6;
584
585 #[test]
586 fn test_two_tokens_close() {
587 let analysis = create_analysis_with_logprobs(vec![create_token_logprob_from_linear_probs(
590 "hello",
591 0.45,
592 vec![("world", 0.44)], )]);
594
595 let close_positions = analysis.get_close_positions_for_choice(0, 0.1);
596 assert_eq!(close_positions.len(), 1);
597
598 assert_abs_diff_eq!(
600 close_positions[0].probability_difference,
601 0.01,
602 epsilon = FLOAT_EPSILON
603 );
604
605 assert_abs_diff_eq!(
607 close_positions[0].logprob_difference,
608 0.023,
609 epsilon = 0.001
610 );
611
612 let multiple_close = analysis.detect_multiple_close_tokens(0, 0.05);
613 assert_eq!(multiple_close.len(), 0); }
615
616 #[test]
617 fn test_three_tokens_close() {
618 let analysis = create_analysis_with_logprobs(vec![create_token_logprob_from_linear_probs(
621 "hello",
622 0.35,
623 vec![
624 ("world", 0.33), ("there", 0.32), ],
627 )]);
628
629 let close_positions = analysis.get_close_positions_for_choice(0, 0.025);
630 assert_eq!(close_positions.len(), 1);
631
632 assert_abs_diff_eq!(
634 close_positions[0].probability_difference,
635 0.02,
636 epsilon = FLOAT_EPSILON
637 );
638
639 let multiple_close = analysis.detect_multiple_close_tokens(0, 0.04);
640 assert_eq!(multiple_close.len(), 1);
641 assert_eq!(multiple_close[0].close_count, 3);
642 assert_abs_diff_eq!(
644 multiple_close[0].max_difference,
645 0.03,
646 epsilon = FLOAT_EPSILON
647 );
648 }
649
650 #[test]
651 fn test_four_tokens_close() {
652 let analysis = create_analysis_with_logprobs(vec![create_token_logprob_from_linear_probs(
655 "hello",
656 0.27,
657 vec![
658 ("world", 0.26), ("there", 0.25), ("friend", 0.22), ],
662 )]);
663
664 let close_positions = analysis.get_close_positions_for_choice(0, 0.02);
665 assert_eq!(close_positions.len(), 1);
666
667 assert_abs_diff_eq!(
669 close_positions[0].probability_difference,
670 0.01,
671 epsilon = FLOAT_EPSILON
672 );
673
674 let multiple_close = analysis.detect_multiple_close_tokens(0, 0.06);
675 assert_eq!(multiple_close.len(), 1);
676 assert_eq!(multiple_close[0].close_count, 4);
677 assert_abs_diff_eq!(
679 multiple_close[0].max_difference,
680 0.05,
681 epsilon = FLOAT_EPSILON
682 );
683 }
684
685 #[test]
686 fn test_multiple_choices_analysis() {
687 let analysis = create_analysis_with_multiple_choices(vec![
688 vec![create_token_logprob_from_linear_probs(
690 "hello",
691 0.7,
692 vec![("world", 0.25)],
693 )],
694 vec![create_token_logprob_from_linear_probs(
696 "hi",
697 0.505,
698 vec![("there", 0.495)],
699 )],
700 ]);
701
702 assert_eq!(analysis.choice_analyses.len(), 2);
703
704 let choice0_close = analysis.get_close_positions_for_choice(0, 0.5);
706 assert_eq!(choice0_close.len(), 1);
707 assert_abs_diff_eq!(
708 choice0_close[0].probability_difference,
709 0.45,
710 epsilon = FLOAT_EPSILON
711 );
712
713 let choice1_close = analysis.get_close_positions_for_choice(1, 0.5);
715 assert_eq!(choice1_close.len(), 1);
716 assert_abs_diff_eq!(
717 choice1_close[0].probability_difference,
718 0.01,
719 epsilon = FLOAT_EPSILON
720 );
721
722 assert!(choice1_close[0].probability_difference < choice0_close[0].probability_difference);
724 }
725
726 #[test]
727 fn test_edge_case_single_token() {
728 let analysis = create_analysis_with_logprobs(vec![create_token_logprob_from_linear_probs(
730 "hello",
731 1.0,
732 vec![],
733 )]);
734
735 let close_positions = analysis.get_close_positions_for_choice(0, 1.0);
736 assert_eq!(close_positions.len(), 0); }
738
739 #[test]
740 fn test_threshold_filtering() {
741 let analysis = create_analysis_with_logprobs(vec![
742 create_token_logprob_from_linear_probs("token1", 0.55, vec![("token2", 0.45)]),
744 create_token_logprob_from_linear_probs("token3", 0.8, vec![("token4", 0.2)]),
746 ]);
747
748 let close_strict = analysis.get_close_positions_for_choice(0, 0.15);
750 assert_eq!(close_strict.len(), 1);
751 assert_abs_diff_eq!(
752 close_strict[0].probability_difference,
753 0.1,
754 epsilon = FLOAT_EPSILON
755 );
756
757 let close_permissive = analysis.get_close_positions_for_choice(0, 0.7);
759 assert_eq!(close_permissive.len(), 2);
760
761 assert!(
763 close_permissive[0].probability_difference < close_permissive[1].probability_difference
764 );
765 }
766
767 #[test]
768 fn test_percentage_calculation() {
769 let analysis = create_analysis_with_logprobs(vec![
770 create_token_logprob_from_linear_probs("token1", 0.6, vec![("token2", 0.4)]),
772 create_token_logprob_from_linear_probs("token3", 0.9, vec![("token4", 0.1)]),
774 create_token_logprob_from_linear_probs("token5", 0.52, vec![("token6", 0.48)]),
776 ]);
777
778 let percentage = analysis.close_position_percentage_for_choice(0, 0.25);
779 assert!((percentage - 66.67).abs() < 0.01); }
781
782 #[test]
783 fn test_real_vllm_equal_logprobs() {
784 let analysis = create_analysis_with_logprobs(vec![create_token_logprob_from_linear_probs(
788 "Ġblock",
789 0.403,
790 vec![("Ġchunk", 0.403)], )]);
792
793 let close_positions = analysis.get_close_positions_for_choice(0, 0.001);
795 assert_eq!(close_positions.len(), 1);
796 assert_abs_diff_eq!(
797 close_positions[0].probability_difference,
798 0.0,
799 epsilon = FLOAT_EPSILON
800 );
801
802 let position = &close_positions[0];
804 assert_eq!(position.candidates.len(), 2);
805
806 let tokens: Vec<&str> = position
808 .candidates
809 .iter()
810 .map(|c| c.token.as_str())
811 .collect();
812 assert!(tokens.contains(&"Ġblock"));
813 assert!(tokens.contains(&"Ġchunk"));
814
815 assert_abs_diff_eq!(
817 position.candidates[0].logprob,
818 position.candidates[1].logprob,
819 epsilon = FLOAT_EPSILON
820 );
821
822 let prob1 = position.candidates[0].logprob.exp();
824 let prob2 = position.candidates[1].logprob.exp();
825 assert_abs_diff_eq!(prob1, 0.403, epsilon = 0.001);
826 assert_abs_diff_eq!(prob2, 0.403, epsilon = 0.001);
827 }
828
829 fn create_analysis_with_logprobs(
831 token_logprobs: Vec<ChatCompletionTokenLogprob>,
832 ) -> SensitivityAnalysis {
833 let start_time = Instant::now();
834 let response = create_mock_response_with_logprobs(token_logprobs);
835 let responses = vec![TimestampedResponse::new(response, 0)];
836 let recorded_stream = RecordedStream::new(responses, start_time, Instant::now());
837 let arc_stream = Arc::new(recorded_stream);
838
839 analyze_logprob_sensitivity(arc_stream)
840 }
841
842 fn create_analysis_with_multiple_choices(
843 choices_logprobs: Vec<Vec<ChatCompletionTokenLogprob>>,
844 ) -> SensitivityAnalysis {
845 let start_time = Instant::now();
846 let response = create_mock_response_with_multiple_choices(choices_logprobs);
847 let responses = vec![TimestampedResponse::new(response, 0)];
848 let recorded_stream = RecordedStream::new(responses, start_time, Instant::now());
849 let arc_stream = Arc::new(recorded_stream);
850
851 analyze_logprob_sensitivity(arc_stream)
852 }
853
854 fn create_analysis_with_mixed_sampling(mixed_data: TestTokenDataVec) -> SensitivityAnalysis {
855 let start_time = Instant::now();
856 let token_logprobs: Vec<ChatCompletionTokenLogprob> = mixed_data
857 .into_iter()
858 .map(|(selected_token, selected_prob, alternatives)| {
859 create_token_logprob_from_linear_probs(selected_token, selected_prob, alternatives)
860 })
861 .collect();
862
863 let response = create_mock_response_with_logprobs(token_logprobs);
864 let responses = vec![TimestampedResponse::new(response, 0)];
865 let recorded_stream = RecordedStream::new(responses, start_time, Instant::now());
866 let arc_stream = Arc::new(recorded_stream);
867
868 analyze_logprob_sensitivity(arc_stream)
869 }
870
871 fn create_analysis_with_missing_selected_token() -> SensitivityAnalysis {
872 let start_time = Instant::now();
873
874 let token_logprobs = vec![ChatCompletionTokenLogprob {
877 token: "unlikely_selection".to_string(),
878 logprob: (0.15_f32).ln(), bytes: None,
880 top_logprobs: vec![
881 TopLogprobs {
882 token: "best_option".to_string(),
883 logprob: (0.4_f32).ln(), bytes: None,
885 },
886 TopLogprobs {
887 token: "second_best".to_string(),
888 logprob: (0.3_f32).ln(), bytes: None,
890 },
891 ],
892 }];
893
894 let response = create_mock_response_with_logprobs(token_logprobs);
895 let responses = vec![TimestampedResponse::new(response, 0)];
896 let recorded_stream = RecordedStream::new(responses, start_time, Instant::now());
897 let arc_stream = Arc::new(recorded_stream);
898
899 analyze_logprob_sensitivity(arc_stream)
900 }
901
902 fn create_token_logprob_from_linear_probs(
905 token: &str,
906 prob: f32,
907 top_probs: Vec<(&str, f32)>,
908 ) -> ChatCompletionTokenLogprob {
909 assert!(
911 (0.0..=1.0).contains(&prob),
912 "Probability must be in [0, 1]: {}",
913 prob
914 );
915
916 let total_prob = prob + top_probs.iter().map(|(_, p)| p).sum::<f32>();
918 assert!(
919 total_prob <= 1.001,
920 "Total probability mass exceeds 1: {}",
921 total_prob
922 ); for (_, p) in &top_probs {
925 assert!(
926 *p >= 0.0 && *p <= 1.0,
927 "Probability must be in [0, 1]: {}",
928 p
929 );
930 }
931
932 ChatCompletionTokenLogprob {
933 token: token.to_string(),
934 logprob: prob.ln(),
935 bytes: None,
936 top_logprobs: top_probs
937 .into_iter()
938 .map(|(t, p)| TopLogprobs {
939 token: t.to_string(),
940 logprob: p.ln(),
941 bytes: None,
942 })
943 .collect(),
944 }
945 }
946
947 fn create_mock_response_with_logprobs(
948 token_logprobs: Vec<ChatCompletionTokenLogprob>,
949 ) -> NvCreateChatCompletionStreamResponse {
950 #[expect(deprecated)]
951 NvCreateChatCompletionStreamResponse {
952 id: "test_id".to_string(),
953 choices: vec![ChatChoiceStream {
954 index: 0,
955 delta: ChatCompletionStreamResponseDelta {
956 content: Some("test".to_string()),
957 function_call: None,
958 tool_calls: None,
959 role: Some(Role::Assistant),
960 refusal: None,
961 reasoning_content: None,
962 },
963 finish_reason: Some(FinishReason::Stop),
964 logprobs: Some(ChatChoiceLogprobs {
965 content: Some(token_logprobs),
966 refusal: None,
967 }),
968 }],
969 created: 1234567890,
970 model: "test-model".to_string(),
971 service_tier: None,
972 system_fingerprint: None,
973 object: "chat.completion.chunk".to_string(),
974 usage: None,
975 }
976 }
977
978 fn create_mock_response_with_multiple_choices(
979 choices_logprobs: Vec<Vec<ChatCompletionTokenLogprob>>,
980 ) -> NvCreateChatCompletionStreamResponse {
981 #[expect(deprecated)]
982 let choices = choices_logprobs
983 .into_iter()
984 .enumerate()
985 .map(|(i, token_logprobs)| ChatChoiceStream {
986 index: i as u32,
987 delta: ChatCompletionStreamResponseDelta {
988 content: Some("test".to_string()),
989 function_call: None,
990 tool_calls: None,
991 role: Some(Role::Assistant),
992 refusal: None,
993 reasoning_content: None,
994 },
995 finish_reason: Some(FinishReason::Stop),
996 logprobs: Some(ChatChoiceLogprobs {
997 content: Some(token_logprobs),
998 refusal: None,
999 }),
1000 })
1001 .collect();
1002
1003 NvCreateChatCompletionStreamResponse {
1004 id: "test_id".to_string(),
1005 choices,
1006 created: 1234567890,
1007 model: "test-model".to_string(),
1008 service_tier: None,
1009 system_fingerprint: None,
1010 object: "chat.completion.chunk".to_string(),
1011 usage: None,
1012 }
1013 }
1014
1015 #[test]
1016 fn test_sensitivity_analysis() {
1017 let start_time = Instant::now();
1018 let responses = vec![TimestampedResponse::new(create_mock_response(), 0)];
1019
1020 let recorded_stream = RecordedStream::new(responses, start_time, Instant::now());
1021 let arc_stream = Arc::new(recorded_stream);
1022
1023 let analysis = analyze_logprob_sensitivity(arc_stream);
1024 assert_eq!(analysis.total_responses, 1);
1026 assert!(analysis.close_position_percentage_for_choice(0, 0.5) >= 0.0);
1027 }
1028
1029 #[test]
1030 fn test_extract_logprobs_by_choice_empty() {
1031 let response = create_mock_response();
1032 let logprobs = response.extract_logprobs_by_choice();
1033 assert!(logprobs.is_empty() || logprobs.values().any(|v| v.is_empty()));
1034 }
1035
1036 #[test]
1037 fn test_token_logprobs_struct() {
1038 let selected = TokenLogprob {
1040 token: "selected".to_string(),
1041 logprob: 0.7_f32.ln(), bytes: None,
1043 };
1044
1045 let alternatives = vec![
1046 TokenLogprob {
1047 token: "alt1".to_string(),
1048 logprob: 0.2_f32.ln(), bytes: None,
1050 },
1051 TokenLogprob {
1052 token: "alt2".to_string(),
1053 logprob: 0.1_f32.ln(), bytes: None,
1055 },
1056 ];
1057
1058 let token_logprobs = TokenLogProbs::new(selected.clone(), alternatives.clone());
1059
1060 assert_eq!(token_logprobs.selected_token(), &selected);
1062 assert_eq!(token_logprobs.alternative_tokens().len(), 2);
1063 assert_eq!(token_logprobs.all_tokens().len(), 3);
1064
1065 let all_tokens = token_logprobs.all_tokens();
1067 assert_eq!(all_tokens[0].token, "selected"); assert_eq!(all_tokens[1].token, "alt1"); assert_eq!(all_tokens[2].token, "alt2"); let alt_tokens = token_logprobs.alternative_tokens();
1073 assert_eq!(alt_tokens[0].token, "alt1"); assert_eq!(alt_tokens[1].token, "alt2"); }
1076
1077 #[test]
1078 fn test_token_logprobs_selected_in_alternatives() {
1079 let selected = TokenLogprob {
1081 token: "token".to_string(),
1082 logprob: 0.4_f32.ln(), bytes: None,
1084 };
1085
1086 let alternatives = vec![
1087 TokenLogprob {
1088 token: "token".to_string(),
1089 logprob: 0.4_f32.ln(), bytes: None,
1091 },
1092 TokenLogprob {
1093 token: "other".to_string(),
1094 logprob: 0.3_f32.ln(), bytes: None,
1096 },
1097 ];
1098
1099 let token_logprobs = TokenLogProbs::new(selected, alternatives.clone());
1100
1101 let all_tokens = token_logprobs.all_tokens();
1103 assert_eq!(all_tokens.len(), 2);
1104 assert_eq!(all_tokens[0].token, "token"); assert_eq!(all_tokens[1].token, "other"); }
1107
1108 #[test]
1109 fn test_validate_and_flatten_choices() {
1110 let mut choices = HashMap::new();
1112 choices.insert(0, vec![]);
1113 choices.insert(1, vec![]);
1114 choices.insert(2, vec![]);
1115
1116 let result = validate_and_flatten_choices(choices);
1117 assert!(result.is_ok());
1118 let flattened = result.unwrap();
1119 assert_eq!(flattened.len(), 3);
1120
1121 let mut choices = HashMap::new();
1123 choices.insert(0, vec![]);
1124 choices.insert(2, vec![]); let result = validate_and_flatten_choices(choices);
1127 assert!(result.is_err());
1128 let error_msg = result.unwrap_err();
1129 assert!(
1130 error_msg.contains("Missing choice indices")
1131 && error_msg.contains("expected 3 choices")
1132 );
1133
1134 let choices = HashMap::new();
1136 let result = validate_and_flatten_choices(choices);
1137 assert!(result.is_ok());
1138 assert_eq!(result.unwrap().len(), 0);
1139 }
1140
1141 #[test]
1142 fn test_probability_remaining_calculation() {
1143 let analysis = create_analysis_with_logprobs(vec![create_token_logprob_from_linear_probs(
1145 "token",
1146 0.4, vec![
1148 ("alt1", 0.3), ("alt2", 0.1), ],
1152 )]);
1153
1154 let close_positions = analysis.get_close_positions_for_choice(0, 1.0);
1155 assert_eq!(close_positions.len(), 1);
1156
1157 let position = &close_positions[0];
1158
1159 assert_abs_diff_eq!(position.probability_remaining, 0.2, epsilon = 0.01);
1162
1163 let analysis_complete =
1165 create_analysis_with_logprobs(vec![create_token_logprob_from_linear_probs(
1166 "token",
1167 0.5, vec![
1169 ("alt1", 0.3), ("alt2", 0.2), ],
1173 )]);
1174
1175 let complete_positions = analysis_complete.get_close_positions_for_choice(0, 1.0);
1176 assert_eq!(complete_positions.len(), 1);
1177
1178 let complete_position = &complete_positions[0];
1179
1180 assert_abs_diff_eq!(complete_position.probability_remaining, 0.0, epsilon = 0.01);
1182 }
1183
1184 #[test]
1185 fn test_position_closeness_ordering() {
1186 let analysis = create_analysis_with_logprobs(vec![
1187 create_token_logprob_from_linear_probs("far", 0.85, vec![("alt", 0.15)]),
1189 create_token_logprob_from_linear_probs("close", 0.51, vec![("alt", 0.49)]),
1191 create_token_logprob_from_linear_probs("medium", 0.7, vec![("alt", 0.3)]),
1193 ]);
1194
1195 let positions = &analysis.choice_analyses.get(&0).unwrap().position_closeness;
1196 assert_eq!(positions.len(), 3);
1197
1198 assert!(positions[0].probability_difference <= positions[1].probability_difference);
1200 assert!(positions[1].probability_difference <= positions[2].probability_difference);
1201
1202 assert_abs_diff_eq!(
1204 positions[0].probability_difference,
1205 0.02,
1206 epsilon = FLOAT_EPSILON
1207 );
1208 assert_abs_diff_eq!(
1209 positions[1].probability_difference,
1210 0.4,
1211 epsilon = FLOAT_EPSILON
1212 );
1213 assert_abs_diff_eq!(
1214 positions[2].probability_difference,
1215 0.7,
1216 epsilon = FLOAT_EPSILON
1217 );
1218 }
1219
1220 #[test]
1221 fn test_multiple_close_tokens_edge_cases() {
1222 let analysis = create_analysis_with_logprobs(vec![create_token_logprob_from_linear_probs(
1224 "token",
1225 0.34,
1226 vec![
1227 ("alt1", 0.33), ("alt2", 0.32), ("alt3", 0.01), ],
1231 )]);
1232
1233 let multiple_close = analysis.detect_multiple_close_tokens(0, 0.025);
1234 assert_eq!(multiple_close.len(), 1);
1235 assert_eq!(multiple_close[0].close_count, 3);
1236 }
1237
1238 #[test]
1239 fn test_choice_analysis_independence() {
1240 let analysis = create_analysis_with_multiple_choices(vec![
1241 vec![
1243 create_token_logprob_from_linear_probs("token1", 0.55, vec![("alt1", 0.45)]), create_token_logprob_from_linear_probs("token2", 0.9, vec![("alt2", 0.1)]), ],
1246 vec![
1248 create_token_logprob_from_linear_probs("token3", 0.501, vec![("alt3", 0.499)]), ],
1250 ]);
1251
1252 assert_eq!(analysis.choice_analyses.len(), 2);
1253 assert_eq!(
1254 analysis.choice_analyses.get(&0).unwrap().positions_analyzed,
1255 2
1256 );
1257 assert_eq!(
1258 analysis.choice_analyses.get(&1).unwrap().positions_analyzed,
1259 1
1260 );
1261
1262 let choice0_close = analysis.get_close_positions_for_choice(0, 0.5);
1264 let choice1_close = analysis.get_close_positions_for_choice(1, 0.5);
1265
1266 assert_eq!(choice0_close.len(), 1);
1267 assert_eq!(choice1_close.len(), 1);
1268
1269 assert!(choice1_close[0].probability_difference < choice0_close[0].probability_difference);
1271 }
1272
1273 #[test]
1274 fn test_get_closest_positions_boundary() {
1275 let analysis = create_analysis_with_logprobs(vec![
1276 create_token_logprob_from_linear_probs("token1", 0.6, vec![("alt1", 0.4)]),
1277 create_token_logprob_from_linear_probs("token2", 0.75, vec![("alt2", 0.25)]),
1278 ]);
1279
1280 let closest = analysis.get_closest_positions_for_choice(0, 10);
1282 assert_eq!(closest.len(), 2);
1283
1284 let closest = analysis.get_closest_positions_for_choice(0, 2);
1286 assert_eq!(closest.len(), 2);
1287
1288 let closest = analysis.get_closest_positions_for_choice(0, 1);
1290 assert_eq!(closest.len(), 1);
1291 }
1292
1293 #[test]
1294 fn test_zero_threshold() {
1295 let analysis = create_analysis_with_logprobs(vec![
1296 create_token_logprob_from_linear_probs("token", 0.5, vec![("alt", 0.5)]), ]);
1298
1299 let close_positions = analysis.get_close_positions_for_choice(0, 0.0);
1300 assert_eq!(close_positions.len(), 1);
1301 assert_abs_diff_eq!(
1302 close_positions[0].probability_difference,
1303 0.0,
1304 epsilon = FLOAT_EPSILON
1305 );
1306 }
1307
1308 #[test]
1309 fn test_nonexistent_choice() {
1310 let analysis = create_analysis_with_logprobs(vec![create_token_logprob_from_linear_probs(
1311 "token",
1312 0.6,
1313 vec![("alt", 0.4)],
1314 )]);
1315
1316 let close_positions = analysis.get_close_positions_for_choice(5, 0.1);
1318 assert!(close_positions.is_empty());
1319
1320 let closest = analysis.get_closest_positions_for_choice(5, 3);
1321 assert!(closest.is_empty());
1322
1323 let percentage = analysis.close_position_percentage_for_choice(5, 0.1);
1324 assert_eq!(percentage, 0.0);
1325 }
1326
1327 #[test]
1328 fn test_logprob_extractor_with_missing_data() {
1329 #[expect(deprecated)]
1331 let response = NvCreateChatCompletionStreamResponse {
1332 id: "test_id".to_string(),
1333 choices: vec![ChatChoiceStream {
1334 index: 0,
1335 delta: ChatCompletionStreamResponseDelta {
1336 content: Some("test".to_string()),
1337 function_call: None,
1338 tool_calls: None,
1339 role: Some(Role::Assistant),
1340 refusal: None,
1341 reasoning_content: None,
1342 },
1343 finish_reason: Some(FinishReason::Stop),
1344 logprobs: None, }],
1346 created: 1234567890,
1347 model: "test-model".to_string(),
1348 service_tier: None,
1349 system_fingerprint: None,
1350 object: "chat.completion.chunk".to_string(),
1351 usage: None,
1352 };
1353
1354 let logprobs = response.extract_logprobs_by_choice();
1355 assert_eq!(logprobs.len(), 1);
1356 assert!(logprobs.values().any(|v| v.is_empty()));
1357 }
1358
1359 #[test]
1360 fn test_print_summary_no_panic() {
1361 let analysis = create_analysis_with_logprobs(vec![create_token_logprob_from_linear_probs(
1362 "token",
1363 0.6,
1364 vec![("alt", 0.4)],
1365 )]);
1366
1367 analysis.print_summary();
1369 }
1370
1371 #[test]
1372 fn test_greedy_decoding_detection() {
1373 let analysis = create_analysis_with_logprobs(vec![
1377 create_token_logprob_from_linear_probs(
1378 "best",
1379 0.8,
1380 vec![("second", 0.15), ("third", 0.05)],
1381 ),
1382 create_token_logprob_from_linear_probs(
1383 "optimal",
1384 0.7,
1385 vec![("suboptimal", 0.2), ("bad", 0.1)],
1386 ),
1387 ]);
1388
1389 let is_greedy = analysis.detect_likely_greedy_decoding(0);
1391 assert!(is_greedy);
1392
1393 let greedy_percentage = analysis.greedy_selection_percentage(0);
1394 assert!(greedy_percentage > 90.0); }
1396
1397 #[test]
1398 fn test_non_greedy_decoding_detection() {
1399 let analysis = create_analysis_with_mixed_sampling(vec![
1403 ("selected_best", 0.6, vec![("alternative", 0.4)]),
1404 (
1405 "close_choice",
1406 0.35,
1407 vec![("very_close", 0.33), ("also_close", 0.32)],
1408 ),
1409 ]);
1410
1411 let _is_greedy = analysis.detect_likely_greedy_decoding(0);
1412 let greedy_percentage = analysis.greedy_selection_percentage(0);
1415 assert!((0.0..=100.0).contains(&greedy_percentage)); }
1417
1418 #[test]
1419 fn test_selected_token_not_in_top_logprobs() {
1420 let analysis = create_analysis_with_missing_selected_token();
1423
1424 let greedy_percentage = analysis.greedy_selection_percentage(0);
1426 assert!((0.0..=100.0).contains(&greedy_percentage)); }
1428
1429 #[test]
1430 fn test_equal_logprobs_greedy_detection() {
1431 let analysis = create_analysis_with_logprobs(vec![create_token_logprob_from_linear_probs(
1433 "Ġblock",
1434 0.403,
1435 vec![("Ġchunk", 0.403)], )]);
1437
1438 let close_positions = analysis.get_close_positions_for_choice(0, 0.001);
1440 assert_eq!(close_positions.len(), 1);
1441
1442 let is_greedy = analysis.detect_likely_greedy_decoding(0);
1444 assert!(is_greedy);
1445 }
1446
1447 #[tokio::test]
1448 async fn test_real_sse_stream_analysis() {
1449 let data = std::fs::read_to_string(
1451 "tests/data/replays/deepseek-r1-distill-llama-8b/chat-completions.stream.1",
1452 )
1453 .expect("Failed to read test data file");
1454
1455 let sse_stream = create_message_stream(&data);
1457
1458 let response_stream =
1460 convert_sse_stream::<NvCreateChatCompletionStreamResponse>(Box::pin(sse_stream));
1461
1462 let filtered_stream = response_stream.filter_map(|annotated| async move { annotated.data });
1464
1465 let ctx = Arc::new(MockContext::new());
1467
1468 let (recorded_stream, recording_rx) =
1470 record_stream_with_context(Box::pin(filtered_stream), ctx, RecordingMode::Sink);
1471
1472 let _collected: Vec<_> = recorded_stream.collect().await;
1474
1475 let recorded = recording_rx
1477 .await
1478 .expect("Failed to receive recorded stream");
1479
1480 assert!(recorded.response_count() > 0, "No responses recorded");
1482 println!("Recorded {} responses", recorded.response_count());
1483
1484 let arc_recorded = Arc::new(recorded);
1486 let analysis = analyze_logprob_sensitivity(arc_recorded);
1487
1488 analysis.print_summary();
1490
1491 assert!(
1493 !analysis.choice_analyses.is_empty(),
1494 "No choice analyses found"
1495 );
1496 assert!(
1497 analysis
1498 .choice_analyses
1499 .values()
1500 .any(|a| a.positions_analyzed > 0),
1501 "No positions analyzed"
1502 );
1503
1504 let close_positions = analysis.get_close_positions_for_choice(0, 0.001);
1506
1507 assert!(!close_positions.is_empty(), "No close positions found");
1509
1510 let equal_positions = close_positions
1512 .iter()
1513 .filter(|pos| pos.probability_difference < 0.0001)
1514 .count();
1515 if equal_positions > 0 {
1516 println!(
1517 "Found {} positions with nearly equal probabilities",
1518 equal_positions
1519 );
1520 }
1521
1522 let closest_3 = analysis.get_closest_positions_for_choice(0, 3);
1524 assert!(
1525 closest_3.len() <= 3,
1526 "Should return at most 3 closest positions"
1527 );
1528
1529 let percentage = analysis.close_position_percentage_for_choice(0, 0.1);
1530 assert!(
1531 (0.0..=100.0).contains(&percentage),
1532 "Percentage should be valid"
1533 );
1534
1535 let is_greedy = analysis.detect_likely_greedy_decoding(0);
1537 let greedy_percentage = analysis.greedy_selection_percentage(0);
1538 println!(
1539 "Greedy detection: {} ({}% greedy-like)",
1540 is_greedy, greedy_percentage
1541 );
1542
1543 let multiple_close = analysis.detect_multiple_close_tokens(0, 0.05);
1545 if !multiple_close.is_empty() {
1546 println!(
1547 "Found {} positions with multiple close tokens",
1548 multiple_close.len()
1549 );
1550 }
1551 }
1552
1553 fn create_mock_response() -> NvCreateChatCompletionStreamResponse {
1554 NvCreateChatCompletionStreamResponse {
1558 id: "test_id".to_string(),
1559 choices: vec![],
1560 created: 1234567890,
1561 model: "test-model".to_string(),
1562 service_tier: None,
1563 system_fingerprint: None,
1564 object: "chat.completion.chunk".to_string(),
1565 usage: None,
1566 }
1567 }
1568
1569 #[derive(Debug)]
1571 struct MockContext {
1572 id: String,
1573 }
1574
1575 impl MockContext {
1576 fn new() -> Self {
1577 Self {
1578 id: "test-context".to_string(),
1579 }
1580 }
1581 }
1582
1583 #[async_trait::async_trait]
1584 impl dynamo_runtime::engine::AsyncEngineContext for MockContext {
1585 fn id(&self) -> &str {
1586 &self.id
1587 }
1588
1589 fn stop(&self) {
1590 }
1592
1593 fn stop_generating(&self) {
1594 }
1596
1597 fn kill(&self) {
1598 }
1600
1601 fn is_stopped(&self) -> bool {
1602 false
1603 }
1604
1605 fn is_killed(&self) -> bool {
1606 false
1607 }
1608
1609 async fn stopped(&self) {
1610 }
1612
1613 async fn killed(&self) {
1614 }
1616
1617 fn link_child(&self, _: Arc<dyn dynamo_runtime::engine::AsyncEngineContext>) {
1618 }
1620 }
1621}