dynamo_llm/perf/
logprobs.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Module for recording logprobs from a streaming response.
5//!
6//! Logprobs are a bit easier than token counting and timing because they are
7//! fully self-contained in the response chunk.
8//!
9//! In fact, if logprobs are given, they are a good way to count tokens; however,
10//! the emission of logprobs is also more costly and generally not available unless
11//! explicitly requested.
12//!
13//! The primary reason to record logprobs is to analyze the possible outputs of
14//! a model as a function of sequence position.
15
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use std::sync::Arc;
19
20use crate::perf::RecordedStream;
21use crate::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse;
22
23/// The type of logprobs observed in the response.
24pub enum LogprobType {
25    /// If normalized, then all the reported "top_logprobs" sum to 0.
26    Normalized,
27
28    /// If unnormalized, then the reported "top_logprobs" are not normalized,
29    /// so the sum of the "top_logprobs" will not sum to 0.
30    Unnormalized,
31}
32
33/// Represents a token with its logprob information
34#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
35pub struct TokenLogprob {
36    /// The token as a string
37    pub token: String,
38    /// The log probability of this token
39    pub logprob: f32,
40    /// Optional byte representation of the token
41    pub bytes: Option<Vec<u8>>,
42}
43
44/// Represents logprob information for a single position with selected and alternative tokens
45#[derive(Debug, Clone)]
46pub struct TokenLogProbs {
47    selected: TokenLogprob,
48    alternatives: Vec<TokenLogprob>,
49    all_sorted: Vec<TokenLogprob>,
50}
51
52impl TokenLogProbs {
53    /// Create a new TokenLogProbs from a selected token and alternatives
54    pub fn new(selected: TokenLogprob, mut alternatives: Vec<TokenLogprob>) -> Self {
55        // Sort alternatives by logprob (highest first)
56        alternatives.sort_by(|a, b| b.logprob.partial_cmp(&a.logprob).unwrap());
57
58        // Create all_sorted by merging selected with alternatives (ensuring uniqueness)
59        let mut all_sorted = Vec::new();
60        let mut added_selected = false;
61
62        // Check if selected token appears in alternatives
63        let selected_in_alternatives = alternatives.iter().any(|alt| {
64            alt.token == selected.token && (alt.logprob - selected.logprob).abs() < 1e-6
65        });
66
67        // If selected is not in alternatives, we need to insert it in the right position
68        if !selected_in_alternatives {
69            // Find the correct position to insert selected token
70            let mut insert_position = alternatives.len();
71            for (i, alt) in alternatives.iter().enumerate() {
72                if selected.logprob > alt.logprob {
73                    insert_position = i;
74                    break;
75                }
76            }
77
78            // Build all_sorted by merging at the correct position
79            for (i, alt) in alternatives.iter().enumerate() {
80                if i == insert_position && !added_selected {
81                    all_sorted.push(selected.clone());
82                    added_selected = true;
83                }
84                all_sorted.push(alt.clone());
85            }
86
87            // If we haven't added selected yet, it goes at the end
88            if !added_selected {
89                all_sorted.push(selected.clone());
90            }
91        } else {
92            // Selected is already in alternatives, just use alternatives
93            all_sorted = alternatives.clone();
94        }
95
96        Self {
97            selected,
98            alternatives,
99            all_sorted,
100        }
101    }
102
103    /// Get the selected token
104    pub fn selected_token(&self) -> &TokenLogprob {
105        &self.selected
106    }
107
108    /// Get alternative tokens sorted by most likely first
109    pub fn alternative_tokens(&self) -> &[TokenLogprob] {
110        &self.alternatives
111    }
112
113    /// Get all tokens (selected merged with alternatives, unique) sorted by most likely first
114    pub fn all_tokens(&self) -> &[TokenLogprob] {
115        &self.all_sorted
116    }
117}
118
119/// Trait for extracting logprob information from various response types
120pub trait LogprobExtractor {
121    /// Extract logprobs organized by choice index
122    /// Returns: HashMap<choice_index, Vec<TokenLogProbs>>
123    fn extract_logprobs_by_choice(&self) -> HashMap<u32, Vec<TokenLogProbs>>;
124}
125
126/// Implementation for NvCreateChatCompletionStreamResponse (our main streaming response type)
127impl LogprobExtractor for NvCreateChatCompletionStreamResponse {
128    fn extract_logprobs_by_choice(&self) -> HashMap<u32, Vec<TokenLogProbs>> {
129        let mut result = HashMap::new();
130
131        for choice in &self.choices {
132            let choice_index = choice.index;
133
134            let choice_logprobs = choice
135                .logprobs
136                .as_ref()
137                .and_then(|logprobs| logprobs.content.as_ref())
138                .map(|content| {
139                    content
140                        .iter()
141                        .map(|token_logprob| {
142                            let selected_token = TokenLogprob {
143                                token: token_logprob.token.clone(),
144                                logprob: token_logprob.logprob,
145                                bytes: token_logprob.bytes.clone(),
146                            };
147
148                            // Convert top alternatives to our format
149                            let alternatives: Vec<TokenLogprob> = token_logprob
150                                .top_logprobs
151                                .iter()
152                                .map(|top_logprob| TokenLogprob {
153                                    token: top_logprob.token.clone(),
154                                    logprob: top_logprob.logprob,
155                                    bytes: top_logprob.bytes.clone(),
156                                })
157                                .collect();
158
159                            TokenLogProbs::new(selected_token, alternatives)
160                        })
161                        .collect::<Vec<_>>()
162                })
163                .unwrap_or_default();
164
165            result.insert(choice_index, choice_logprobs);
166        }
167
168        result
169    }
170}
171
172/// Validate and flatten choice logprobs HashMap to Vec
173/// Ensures all expected choice indices [0, max_choice) are present
174pub fn validate_and_flatten_choices(
175    choice_logprobs: HashMap<u32, Vec<TokenLogProbs>>,
176) -> Result<Vec<Vec<TokenLogProbs>>, String> {
177    if choice_logprobs.is_empty() {
178        return Ok(Vec::new());
179    }
180
181    let max_choice = *choice_logprobs.keys().max().unwrap();
182    let expected_count = (max_choice + 1) as usize;
183
184    if choice_logprobs.len() != expected_count {
185        return Err(format!(
186            "Missing choice indices: expected {} choices [0, {}), but found {} choices: {:?}",
187            expected_count,
188            max_choice + 1,
189            choice_logprobs.len(),
190            choice_logprobs.keys().collect::<Vec<_>>()
191        ));
192    }
193
194    // Validate all indices from 0 to max_choice are present
195    for i in 0..=max_choice {
196        if !choice_logprobs.contains_key(&i) {
197            return Err(format!(
198                "Missing choice index {}: expected [0, {}), found {:?}",
199                i,
200                max_choice + 1,
201                choice_logprobs.keys().collect::<Vec<_>>()
202            ));
203        }
204    }
205
206    // Flatten to Vec ordered by keys
207    let mut result = Vec::with_capacity(expected_count);
208    for i in 0..=max_choice {
209        result.push(choice_logprobs[&i].clone());
210    }
211
212    Ok(result)
213}
214
215/// Analysis focused on detecting close logprobs indicating model uncertainty
216#[derive(Debug, Clone, Serialize, Deserialize)]
217pub struct SensitivityAnalysis {
218    /// Total number of responses analyzed
219    pub total_responses: usize,
220    /// Analysis results per choice index
221    pub choice_analyses: HashMap<u32, ChoiceAnalysis>,
222}
223
224/// Analysis for a single choice
225#[derive(Debug, Clone, Serialize, Deserialize)]
226pub struct ChoiceAnalysis {
227    /// Choice index
228    pub choice_index: u32,
229    /// All positions with their closeness values, sorted by closeness
230    pub position_closeness: Vec<PositionCloseness>,
231    /// Number of positions analyzed for this choice
232    pub positions_analyzed: usize,
233}
234
235/// Closeness information for a position
236#[derive(Debug, Clone, Serialize, Deserialize)]
237pub struct PositionCloseness {
238    /// Position in the stream (response index)
239    pub stream_position: usize,
240    /// Position within the token sequence
241    pub token_position: usize,
242    /// Logprob difference between top 2 candidates (deprecated - use probability_difference)
243    pub logprob_difference: f32,
244    /// Probability difference between top 2 candidates (in linear space 0-1)
245    pub probability_difference: f32,
246    /// Probability mass not accounted for by all_tokens (1 - sum of all_tokens probabilities)
247    pub probability_remaining: f32,
248    /// All candidates at this position, sorted by logprob (highest first)
249    pub candidates: Vec<TokenLogprob>,
250}
251
252/// A position where top candidates have close probabilities
253#[derive(Debug, Clone)]
254pub struct ClosePosition {
255    /// Position in the stream (response index)
256    pub stream_position: usize,
257    /// Position within the token sequence
258    pub token_position: usize,
259    /// Logprob difference between top 2 candidates (deprecated - use probability_difference)
260    pub logprob_difference: f32,
261    /// Probability difference between top 2 candidates (in linear space 0-1)
262    pub probability_difference: f32,
263    /// Probability mass not accounted for by top_candidates (1 - sum of top_candidates probabilities)
264    pub probability_remaining: f32,
265    /// Top 2 candidates at this position
266    pub top_candidates: Vec<TokenLogprob>,
267}
268
269/// Analyzes logprobs from a recorded stream focusing on token similarity/closeness
270pub fn analyze_logprob_sensitivity(
271    recorded_stream: Arc<RecordedStream<impl LogprobExtractor>>,
272) -> SensitivityAnalysis {
273    let mut choice_analyses: HashMap<u32, ChoiceAnalysis> = HashMap::new();
274    // Track cumulative sequence position per choice
275    let mut choice_sequence_positions: HashMap<u32, usize> = HashMap::new();
276
277    for (stream_pos, timestamped_response) in recorded_stream.responses().iter().enumerate() {
278        let response = &timestamped_response.response;
279        let logprobs_by_choice = response.extract_logprobs_by_choice();
280
281        for (choice_index, choice_logprobs) in logprobs_by_choice {
282            // Ensure we have a ChoiceAnalysis for this choice
283            let choice_analysis =
284                choice_analyses
285                    .entry(choice_index)
286                    .or_insert_with(|| ChoiceAnalysis {
287                        choice_index,
288                        position_closeness: Vec::new(),
289                        positions_analyzed: 0,
290                    });
291
292            // Get current sequence position for this choice
293            let current_seq_pos = choice_sequence_positions.entry(choice_index).or_insert(0);
294
295            for token_logprobs in choice_logprobs {
296                let all_tokens = token_logprobs.all_tokens();
297
298                if all_tokens.len() < 2 {
299                    *current_seq_pos += 1;
300                    continue;
301                }
302
303                // all_tokens is already sorted by logprob (highest first)
304                let sorted_candidates = all_tokens.to_vec();
305
306                // Calculate difference between top 2 in both logprob and probability space
307                let logprob_difference =
308                    sorted_candidates[0].logprob - sorted_candidates[1].logprob;
309
310                // Convert to probability space for more intuitive closeness calculation
311                let prob1 = sorted_candidates[0].logprob.exp();
312                let prob2 = sorted_candidates[1].logprob.exp();
313                let probability_difference = prob1 - prob2;
314
315                // Calculate probability_remaining
316                let total_prob_sum: f32 = sorted_candidates.iter().map(|t| t.logprob.exp()).sum();
317                let probability_remaining = 1.0 - total_prob_sum;
318
319                choice_analysis.position_closeness.push(PositionCloseness {
320                    stream_position: stream_pos,
321                    token_position: *current_seq_pos,
322                    logprob_difference,
323                    probability_difference,
324                    probability_remaining,
325                    candidates: sorted_candidates,
326                });
327
328                choice_analysis.positions_analyzed += 1;
329                *current_seq_pos += 1;
330            }
331        }
332    }
333
334    // Sort position closeness by probability difference (smallest first = most uncertain)
335    for choice_analysis in choice_analyses.values_mut() {
336        choice_analysis.position_closeness.sort_by(|a, b| {
337            a.probability_difference
338                .partial_cmp(&b.probability_difference)
339                .unwrap()
340        });
341    }
342
343    SensitivityAnalysis {
344        total_responses: recorded_stream.responses().len(),
345        choice_analyses,
346    }
347}
348
349impl SensitivityAnalysis {
350    /// Get positions below a threshold for a specific choice
351    /// Threshold is in probability space (0-1), where smaller values indicate closer probabilities
352    pub fn get_close_positions_for_choice(
353        &self,
354        choice_index: u32,
355        threshold: f32,
356    ) -> Vec<&PositionCloseness> {
357        self.choice_analyses
358            .get(&choice_index)
359            .map(|analysis| {
360                analysis
361                    .position_closeness
362                    .iter()
363                    .filter(|pos| pos.probability_difference <= threshold)
364                    .collect()
365            })
366            .unwrap_or_default()
367    }
368
369    /// Get the closest N positions for a specific choice
370    pub fn get_closest_positions_for_choice(
371        &self,
372        choice_index: u32,
373        count: usize,
374    ) -> Vec<&PositionCloseness> {
375        self.choice_analyses
376            .get(&choice_index)
377            .map(|analysis| analysis.position_closeness.iter().take(count).collect())
378            .unwrap_or_default()
379    }
380
381    /// Print a summary of the sensitivity analysis
382    pub fn print_summary(&self) {
383        println!("=== Logprob Sensitivity Analysis Summary ===");
384        println!("Total stream responses analyzed: {}", self.total_responses);
385        println!("Number of choices: {}", self.choice_analyses.len());
386        println!();
387
388        for (choice_index, choice_analysis) in &self.choice_analyses {
389            println!(
390                "Choice {}: {} positions analyzed",
391                choice_index, choice_analysis.positions_analyzed
392            );
393
394            if !choice_analysis.position_closeness.is_empty() {
395                println!("  Closest positions (smallest probability differences):");
396                for (j, pos) in choice_analysis
397                    .position_closeness
398                    .iter()
399                    .take(3)
400                    .enumerate()
401                {
402                    let top_token = &pos.candidates[0].token;
403                    let second_token = &pos.candidates[1].token;
404                    let prob1 = pos.candidates[0].logprob.exp();
405                    let prob2 = pos.candidates[1].logprob.exp();
406                    println!(
407                        "    {}: Stream pos {}, token pos {} - '{}' ({:.1}%) vs '{}' ({:.1}%) (prob diff: {:.4})",
408                        j + 1,
409                        pos.stream_position,
410                        pos.token_position,
411                        top_token,
412                        prob1 * 100.0,
413                        second_token,
414                        prob2 * 100.0,
415                        pos.probability_difference
416                    );
417                }
418            }
419            println!();
420        }
421    }
422
423    /// Get percentage of positions with close probabilities for a specific choice
424    /// Threshold is in probability space (0-1)
425    pub fn close_position_percentage_for_choice(&self, choice_index: u32, threshold: f32) -> f32 {
426        if let Some(analysis) = self.choice_analyses.get(&choice_index) {
427            if analysis.positions_analyzed == 0 {
428                return 0.0;
429            }
430            let close_count = analysis
431                .position_closeness
432                .iter()
433                .filter(|pos| pos.probability_difference <= threshold)
434                .count();
435            (close_count as f32 / analysis.positions_analyzed as f32) * 100.0
436        } else {
437            0.0
438        }
439    }
440
441    /// Check if multiple tokens are close (within threshold of each other)
442    pub fn detect_multiple_close_tokens(
443        &self,
444        choice_index: u32,
445        threshold: f32,
446    ) -> Vec<MultipleCloseTokens> {
447        let mut results = Vec::new();
448
449        if let Some(analysis) = self.choice_analyses.get(&choice_index) {
450            for pos in &analysis.position_closeness {
451                let close_tokens = self.count_close_tokens_at_position(pos, threshold);
452                if close_tokens.close_count > 2 {
453                    results.push(close_tokens);
454                }
455            }
456        }
457
458        results
459    }
460
461    /// Detect if greedy decoding was likely used by checking if selected tokens are always the most probable
462    /// Note: This is an approximation since we infer selection from the data structure
463    pub fn detect_likely_greedy_decoding(&self, choice_index: u32) -> bool {
464        if let Some(analysis) = self.choice_analyses.get(&choice_index) {
465            if analysis.positions_analyzed == 0 {
466                return true; // No evidence against greedy
467            }
468
469            // For greedy detection, we're looking for positions with moderate to large differences
470            // Very small differences (< 0.01) suggest equal alternatives - could be greedy or random
471            // Very large differences (> 0.05) suggest clear winners - likely greedy
472            let likely_greedy_positions = analysis
473                .position_closeness
474                .iter()
475                .filter(|pos| {
476                    if pos.candidates.is_empty() {
477                        return true; // No contradiction
478                    }
479
480                    // Either very close (tie - could be greedy) or clear difference (likely greedy)
481                    pos.probability_difference < 0.01 || pos.probability_difference > 0.05
482                })
483                .count();
484
485            // If most positions show greedy-like patterns, consider it greedy
486            (likely_greedy_positions as f32 / analysis.positions_analyzed as f32) > 0.5
487        } else {
488            false
489        }
490    }
491
492    /// Get percentage of positions with greedy-like selection patterns
493    pub fn greedy_selection_percentage(&self, choice_index: u32) -> f32 {
494        if let Some(analysis) = self.choice_analyses.get(&choice_index) {
495            if analysis.positions_analyzed == 0 {
496                return 0.0;
497            }
498
499            let greedy_like_positions = analysis
500                .position_closeness
501                .iter()
502                .filter(|pos| {
503                    // Same logic as detect_likely_greedy_decoding for consistency
504                    pos.probability_difference < 0.01 || pos.probability_difference > 0.05
505                })
506                .count();
507
508            (greedy_like_positions as f32 / analysis.positions_analyzed as f32) * 100.0
509        } else {
510            0.0
511        }
512    }
513
514    /// Count how many tokens are close at a specific position
515    /// Threshold is in probability space (0-1)
516    fn count_close_tokens_at_position(
517        &self,
518        position: &PositionCloseness,
519        threshold: f32,
520    ) -> MultipleCloseTokens {
521        let top_prob = position.candidates[0].logprob.exp();
522        let mut close_count = 1; // Top token is always included
523        let mut close_tokens = vec![position.candidates[0].clone()];
524
525        for candidate in &position.candidates[1..] {
526            let candidate_prob = candidate.logprob.exp();
527            let prob_diff = top_prob - candidate_prob;
528            if prob_diff <= threshold {
529                close_count += 1;
530                close_tokens.push(candidate.clone());
531            } else {
532                break; // Since candidates are sorted, no need to check further
533            }
534        }
535
536        let max_difference = if close_count > 1 {
537            let last_prob = close_tokens.last().unwrap().logprob.exp();
538            top_prob - last_prob
539        } else {
540            0.0
541        };
542
543        MultipleCloseTokens {
544            stream_position: position.stream_position,
545            token_position: position.token_position,
546            close_count,
547            close_tokens,
548            max_difference,
549        }
550    }
551}
552
553/// Information about multiple close tokens at a position
554#[derive(Debug, Clone)]
555pub struct MultipleCloseTokens {
556    pub stream_position: usize,
557    pub token_position: usize,
558    pub close_count: usize,
559    pub close_tokens: Vec<TokenLogprob>,
560    pub max_difference: f32,
561}
562
563#[cfg(test)]
564mod tests {
565    use super::*;
566
567    // Type aliases to simplify complex test data structures
568    type TestTokenAlternative = (&'static str, f32);
569    type TestTokenData = (&'static str, f32, Vec<TestTokenAlternative>);
570    type TestTokenDataVec = Vec<TestTokenData>;
571    use crate::perf::{RecordingMode, TimestampedResponse, record_stream_with_context};
572    use crate::protocols::codec::create_message_stream;
573    use crate::protocols::convert_sse_stream;
574    use approx::assert_abs_diff_eq;
575    use dynamo_async_openai::types::{
576        ChatChoiceLogprobs, ChatChoiceStream, ChatCompletionStreamResponseDelta,
577        ChatCompletionTokenLogprob, FinishReason, Role, TopLogprobs,
578    };
579    use futures::StreamExt;
580    use std::sync::Arc;
581    use std::time::Instant;
582
583    const FLOAT_EPSILON: f32 = 1e-6;
584
585    #[test]
586    fn test_two_tokens_close() {
587        // Two very close tokens: 45% vs 44% (remaining 11% for other tokens)
588        // Linear probs: [0.45, 0.44], difference = 0.01
589        let analysis = create_analysis_with_logprobs(vec![create_token_logprob_from_linear_probs(
590            "hello",
591            0.45,
592            vec![("world", 0.44)], // Very close: 45% vs 44%
593        )]);
594
595        let close_positions = analysis.get_close_positions_for_choice(0, 0.1);
596        assert_eq!(close_positions.len(), 1);
597
598        // Probability difference should be 0.01 (45% - 44%)
599        assert_abs_diff_eq!(
600            close_positions[0].probability_difference,
601            0.01,
602            epsilon = FLOAT_EPSILON
603        );
604
605        // Logprob difference: ln(0.45) - ln(0.44) ≈ -0.798 - (-0.821) ≈ 0.023
606        assert_abs_diff_eq!(
607            close_positions[0].logprob_difference,
608            0.023,
609            epsilon = 0.001
610        );
611
612        let multiple_close = analysis.detect_multiple_close_tokens(0, 0.05);
613        assert_eq!(multiple_close.len(), 0); // Only 2 tokens, so no "multiple" detected
614    }
615
616    #[test]
617    fn test_three_tokens_close() {
618        // Three close tokens: 35%, 33%, 32% (complete distribution)
619        // Linear probs: [0.35, 0.33, 0.32], differences = [0.02, 0.01]
620        let analysis = create_analysis_with_logprobs(vec![create_token_logprob_from_linear_probs(
621            "hello",
622            0.35,
623            vec![
624                ("world", 0.33), // Close: 35% vs 33% (diff = 0.02)
625                ("there", 0.32), // Close: 33% vs 32% (diff = 0.01)
626            ],
627        )]);
628
629        let close_positions = analysis.get_close_positions_for_choice(0, 0.025);
630        assert_eq!(close_positions.len(), 1);
631
632        // Top 2 probability difference: 0.35 - 0.33 = 0.02
633        assert_abs_diff_eq!(
634            close_positions[0].probability_difference,
635            0.02,
636            epsilon = FLOAT_EPSILON
637        );
638
639        let multiple_close = analysis.detect_multiple_close_tokens(0, 0.04);
640        assert_eq!(multiple_close.len(), 1);
641        assert_eq!(multiple_close[0].close_count, 3);
642        // Max difference: 0.35 - 0.32 = 0.03
643        assert_abs_diff_eq!(
644            multiple_close[0].max_difference,
645            0.03,
646            epsilon = FLOAT_EPSILON
647        );
648    }
649
650    #[test]
651    fn test_four_tokens_close() {
652        // Four close tokens: 27%, 26%, 25%, 22% (complete distribution)
653        // Linear probs: [0.27, 0.26, 0.25, 0.22], all very close
654        let analysis = create_analysis_with_logprobs(vec![create_token_logprob_from_linear_probs(
655            "hello",
656            0.27,
657            vec![
658                ("world", 0.26),  // Close: 27% vs 26% (diff = 0.01)
659                ("there", 0.25),  // Close: 26% vs 25% (diff = 0.01)
660                ("friend", 0.22), // Close: 25% vs 22% (diff = 0.03)
661            ],
662        )]);
663
664        let close_positions = analysis.get_close_positions_for_choice(0, 0.02);
665        assert_eq!(close_positions.len(), 1);
666
667        // Top 2 probability difference: 0.27 - 0.26 = 0.01
668        assert_abs_diff_eq!(
669            close_positions[0].probability_difference,
670            0.01,
671            epsilon = FLOAT_EPSILON
672        );
673
674        let multiple_close = analysis.detect_multiple_close_tokens(0, 0.06);
675        assert_eq!(multiple_close.len(), 1);
676        assert_eq!(multiple_close[0].close_count, 4);
677        // Max difference: 0.27 - 0.22 = 0.05
678        assert_abs_diff_eq!(
679            multiple_close[0].max_difference,
680            0.05,
681            epsilon = FLOAT_EPSILON
682        );
683    }
684
685    #[test]
686    fn test_multiple_choices_analysis() {
687        let analysis = create_analysis_with_multiple_choices(vec![
688            // Choice 0: Moderately close tokens (70% vs 25%, remaining 5%)
689            vec![create_token_logprob_from_linear_probs(
690                "hello",
691                0.7,
692                vec![("world", 0.25)],
693            )],
694            // Choice 1: Very close tokens (50.5% vs 49.5%)
695            vec![create_token_logprob_from_linear_probs(
696                "hi",
697                0.505,
698                vec![("there", 0.495)],
699            )],
700        ]);
701
702        assert_eq!(analysis.choice_analyses.len(), 2);
703
704        // Check choice 0: probability difference = 0.7 - 0.25 = 0.45
705        let choice0_close = analysis.get_close_positions_for_choice(0, 0.5);
706        assert_eq!(choice0_close.len(), 1);
707        assert_abs_diff_eq!(
708            choice0_close[0].probability_difference,
709            0.45,
710            epsilon = FLOAT_EPSILON
711        );
712
713        // Check choice 1: probability difference = 0.505 - 0.495 = 0.01
714        let choice1_close = analysis.get_close_positions_for_choice(1, 0.5);
715        assert_eq!(choice1_close.len(), 1);
716        assert_abs_diff_eq!(
717            choice1_close[0].probability_difference,
718            0.01,
719            epsilon = FLOAT_EPSILON
720        );
721
722        // Choice 1 should be much closer than choice 0
723        assert!(choice1_close[0].probability_difference < choice0_close[0].probability_difference);
724    }
725
726    #[test]
727    fn test_edge_case_single_token() {
728        // Position with only one token (100% probability, no alternatives)
729        let analysis = create_analysis_with_logprobs(vec![create_token_logprob_from_linear_probs(
730            "hello",
731            1.0,
732            vec![],
733        )]);
734
735        let close_positions = analysis.get_close_positions_for_choice(0, 1.0);
736        assert_eq!(close_positions.len(), 0); // No close positions when only 1 token
737    }
738
739    #[test]
740    fn test_threshold_filtering() {
741        let analysis = create_analysis_with_logprobs(vec![
742            // Position 1: Close tokens (55% vs 45%)
743            create_token_logprob_from_linear_probs("token1", 0.55, vec![("token2", 0.45)]),
744            // Position 2: Far tokens (80% vs 20%)
745            create_token_logprob_from_linear_probs("token3", 0.8, vec![("token4", 0.2)]),
746        ]);
747
748        // With threshold 0.15, only first position should be close (diff = 0.1)
749        let close_strict = analysis.get_close_positions_for_choice(0, 0.15);
750        assert_eq!(close_strict.len(), 1);
751        assert_abs_diff_eq!(
752            close_strict[0].probability_difference,
753            0.1,
754            epsilon = FLOAT_EPSILON
755        );
756
757        // With threshold 0.7, both positions should be close
758        let close_permissive = analysis.get_close_positions_for_choice(0, 0.7);
759        assert_eq!(close_permissive.len(), 2);
760
761        // Check they're sorted by closeness (0.1 < 0.6)
762        assert!(
763            close_permissive[0].probability_difference < close_permissive[1].probability_difference
764        );
765    }
766
767    #[test]
768    fn test_percentage_calculation() {
769        let analysis = create_analysis_with_logprobs(vec![
770            // Position 1: Close (60% vs 40%, diff = 0.2)
771            create_token_logprob_from_linear_probs("token1", 0.6, vec![("token2", 0.4)]),
772            // Position 2: Far (90% vs 10%, diff = 0.8)
773            create_token_logprob_from_linear_probs("token3", 0.9, vec![("token4", 0.1)]),
774            // Position 3: Close (52% vs 48%, diff = 0.04)
775            create_token_logprob_from_linear_probs("token5", 0.52, vec![("token6", 0.48)]),
776        ]);
777
778        let percentage = analysis.close_position_percentage_for_choice(0, 0.25);
779        assert!((percentage - 66.67).abs() < 0.01); // 2 out of 3 positions are close
780    }
781
782    #[test]
783    fn test_real_vllm_equal_logprobs() {
784        // Real example from vLLM where two tokens have identical logprobs
785        // Both "Ġblock" and "Ġchunk" have logprob -0.9078922271728516
786        // exp(-0.9078922271728516) ≈ 0.403 = 40.3% each (sum = 80.6%, remaining 19.4%)
787        let analysis = create_analysis_with_logprobs(vec![create_token_logprob_from_linear_probs(
788            "Ġblock",
789            0.403,
790            vec![("Ġchunk", 0.403)], // Identical probability = equally likely
791        )]);
792
793        // These should be detected as extremely close (difference = 0.0)
794        let close_positions = analysis.get_close_positions_for_choice(0, 0.001);
795        assert_eq!(close_positions.len(), 1);
796        assert_abs_diff_eq!(
797            close_positions[0].probability_difference,
798            0.0,
799            epsilon = FLOAT_EPSILON
800        );
801
802        // Verify probabilities are exactly equal at 40.3%
803        let position = &close_positions[0];
804        assert_eq!(position.candidates.len(), 2);
805
806        // Check that both tokens are present (order doesn't matter for equal logprobs)
807        let tokens: Vec<&str> = position
808            .candidates
809            .iter()
810            .map(|c| c.token.as_str())
811            .collect();
812        assert!(tokens.contains(&"Ġblock"));
813        assert!(tokens.contains(&"Ġchunk"));
814
815        // Both should have identical logprobs (ln(0.403) ≈ -0.907892)
816        assert_abs_diff_eq!(
817            position.candidates[0].logprob,
818            position.candidates[1].logprob,
819            epsilon = FLOAT_EPSILON
820        );
821
822        // Verify the actual probability values
823        let prob1 = position.candidates[0].logprob.exp();
824        let prob2 = position.candidates[1].logprob.exp();
825        assert_abs_diff_eq!(prob1, 0.403, epsilon = 0.001);
826        assert_abs_diff_eq!(prob2, 0.403, epsilon = 0.001);
827    }
828
829    // Helper functions for creating test data
830    fn create_analysis_with_logprobs(
831        token_logprobs: Vec<ChatCompletionTokenLogprob>,
832    ) -> SensitivityAnalysis {
833        let start_time = Instant::now();
834        let response = create_mock_response_with_logprobs(token_logprobs);
835        let responses = vec![TimestampedResponse::new(response, 0)];
836        let recorded_stream = RecordedStream::new(responses, start_time, Instant::now());
837        let arc_stream = Arc::new(recorded_stream);
838
839        analyze_logprob_sensitivity(arc_stream)
840    }
841
842    fn create_analysis_with_multiple_choices(
843        choices_logprobs: Vec<Vec<ChatCompletionTokenLogprob>>,
844    ) -> SensitivityAnalysis {
845        let start_time = Instant::now();
846        let response = create_mock_response_with_multiple_choices(choices_logprobs);
847        let responses = vec![TimestampedResponse::new(response, 0)];
848        let recorded_stream = RecordedStream::new(responses, start_time, Instant::now());
849        let arc_stream = Arc::new(recorded_stream);
850
851        analyze_logprob_sensitivity(arc_stream)
852    }
853
854    fn create_analysis_with_mixed_sampling(mixed_data: TestTokenDataVec) -> SensitivityAnalysis {
855        let start_time = Instant::now();
856        let token_logprobs: Vec<ChatCompletionTokenLogprob> = mixed_data
857            .into_iter()
858            .map(|(selected_token, selected_prob, alternatives)| {
859                create_token_logprob_from_linear_probs(selected_token, selected_prob, alternatives)
860            })
861            .collect();
862
863        let response = create_mock_response_with_logprobs(token_logprobs);
864        let responses = vec![TimestampedResponse::new(response, 0)];
865        let recorded_stream = RecordedStream::new(responses, start_time, Instant::now());
866        let arc_stream = Arc::new(recorded_stream);
867
868        analyze_logprob_sensitivity(arc_stream)
869    }
870
871    fn create_analysis_with_missing_selected_token() -> SensitivityAnalysis {
872        let start_time = Instant::now();
873
874        // Create a scenario where the selected token has a lower probability than alternatives
875        // This simulates non-greedy sampling: selected token 15%, but alternatives are 40% and 30%
876        let token_logprobs = vec![ChatCompletionTokenLogprob {
877            token: "unlikely_selection".to_string(),
878            logprob: (0.15_f32).ln(), // Selected but not optimal: 15%
879            bytes: None,
880            top_logprobs: vec![
881                TopLogprobs {
882                    token: "best_option".to_string(),
883                    logprob: (0.4_f32).ln(), // Much better option: 40%
884                    bytes: None,
885                },
886                TopLogprobs {
887                    token: "second_best".to_string(),
888                    logprob: (0.3_f32).ln(), // Still better than selected: 30%
889                    bytes: None,
890                },
891            ],
892        }];
893
894        let response = create_mock_response_with_logprobs(token_logprobs);
895        let responses = vec![TimestampedResponse::new(response, 0)];
896        let recorded_stream = RecordedStream::new(responses, start_time, Instant::now());
897        let arc_stream = Arc::new(recorded_stream);
898
899        analyze_logprob_sensitivity(arc_stream)
900    }
901
902    /// Helper function to create token logprobs from linear probabilities [0, 1]
903    /// This ensures realistic probability distributions that sum to ≤ 1
904    fn create_token_logprob_from_linear_probs(
905        token: &str,
906        prob: f32,
907        top_probs: Vec<(&str, f32)>,
908    ) -> ChatCompletionTokenLogprob {
909        // Validate that probabilities are in [0, 1] range
910        assert!(
911            (0.0..=1.0).contains(&prob),
912            "Probability must be in [0, 1]: {}",
913            prob
914        );
915
916        // Calculate total probability mass
917        let total_prob = prob + top_probs.iter().map(|(_, p)| p).sum::<f32>();
918        assert!(
919            total_prob <= 1.001,
920            "Total probability mass exceeds 1: {}",
921            total_prob
922        ); // Allow small floating point error
923
924        for (_, p) in &top_probs {
925            assert!(
926                *p >= 0.0 && *p <= 1.0,
927                "Probability must be in [0, 1]: {}",
928                p
929            );
930        }
931
932        ChatCompletionTokenLogprob {
933            token: token.to_string(),
934            logprob: prob.ln(),
935            bytes: None,
936            top_logprobs: top_probs
937                .into_iter()
938                .map(|(t, p)| TopLogprobs {
939                    token: t.to_string(),
940                    logprob: p.ln(),
941                    bytes: None,
942                })
943                .collect(),
944        }
945    }
946
947    fn create_mock_response_with_logprobs(
948        token_logprobs: Vec<ChatCompletionTokenLogprob>,
949    ) -> NvCreateChatCompletionStreamResponse {
950        #[expect(deprecated)]
951        NvCreateChatCompletionStreamResponse {
952            id: "test_id".to_string(),
953            choices: vec![ChatChoiceStream {
954                index: 0,
955                delta: ChatCompletionStreamResponseDelta {
956                    content: Some("test".to_string()),
957                    function_call: None,
958                    tool_calls: None,
959                    role: Some(Role::Assistant),
960                    refusal: None,
961                    reasoning_content: None,
962                },
963                finish_reason: Some(FinishReason::Stop),
964                logprobs: Some(ChatChoiceLogprobs {
965                    content: Some(token_logprobs),
966                    refusal: None,
967                }),
968            }],
969            created: 1234567890,
970            model: "test-model".to_string(),
971            service_tier: None,
972            system_fingerprint: None,
973            object: "chat.completion.chunk".to_string(),
974            usage: None,
975        }
976    }
977
978    fn create_mock_response_with_multiple_choices(
979        choices_logprobs: Vec<Vec<ChatCompletionTokenLogprob>>,
980    ) -> NvCreateChatCompletionStreamResponse {
981        #[expect(deprecated)]
982        let choices = choices_logprobs
983            .into_iter()
984            .enumerate()
985            .map(|(i, token_logprobs)| ChatChoiceStream {
986                index: i as u32,
987                delta: ChatCompletionStreamResponseDelta {
988                    content: Some("test".to_string()),
989                    function_call: None,
990                    tool_calls: None,
991                    role: Some(Role::Assistant),
992                    refusal: None,
993                    reasoning_content: None,
994                },
995                finish_reason: Some(FinishReason::Stop),
996                logprobs: Some(ChatChoiceLogprobs {
997                    content: Some(token_logprobs),
998                    refusal: None,
999                }),
1000            })
1001            .collect();
1002
1003        NvCreateChatCompletionStreamResponse {
1004            id: "test_id".to_string(),
1005            choices,
1006            created: 1234567890,
1007            model: "test-model".to_string(),
1008            service_tier: None,
1009            system_fingerprint: None,
1010            object: "chat.completion.chunk".to_string(),
1011            usage: None,
1012        }
1013    }
1014
1015    #[test]
1016    fn test_sensitivity_analysis() {
1017        let start_time = Instant::now();
1018        let responses = vec![TimestampedResponse::new(create_mock_response(), 0)];
1019
1020        let recorded_stream = RecordedStream::new(responses, start_time, Instant::now());
1021        let arc_stream = Arc::new(recorded_stream);
1022
1023        let analysis = analyze_logprob_sensitivity(arc_stream);
1024        // Basic validation that analysis was created
1025        assert_eq!(analysis.total_responses, 1);
1026        assert!(analysis.close_position_percentage_for_choice(0, 0.5) >= 0.0);
1027    }
1028
1029    #[test]
1030    fn test_extract_logprobs_by_choice_empty() {
1031        let response = create_mock_response();
1032        let logprobs = response.extract_logprobs_by_choice();
1033        assert!(logprobs.is_empty() || logprobs.values().any(|v| v.is_empty()));
1034    }
1035
1036    #[test]
1037    fn test_token_logprobs_struct() {
1038        // Test TokenLogProbs with selected token not in alternatives
1039        let selected = TokenLogprob {
1040            token: "selected".to_string(),
1041            logprob: 0.7_f32.ln(), // 70%
1042            bytes: None,
1043        };
1044
1045        let alternatives = vec![
1046            TokenLogprob {
1047                token: "alt1".to_string(),
1048                logprob: 0.2_f32.ln(), // 20%
1049                bytes: None,
1050            },
1051            TokenLogprob {
1052                token: "alt2".to_string(),
1053                logprob: 0.1_f32.ln(), // 10%
1054                bytes: None,
1055            },
1056        ];
1057
1058        let token_logprobs = TokenLogProbs::new(selected.clone(), alternatives.clone());
1059
1060        // Test methods
1061        assert_eq!(token_logprobs.selected_token(), &selected);
1062        assert_eq!(token_logprobs.alternative_tokens().len(), 2);
1063        assert_eq!(token_logprobs.all_tokens().len(), 3);
1064
1065        // Test sorting - all_tokens should be sorted by logprob (highest first)
1066        let all_tokens = token_logprobs.all_tokens();
1067        assert_eq!(all_tokens[0].token, "selected"); // 70%
1068        assert_eq!(all_tokens[1].token, "alt1"); // 20%
1069        assert_eq!(all_tokens[2].token, "alt2"); // 10%
1070
1071        // Test that alternatives are sorted
1072        let alt_tokens = token_logprobs.alternative_tokens();
1073        assert_eq!(alt_tokens[0].token, "alt1"); // 20%
1074        assert_eq!(alt_tokens[1].token, "alt2"); // 10%
1075    }
1076
1077    #[test]
1078    fn test_token_logprobs_selected_in_alternatives() {
1079        // Test case where selected token already appears in alternatives
1080        let selected = TokenLogprob {
1081            token: "token".to_string(),
1082            logprob: 0.4_f32.ln(), // 40%
1083            bytes: None,
1084        };
1085
1086        let alternatives = vec![
1087            TokenLogprob {
1088                token: "token".to_string(),
1089                logprob: 0.4_f32.ln(), // Same as selected
1090                bytes: None,
1091            },
1092            TokenLogprob {
1093                token: "other".to_string(),
1094                logprob: 0.3_f32.ln(), // 30%
1095                bytes: None,
1096            },
1097        ];
1098
1099        let token_logprobs = TokenLogProbs::new(selected, alternatives.clone());
1100
1101        // all_tokens should not duplicate the selected token
1102        let all_tokens = token_logprobs.all_tokens();
1103        assert_eq!(all_tokens.len(), 2);
1104        assert_eq!(all_tokens[0].token, "token"); // 40%
1105        assert_eq!(all_tokens[1].token, "other"); // 30%
1106    }
1107
1108    #[test]
1109    fn test_validate_and_flatten_choices() {
1110        // Test successful validation
1111        let mut choices = HashMap::new();
1112        choices.insert(0, vec![]);
1113        choices.insert(1, vec![]);
1114        choices.insert(2, vec![]);
1115
1116        let result = validate_and_flatten_choices(choices);
1117        assert!(result.is_ok());
1118        let flattened = result.unwrap();
1119        assert_eq!(flattened.len(), 3);
1120
1121        // Test missing choice index
1122        let mut choices = HashMap::new();
1123        choices.insert(0, vec![]);
1124        choices.insert(2, vec![]); // Missing index 1
1125
1126        let result = validate_and_flatten_choices(choices);
1127        assert!(result.is_err());
1128        let error_msg = result.unwrap_err();
1129        assert!(
1130            error_msg.contains("Missing choice indices")
1131                && error_msg.contains("expected 3 choices")
1132        );
1133
1134        // Test empty choices
1135        let choices = HashMap::new();
1136        let result = validate_and_flatten_choices(choices);
1137        assert!(result.is_ok());
1138        assert_eq!(result.unwrap().len(), 0);
1139    }
1140
1141    #[test]
1142    fn test_probability_remaining_calculation() {
1143        // Test with tokens that don't sum to 1.0 (incomplete distribution)
1144        let analysis = create_analysis_with_logprobs(vec![create_token_logprob_from_linear_probs(
1145            "token",
1146            0.4, // 40%
1147            vec![
1148                ("alt1", 0.3), // 30%
1149                ("alt2", 0.1), // 10%
1150                               // Missing 20% probability mass
1151            ],
1152        )]);
1153
1154        let close_positions = analysis.get_close_positions_for_choice(0, 1.0);
1155        assert_eq!(close_positions.len(), 1);
1156
1157        let position = &close_positions[0];
1158
1159        // Should have probability_remaining ≈ 0.2 (20% missing)
1160        // Total: 40% + 30% + 10% = 80%, so remaining = 20%
1161        assert_abs_diff_eq!(position.probability_remaining, 0.2, epsilon = 0.01);
1162
1163        // Test with tokens that nearly sum to 1.0 (complete distribution)
1164        let analysis_complete =
1165            create_analysis_with_logprobs(vec![create_token_logprob_from_linear_probs(
1166                "token",
1167                0.5, // 50%
1168                vec![
1169                    ("alt1", 0.3), // 30%
1170                    ("alt2", 0.2), // 20%
1171                                   // Total: 100%
1172                ],
1173            )]);
1174
1175        let complete_positions = analysis_complete.get_close_positions_for_choice(0, 1.0);
1176        assert_eq!(complete_positions.len(), 1);
1177
1178        let complete_position = &complete_positions[0];
1179
1180        // Should have probability_remaining ≈ 0.0 (no missing mass)
1181        assert_abs_diff_eq!(complete_position.probability_remaining, 0.0, epsilon = 0.01);
1182    }
1183
1184    #[test]
1185    fn test_position_closeness_ordering() {
1186        let analysis = create_analysis_with_logprobs(vec![
1187            // Position 1: Far apart (85% vs 15%, diff = 0.7)
1188            create_token_logprob_from_linear_probs("far", 0.85, vec![("alt", 0.15)]),
1189            // Position 2: Close (51% vs 49%, diff = 0.02)
1190            create_token_logprob_from_linear_probs("close", 0.51, vec![("alt", 0.49)]),
1191            // Position 3: Medium (70% vs 30%, diff = 0.4)
1192            create_token_logprob_from_linear_probs("medium", 0.7, vec![("alt", 0.3)]),
1193        ]);
1194
1195        let positions = &analysis.choice_analyses.get(&0).unwrap().position_closeness;
1196        assert_eq!(positions.len(), 3);
1197
1198        // Should be sorted by closeness (smallest difference first)
1199        assert!(positions[0].probability_difference <= positions[1].probability_difference);
1200        assert!(positions[1].probability_difference <= positions[2].probability_difference);
1201
1202        // Check actual values
1203        assert_abs_diff_eq!(
1204            positions[0].probability_difference,
1205            0.02,
1206            epsilon = FLOAT_EPSILON
1207        );
1208        assert_abs_diff_eq!(
1209            positions[1].probability_difference,
1210            0.4,
1211            epsilon = FLOAT_EPSILON
1212        );
1213        assert_abs_diff_eq!(
1214            positions[2].probability_difference,
1215            0.7,
1216            epsilon = FLOAT_EPSILON
1217        );
1218    }
1219
1220    #[test]
1221    fn test_multiple_close_tokens_edge_cases() {
1222        // Test with exactly 3 close tokens: 34%, 33%, 32% (close within 0.02)
1223        let analysis = create_analysis_with_logprobs(vec![create_token_logprob_from_linear_probs(
1224            "token",
1225            0.34,
1226            vec![
1227                ("alt1", 0.33), // diff = 0.01
1228                ("alt2", 0.32), // diff = 0.01 from alt1, 0.02 from token
1229                ("alt3", 0.01), // diff = 0.31 (not close)
1230            ],
1231        )]);
1232
1233        let multiple_close = analysis.detect_multiple_close_tokens(0, 0.025);
1234        assert_eq!(multiple_close.len(), 1);
1235        assert_eq!(multiple_close[0].close_count, 3);
1236    }
1237
1238    #[test]
1239    fn test_choice_analysis_independence() {
1240        let analysis = create_analysis_with_multiple_choices(vec![
1241            // Choice 0: 2 positions, 1 close
1242            vec![
1243                create_token_logprob_from_linear_probs("token1", 0.55, vec![("alt1", 0.45)]), // diff = 0.1
1244                create_token_logprob_from_linear_probs("token2", 0.9, vec![("alt2", 0.1)]), // diff = 0.8
1245            ],
1246            // Choice 1: 1 position, very close
1247            vec![
1248                create_token_logprob_from_linear_probs("token3", 0.501, vec![("alt3", 0.499)]), // diff = 0.002
1249            ],
1250        ]);
1251
1252        assert_eq!(analysis.choice_analyses.len(), 2);
1253        assert_eq!(
1254            analysis.choice_analyses.get(&0).unwrap().positions_analyzed,
1255            2
1256        );
1257        assert_eq!(
1258            analysis.choice_analyses.get(&1).unwrap().positions_analyzed,
1259            1
1260        );
1261
1262        // Check independence - each choice should have different closeness patterns
1263        let choice0_close = analysis.get_close_positions_for_choice(0, 0.5);
1264        let choice1_close = analysis.get_close_positions_for_choice(1, 0.5);
1265
1266        assert_eq!(choice0_close.len(), 1);
1267        assert_eq!(choice1_close.len(), 1);
1268
1269        // Choice 1 should be much closer
1270        assert!(choice1_close[0].probability_difference < choice0_close[0].probability_difference);
1271    }
1272
1273    #[test]
1274    fn test_get_closest_positions_boundary() {
1275        let analysis = create_analysis_with_logprobs(vec![
1276            create_token_logprob_from_linear_probs("token1", 0.6, vec![("alt1", 0.4)]),
1277            create_token_logprob_from_linear_probs("token2", 0.75, vec![("alt2", 0.25)]),
1278        ]);
1279
1280        // Request more positions than available
1281        let closest = analysis.get_closest_positions_for_choice(0, 10);
1282        assert_eq!(closest.len(), 2);
1283
1284        // Request exactly the number available
1285        let closest = analysis.get_closest_positions_for_choice(0, 2);
1286        assert_eq!(closest.len(), 2);
1287
1288        // Request fewer
1289        let closest = analysis.get_closest_positions_for_choice(0, 1);
1290        assert_eq!(closest.len(), 1);
1291    }
1292
1293    #[test]
1294    fn test_zero_threshold() {
1295        let analysis = create_analysis_with_logprobs(vec![
1296            create_token_logprob_from_linear_probs("token", 0.5, vec![("alt", 0.5)]), // diff = 0.0
1297        ]);
1298
1299        let close_positions = analysis.get_close_positions_for_choice(0, 0.0);
1300        assert_eq!(close_positions.len(), 1);
1301        assert_abs_diff_eq!(
1302            close_positions[0].probability_difference,
1303            0.0,
1304            epsilon = FLOAT_EPSILON
1305        );
1306    }
1307
1308    #[test]
1309    fn test_nonexistent_choice() {
1310        let analysis = create_analysis_with_logprobs(vec![create_token_logprob_from_linear_probs(
1311            "token",
1312            0.6,
1313            vec![("alt", 0.4)],
1314        )]);
1315
1316        // Request analysis for non-existent choice
1317        let close_positions = analysis.get_close_positions_for_choice(5, 0.1);
1318        assert!(close_positions.is_empty());
1319
1320        let closest = analysis.get_closest_positions_for_choice(5, 3);
1321        assert!(closest.is_empty());
1322
1323        let percentage = analysis.close_position_percentage_for_choice(5, 0.1);
1324        assert_eq!(percentage, 0.0);
1325    }
1326
1327    #[test]
1328    fn test_logprob_extractor_with_missing_data() {
1329        // Test with choice that has no logprobs
1330        #[expect(deprecated)]
1331        let response = NvCreateChatCompletionStreamResponse {
1332            id: "test_id".to_string(),
1333            choices: vec![ChatChoiceStream {
1334                index: 0,
1335                delta: ChatCompletionStreamResponseDelta {
1336                    content: Some("test".to_string()),
1337                    function_call: None,
1338                    tool_calls: None,
1339                    role: Some(Role::Assistant),
1340                    refusal: None,
1341                    reasoning_content: None,
1342                },
1343                finish_reason: Some(FinishReason::Stop),
1344                logprobs: None, // No logprobs
1345            }],
1346            created: 1234567890,
1347            model: "test-model".to_string(),
1348            service_tier: None,
1349            system_fingerprint: None,
1350            object: "chat.completion.chunk".to_string(),
1351            usage: None,
1352        };
1353
1354        let logprobs = response.extract_logprobs_by_choice();
1355        assert_eq!(logprobs.len(), 1);
1356        assert!(logprobs.values().any(|v| v.is_empty()));
1357    }
1358
1359    #[test]
1360    fn test_print_summary_no_panic() {
1361        let analysis = create_analysis_with_logprobs(vec![create_token_logprob_from_linear_probs(
1362            "token",
1363            0.6,
1364            vec![("alt", 0.4)],
1365        )]);
1366
1367        // Should not panic when printing summary
1368        analysis.print_summary();
1369    }
1370
1371    #[test]
1372    fn test_greedy_decoding_detection() {
1373        // Greedy decoding: selected token is always the most probable
1374        // Position 1: Clear winner (80% vs 15% vs 5%)
1375        // Position 2: Another clear winner (70% vs 20% vs 10%)
1376        let analysis = create_analysis_with_logprobs(vec![
1377            create_token_logprob_from_linear_probs(
1378                "best",
1379                0.8,
1380                vec![("second", 0.15), ("third", 0.05)],
1381            ),
1382            create_token_logprob_from_linear_probs(
1383                "optimal",
1384                0.7,
1385                vec![("suboptimal", 0.2), ("bad", 0.1)],
1386            ),
1387        ]);
1388
1389        // Should detect greedy-like behavior (selected tokens have highest probability)
1390        let is_greedy = analysis.detect_likely_greedy_decoding(0);
1391        assert!(is_greedy);
1392
1393        let greedy_percentage = analysis.greedy_selection_percentage(0);
1394        assert!(greedy_percentage > 90.0); // Should be close to 100%
1395    }
1396
1397    #[test]
1398    fn test_non_greedy_decoding_detection() {
1399        // Non-greedy decoding: some positions show sampling behavior
1400        // Position 1: Greedy selection (best token chosen: 60% vs 40%)
1401        // Position 2: Non-greedy-like (close tokens: 35% vs 33% vs 32%)
1402        let analysis = create_analysis_with_mixed_sampling(vec![
1403            ("selected_best", 0.6, vec![("alternative", 0.4)]),
1404            (
1405                "close_choice",
1406                0.35,
1407                vec![("very_close", 0.33), ("also_close", 0.32)],
1408            ),
1409        ]);
1410
1411        let _is_greedy = analysis.detect_likely_greedy_decoding(0);
1412        // This should be detected as greedy since we have some clear differences
1413
1414        let greedy_percentage = analysis.greedy_selection_percentage(0);
1415        assert!((0.0..=100.0).contains(&greedy_percentage)); // Valid percentage range
1416    }
1417
1418    #[test]
1419    fn test_selected_token_not_in_top_logprobs() {
1420        // Edge case: selected token doesn't appear in top_logprobs at all
1421        // Selected: 15%, but alternatives are 40% and 30% (non-greedy sampling)
1422        let analysis = create_analysis_with_missing_selected_token();
1423
1424        // Should still work - the algorithm adapts to different logprob patterns
1425        let greedy_percentage = analysis.greedy_selection_percentage(0);
1426        assert!((0.0..=100.0).contains(&greedy_percentage)); // Valid percentage range
1427    }
1428
1429    #[test]
1430    fn test_equal_logprobs_greedy_detection() {
1431        // Test the original vLLM example - equal logprobs should be detected as close
1432        let analysis = create_analysis_with_logprobs(vec![create_token_logprob_from_linear_probs(
1433            "Ġblock",
1434            0.403,
1435            vec![("Ġchunk", 0.403)], // Identical probability = equally likely
1436        )]);
1437
1438        // Equal probabilities should be detected as extremely close
1439        let close_positions = analysis.get_close_positions_for_choice(0, 0.001);
1440        assert_eq!(close_positions.len(), 1);
1441
1442        // Should be detected as greedy-like since there's no clear better choice
1443        let is_greedy = analysis.detect_likely_greedy_decoding(0);
1444        assert!(is_greedy);
1445    }
1446
1447    #[tokio::test]
1448    async fn test_real_sse_stream_analysis() {
1449        // Read the real SSE data with logprobs
1450        let data = std::fs::read_to_string(
1451            "tests/data/replays/deepseek-r1-distill-llama-8b/chat-completions.stream.1",
1452        )
1453        .expect("Failed to read test data file");
1454
1455        // Create stream from SSE data
1456        let sse_stream = create_message_stream(&data);
1457
1458        // Convert SSE messages to our stream response format using the existing converter
1459        let response_stream =
1460            convert_sse_stream::<NvCreateChatCompletionStreamResponse>(Box::pin(sse_stream));
1461
1462        // Filter out errors and extract successful responses
1463        let filtered_stream = response_stream.filter_map(|annotated| async move { annotated.data });
1464
1465        // Create a mock context for recording
1466        let ctx = Arc::new(MockContext::new());
1467
1468        // Record the stream
1469        let (recorded_stream, recording_rx) =
1470            record_stream_with_context(Box::pin(filtered_stream), ctx, RecordingMode::Sink);
1471
1472        // Consume the stream (it will be recorded)
1473        let _collected: Vec<_> = recorded_stream.collect().await;
1474
1475        // Get the recorded data
1476        let recorded = recording_rx
1477            .await
1478            .expect("Failed to receive recorded stream");
1479
1480        // Verify we have data
1481        assert!(recorded.response_count() > 0, "No responses recorded");
1482        println!("Recorded {} responses", recorded.response_count());
1483
1484        // Perform logprob analysis
1485        let arc_recorded = Arc::new(recorded);
1486        let analysis = analyze_logprob_sensitivity(arc_recorded);
1487
1488        // Print analysis summary
1489        analysis.print_summary();
1490
1491        // Verify the analysis found logprob data
1492        assert!(
1493            !analysis.choice_analyses.is_empty(),
1494            "No choice analyses found"
1495        );
1496        assert!(
1497            analysis
1498                .choice_analyses
1499                .values()
1500                .any(|a| a.positions_analyzed > 0),
1501            "No positions analyzed"
1502        );
1503
1504        // Look for the specific vLLM case with equal logprobs ("Ġblock" vs "Ġchunk")
1505        let close_positions = analysis.get_close_positions_for_choice(0, 0.001);
1506
1507        // Should find at least one very close position (the equal logprob case)
1508        assert!(!close_positions.is_empty(), "No close positions found");
1509
1510        // Check if we found the exact equal case (difference = 0)
1511        let equal_positions = close_positions
1512            .iter()
1513            .filter(|pos| pos.probability_difference < 0.0001)
1514            .count();
1515        if equal_positions > 0 {
1516            println!(
1517                "Found {} positions with nearly equal probabilities",
1518                equal_positions
1519            );
1520        }
1521
1522        // Test other analysis methods
1523        let closest_3 = analysis.get_closest_positions_for_choice(0, 3);
1524        assert!(
1525            closest_3.len() <= 3,
1526            "Should return at most 3 closest positions"
1527        );
1528
1529        let percentage = analysis.close_position_percentage_for_choice(0, 0.1);
1530        assert!(
1531            (0.0..=100.0).contains(&percentage),
1532            "Percentage should be valid"
1533        );
1534
1535        // Test greedy detection
1536        let is_greedy = analysis.detect_likely_greedy_decoding(0);
1537        let greedy_percentage = analysis.greedy_selection_percentage(0);
1538        println!(
1539            "Greedy detection: {} ({}% greedy-like)",
1540            is_greedy, greedy_percentage
1541        );
1542
1543        // Test multiple close tokens detection
1544        let multiple_close = analysis.detect_multiple_close_tokens(0, 0.05);
1545        if !multiple_close.is_empty() {
1546            println!(
1547                "Found {} positions with multiple close tokens",
1548                multiple_close.len()
1549            );
1550        }
1551    }
1552
1553    fn create_mock_response() -> NvCreateChatCompletionStreamResponse {
1554        // Create a mock response for testing
1555        // In practice, this would have real logprobs data
1556
1557        NvCreateChatCompletionStreamResponse {
1558            id: "test_id".to_string(),
1559            choices: vec![],
1560            created: 1234567890,
1561            model: "test-model".to_string(),
1562            service_tier: None,
1563            system_fingerprint: None,
1564            object: "chat.completion.chunk".to_string(),
1565            usage: None,
1566        }
1567    }
1568
1569    // Mock context for testing
1570    #[derive(Debug)]
1571    struct MockContext {
1572        id: String,
1573    }
1574
1575    impl MockContext {
1576        fn new() -> Self {
1577            Self {
1578                id: "test-context".to_string(),
1579            }
1580        }
1581    }
1582
1583    #[async_trait::async_trait]
1584    impl dynamo_runtime::engine::AsyncEngineContext for MockContext {
1585        fn id(&self) -> &str {
1586            &self.id
1587        }
1588
1589        fn stop(&self) {
1590            // No-op for testing
1591        }
1592
1593        fn stop_generating(&self) {
1594            // No-op for testing
1595        }
1596
1597        fn kill(&self) {
1598            // No-op for testing
1599        }
1600
1601        fn is_stopped(&self) -> bool {
1602            false
1603        }
1604
1605        fn is_killed(&self) -> bool {
1606            false
1607        }
1608
1609        async fn stopped(&self) {
1610            // No-op for testing
1611        }
1612
1613        async fn killed(&self) {
1614            // No-op for testing
1615        }
1616
1617        fn link_child(&self, _: Arc<dyn dynamo_runtime::engine::AsyncEngineContext>) {
1618            // No-op for testing
1619        }
1620    }
1621}