use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use crate::perf::RecordedStream;
use crate::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse;
pub enum LogprobType {
Normalized,
Unnormalized,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct TokenLogprob {
pub token: String,
pub logprob: f32,
pub bytes: Option<Vec<u8>>,
}
#[derive(Debug, Clone)]
pub struct TokenLogProbs {
selected: TokenLogprob,
alternatives: Vec<TokenLogprob>,
all_sorted: Vec<TokenLogprob>,
}
impl TokenLogProbs {
pub fn new(selected: TokenLogprob, mut alternatives: Vec<TokenLogprob>) -> Self {
alternatives.sort_by(|a, b| b.logprob.partial_cmp(&a.logprob).unwrap());
let mut all_sorted = Vec::new();
let mut added_selected = false;
let selected_in_alternatives = alternatives.iter().any(|alt| {
alt.token == selected.token && (alt.logprob - selected.logprob).abs() < 1e-6
});
if !selected_in_alternatives {
let mut insert_position = alternatives.len();
for (i, alt) in alternatives.iter().enumerate() {
if selected.logprob > alt.logprob {
insert_position = i;
break;
}
}
for (i, alt) in alternatives.iter().enumerate() {
if i == insert_position && !added_selected {
all_sorted.push(selected.clone());
added_selected = true;
}
all_sorted.push(alt.clone());
}
if !added_selected {
all_sorted.push(selected.clone());
}
} else {
all_sorted = alternatives.clone();
}
Self {
selected,
alternatives,
all_sorted,
}
}
pub fn selected_token(&self) -> &TokenLogprob {
&self.selected
}
pub fn alternative_tokens(&self) -> &[TokenLogprob] {
&self.alternatives
}
pub fn all_tokens(&self) -> &[TokenLogprob] {
&self.all_sorted
}
}
pub trait LogprobExtractor {
fn extract_logprobs_by_choice(&self) -> HashMap<u32, Vec<TokenLogProbs>>;
}
impl LogprobExtractor for NvCreateChatCompletionStreamResponse {
fn extract_logprobs_by_choice(&self) -> HashMap<u32, Vec<TokenLogProbs>> {
let mut result = HashMap::new();
for choice in &self.choices {
let choice_index = choice.index;
let choice_logprobs = choice
.logprobs
.as_ref()
.and_then(|logprobs| logprobs.content.as_ref())
.map(|content| {
content
.iter()
.map(|token_logprob| {
let selected_token = TokenLogprob {
token: token_logprob.token.clone(),
logprob: token_logprob.logprob,
bytes: token_logprob.bytes.clone(),
};
let alternatives: Vec<TokenLogprob> = token_logprob
.top_logprobs
.iter()
.map(|top_logprob| TokenLogprob {
token: top_logprob.token.clone(),
logprob: top_logprob.logprob,
bytes: top_logprob.bytes.clone(),
})
.collect();
TokenLogProbs::new(selected_token, alternatives)
})
.collect::<Vec<_>>()
})
.unwrap_or_default();
result.insert(choice_index, choice_logprobs);
}
result
}
}
pub fn validate_and_flatten_choices(
choice_logprobs: HashMap<u32, Vec<TokenLogProbs>>,
) -> Result<Vec<Vec<TokenLogProbs>>, String> {
if choice_logprobs.is_empty() {
return Ok(Vec::new());
}
let max_choice = *choice_logprobs.keys().max().unwrap();
let expected_count = (max_choice + 1) as usize;
if choice_logprobs.len() != expected_count {
return Err(format!(
"Missing choice indices: expected {} choices [0, {}), but found {} choices: {:?}",
expected_count,
max_choice + 1,
choice_logprobs.len(),
choice_logprobs.keys().collect::<Vec<_>>()
));
}
for i in 0..=max_choice {
if !choice_logprobs.contains_key(&i) {
return Err(format!(
"Missing choice index {}: expected [0, {}), found {:?}",
i,
max_choice + 1,
choice_logprobs.keys().collect::<Vec<_>>()
));
}
}
let mut result = Vec::with_capacity(expected_count);
for i in 0..=max_choice {
result.push(choice_logprobs[&i].clone());
}
Ok(result)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SensitivityAnalysis {
pub total_responses: usize,
pub choice_analyses: HashMap<u32, ChoiceAnalysis>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChoiceAnalysis {
pub choice_index: u32,
pub position_closeness: Vec<PositionCloseness>,
pub positions_analyzed: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PositionCloseness {
pub stream_position: usize,
pub token_position: usize,
pub logprob_difference: f32,
pub probability_difference: f32,
pub probability_remaining: f32,
pub candidates: Vec<TokenLogprob>,
}
#[derive(Debug, Clone)]
pub struct ClosePosition {
pub stream_position: usize,
pub token_position: usize,
pub logprob_difference: f32,
pub probability_difference: f32,
pub probability_remaining: f32,
pub top_candidates: Vec<TokenLogprob>,
}
pub fn analyze_logprob_sensitivity(
recorded_stream: Arc<RecordedStream<impl LogprobExtractor>>,
) -> SensitivityAnalysis {
let mut choice_analyses: HashMap<u32, ChoiceAnalysis> = HashMap::new();
let mut choice_sequence_positions: HashMap<u32, usize> = HashMap::new();
for (stream_pos, timestamped_response) in recorded_stream.responses().iter().enumerate() {
let response = ×tamped_response.response;
let logprobs_by_choice = response.extract_logprobs_by_choice();
for (choice_index, choice_logprobs) in logprobs_by_choice {
let choice_analysis =
choice_analyses
.entry(choice_index)
.or_insert_with(|| ChoiceAnalysis {
choice_index,
position_closeness: Vec::new(),
positions_analyzed: 0,
});
let current_seq_pos = choice_sequence_positions.entry(choice_index).or_insert(0);
for token_logprobs in choice_logprobs {
let all_tokens = token_logprobs.all_tokens();
if all_tokens.len() < 2 {
*current_seq_pos += 1;
continue;
}
let sorted_candidates = all_tokens.to_vec();
let logprob_difference =
sorted_candidates[0].logprob - sorted_candidates[1].logprob;
let prob1 = sorted_candidates[0].logprob.exp();
let prob2 = sorted_candidates[1].logprob.exp();
let probability_difference = prob1 - prob2;
let total_prob_sum: f32 = sorted_candidates.iter().map(|t| t.logprob.exp()).sum();
let probability_remaining = 1.0 - total_prob_sum;
choice_analysis.position_closeness.push(PositionCloseness {
stream_position: stream_pos,
token_position: *current_seq_pos,
logprob_difference,
probability_difference,
probability_remaining,
candidates: sorted_candidates,
});
choice_analysis.positions_analyzed += 1;
*current_seq_pos += 1;
}
}
}
for choice_analysis in choice_analyses.values_mut() {
choice_analysis.position_closeness.sort_by(|a, b| {
a.probability_difference
.partial_cmp(&b.probability_difference)
.unwrap()
});
}
SensitivityAnalysis {
total_responses: recorded_stream.responses().len(),
choice_analyses,
}
}
impl SensitivityAnalysis {
pub fn get_close_positions_for_choice(
&self,
choice_index: u32,
threshold: f32,
) -> Vec<&PositionCloseness> {
self.choice_analyses
.get(&choice_index)
.map(|analysis| {
analysis
.position_closeness
.iter()
.filter(|pos| pos.probability_difference <= threshold)
.collect()
})
.unwrap_or_default()
}
pub fn get_closest_positions_for_choice(
&self,
choice_index: u32,
count: usize,
) -> Vec<&PositionCloseness> {
self.choice_analyses
.get(&choice_index)
.map(|analysis| analysis.position_closeness.iter().take(count).collect())
.unwrap_or_default()
}
pub fn print_summary(&self) {
println!("=== Logprob Sensitivity Analysis Summary ===");
println!("Total stream responses analyzed: {}", self.total_responses);
println!("Number of choices: {}", self.choice_analyses.len());
println!();
for (choice_index, choice_analysis) in &self.choice_analyses {
println!(
"Choice {}: {} positions analyzed",
choice_index, choice_analysis.positions_analyzed
);
if !choice_analysis.position_closeness.is_empty() {
println!(" Closest positions (smallest probability differences):");
for (j, pos) in choice_analysis
.position_closeness
.iter()
.take(3)
.enumerate()
{
let top_token = &pos.candidates[0].token;
let second_token = &pos.candidates[1].token;
let prob1 = pos.candidates[0].logprob.exp();
let prob2 = pos.candidates[1].logprob.exp();
println!(
" {}: Stream pos {}, token pos {} - '{}' ({:.1}%) vs '{}' ({:.1}%) (prob diff: {:.4})",
j + 1,
pos.stream_position,
pos.token_position,
top_token,
prob1 * 100.0,
second_token,
prob2 * 100.0,
pos.probability_difference
);
}
}
println!();
}
}
pub fn close_position_percentage_for_choice(&self, choice_index: u32, threshold: f32) -> f32 {
if let Some(analysis) = self.choice_analyses.get(&choice_index) {
if analysis.positions_analyzed == 0 {
return 0.0;
}
let close_count = analysis
.position_closeness
.iter()
.filter(|pos| pos.probability_difference <= threshold)
.count();
(close_count as f32 / analysis.positions_analyzed as f32) * 100.0
} else {
0.0
}
}
pub fn detect_multiple_close_tokens(
&self,
choice_index: u32,
threshold: f32,
) -> Vec<MultipleCloseTokens> {
let mut results = Vec::new();
if let Some(analysis) = self.choice_analyses.get(&choice_index) {
for pos in &analysis.position_closeness {
let close_tokens = self.count_close_tokens_at_position(pos, threshold);
if close_tokens.close_count > 2 {
results.push(close_tokens);
}
}
}
results
}
pub fn detect_likely_greedy_decoding(&self, choice_index: u32) -> bool {
if let Some(analysis) = self.choice_analyses.get(&choice_index) {
if analysis.positions_analyzed == 0 {
return true; }
let likely_greedy_positions = analysis
.position_closeness
.iter()
.filter(|pos| {
if pos.candidates.is_empty() {
return true; }
pos.probability_difference < 0.01 || pos.probability_difference > 0.05
})
.count();
(likely_greedy_positions as f32 / analysis.positions_analyzed as f32) > 0.5
} else {
false
}
}
pub fn greedy_selection_percentage(&self, choice_index: u32) -> f32 {
if let Some(analysis) = self.choice_analyses.get(&choice_index) {
if analysis.positions_analyzed == 0 {
return 0.0;
}
let greedy_like_positions = analysis
.position_closeness
.iter()
.filter(|pos| {
pos.probability_difference < 0.01 || pos.probability_difference > 0.05
})
.count();
(greedy_like_positions as f32 / analysis.positions_analyzed as f32) * 100.0
} else {
0.0
}
}
fn count_close_tokens_at_position(
&self,
position: &PositionCloseness,
threshold: f32,
) -> MultipleCloseTokens {
let top_prob = position.candidates[0].logprob.exp();
let mut close_count = 1; let mut close_tokens = vec![position.candidates[0].clone()];
for candidate in &position.candidates[1..] {
let candidate_prob = candidate.logprob.exp();
let prob_diff = top_prob - candidate_prob;
if prob_diff <= threshold {
close_count += 1;
close_tokens.push(candidate.clone());
} else {
break; }
}
let max_difference = if close_count > 1 {
let last_prob = close_tokens.last().unwrap().logprob.exp();
top_prob - last_prob
} else {
0.0
};
MultipleCloseTokens {
stream_position: position.stream_position,
token_position: position.token_position,
close_count,
close_tokens,
max_difference,
}
}
}
#[derive(Debug, Clone)]
pub struct MultipleCloseTokens {
pub stream_position: usize,
pub token_position: usize,
pub close_count: usize,
pub close_tokens: Vec<TokenLogprob>,
pub max_difference: f32,
}
#[cfg(test)]
mod tests {
use super::*;
type TestTokenAlternative = (&'static str, f32);
type TestTokenData = (&'static str, f32, Vec<TestTokenAlternative>);
type TestTokenDataVec = Vec<TestTokenData>;
use crate::perf::{RecordingMode, TimestampedResponse, record_stream_with_context};
use crate::protocols::codec::create_message_stream;
use crate::protocols::convert_sse_stream;
use approx::assert_abs_diff_eq;
use dynamo_async_openai::types::{
ChatChoiceLogprobs, ChatChoiceStream, ChatCompletionStreamResponseDelta,
ChatCompletionTokenLogprob, FinishReason, Role, TopLogprobs,
};
use futures::StreamExt;
use std::sync::Arc;
use std::time::Instant;
const FLOAT_EPSILON: f32 = 1e-6;
#[test]
fn test_two_tokens_close() {
let analysis = create_analysis_with_logprobs(vec![create_token_logprob_from_linear_probs(
"hello",
0.45,
vec![("world", 0.44)], )]);
let close_positions = analysis.get_close_positions_for_choice(0, 0.1);
assert_eq!(close_positions.len(), 1);
assert_abs_diff_eq!(
close_positions[0].probability_difference,
0.01,
epsilon = FLOAT_EPSILON
);
assert_abs_diff_eq!(
close_positions[0].logprob_difference,
0.023,
epsilon = 0.001
);
let multiple_close = analysis.detect_multiple_close_tokens(0, 0.05);
assert_eq!(multiple_close.len(), 0); }
#[test]
fn test_three_tokens_close() {
let analysis = create_analysis_with_logprobs(vec![create_token_logprob_from_linear_probs(
"hello",
0.35,
vec![
("world", 0.33), ("there", 0.32), ],
)]);
let close_positions = analysis.get_close_positions_for_choice(0, 0.025);
assert_eq!(close_positions.len(), 1);
assert_abs_diff_eq!(
close_positions[0].probability_difference,
0.02,
epsilon = FLOAT_EPSILON
);
let multiple_close = analysis.detect_multiple_close_tokens(0, 0.04);
assert_eq!(multiple_close.len(), 1);
assert_eq!(multiple_close[0].close_count, 3);
assert_abs_diff_eq!(
multiple_close[0].max_difference,
0.03,
epsilon = FLOAT_EPSILON
);
}
#[test]
fn test_four_tokens_close() {
let analysis = create_analysis_with_logprobs(vec![create_token_logprob_from_linear_probs(
"hello",
0.27,
vec![
("world", 0.26), ("there", 0.25), ("friend", 0.22), ],
)]);
let close_positions = analysis.get_close_positions_for_choice(0, 0.02);
assert_eq!(close_positions.len(), 1);
assert_abs_diff_eq!(
close_positions[0].probability_difference,
0.01,
epsilon = FLOAT_EPSILON
);
let multiple_close = analysis.detect_multiple_close_tokens(0, 0.06);
assert_eq!(multiple_close.len(), 1);
assert_eq!(multiple_close[0].close_count, 4);
assert_abs_diff_eq!(
multiple_close[0].max_difference,
0.05,
epsilon = FLOAT_EPSILON
);
}
#[test]
fn test_multiple_choices_analysis() {
let analysis = create_analysis_with_multiple_choices(vec![
vec![create_token_logprob_from_linear_probs(
"hello",
0.7,
vec![("world", 0.25)],
)],
vec![create_token_logprob_from_linear_probs(
"hi",
0.505,
vec![("there", 0.495)],
)],
]);
assert_eq!(analysis.choice_analyses.len(), 2);
let choice0_close = analysis.get_close_positions_for_choice(0, 0.5);
assert_eq!(choice0_close.len(), 1);
assert_abs_diff_eq!(
choice0_close[0].probability_difference,
0.45,
epsilon = FLOAT_EPSILON
);
let choice1_close = analysis.get_close_positions_for_choice(1, 0.5);
assert_eq!(choice1_close.len(), 1);
assert_abs_diff_eq!(
choice1_close[0].probability_difference,
0.01,
epsilon = FLOAT_EPSILON
);
assert!(choice1_close[0].probability_difference < choice0_close[0].probability_difference);
}
#[test]
fn test_edge_case_single_token() {
let analysis = create_analysis_with_logprobs(vec![create_token_logprob_from_linear_probs(
"hello",
1.0,
vec![],
)]);
let close_positions = analysis.get_close_positions_for_choice(0, 1.0);
assert_eq!(close_positions.len(), 0); }
#[test]
fn test_threshold_filtering() {
let analysis = create_analysis_with_logprobs(vec![
create_token_logprob_from_linear_probs("token1", 0.55, vec![("token2", 0.45)]),
create_token_logprob_from_linear_probs("token3", 0.8, vec![("token4", 0.2)]),
]);
let close_strict = analysis.get_close_positions_for_choice(0, 0.15);
assert_eq!(close_strict.len(), 1);
assert_abs_diff_eq!(
close_strict[0].probability_difference,
0.1,
epsilon = FLOAT_EPSILON
);
let close_permissive = analysis.get_close_positions_for_choice(0, 0.7);
assert_eq!(close_permissive.len(), 2);
assert!(
close_permissive[0].probability_difference < close_permissive[1].probability_difference
);
}
#[test]
fn test_percentage_calculation() {
let analysis = create_analysis_with_logprobs(vec![
create_token_logprob_from_linear_probs("token1", 0.6, vec![("token2", 0.4)]),
create_token_logprob_from_linear_probs("token3", 0.9, vec![("token4", 0.1)]),
create_token_logprob_from_linear_probs("token5", 0.52, vec![("token6", 0.48)]),
]);
let percentage = analysis.close_position_percentage_for_choice(0, 0.25);
assert!((percentage - 66.67).abs() < 0.01); }
#[test]
fn test_real_vllm_equal_logprobs() {
let analysis = create_analysis_with_logprobs(vec![create_token_logprob_from_linear_probs(
"Ġblock",
0.403,
vec![("Ġchunk", 0.403)], )]);
let close_positions = analysis.get_close_positions_for_choice(0, 0.001);
assert_eq!(close_positions.len(), 1);
assert_abs_diff_eq!(
close_positions[0].probability_difference,
0.0,
epsilon = FLOAT_EPSILON
);
let position = &close_positions[0];
assert_eq!(position.candidates.len(), 2);
let tokens: Vec<&str> = position
.candidates
.iter()
.map(|c| c.token.as_str())
.collect();
assert!(tokens.contains(&"Ġblock"));
assert!(tokens.contains(&"Ġchunk"));
assert_abs_diff_eq!(
position.candidates[0].logprob,
position.candidates[1].logprob,
epsilon = FLOAT_EPSILON
);
let prob1 = position.candidates[0].logprob.exp();
let prob2 = position.candidates[1].logprob.exp();
assert_abs_diff_eq!(prob1, 0.403, epsilon = 0.001);
assert_abs_diff_eq!(prob2, 0.403, epsilon = 0.001);
}
fn create_analysis_with_logprobs(
token_logprobs: Vec<ChatCompletionTokenLogprob>,
) -> SensitivityAnalysis {
let start_time = Instant::now();
let response = create_mock_response_with_logprobs(token_logprobs);
let responses = vec![TimestampedResponse::new(response, 0)];
let recorded_stream = RecordedStream::new(responses, start_time, Instant::now());
let arc_stream = Arc::new(recorded_stream);
analyze_logprob_sensitivity(arc_stream)
}
fn create_analysis_with_multiple_choices(
choices_logprobs: Vec<Vec<ChatCompletionTokenLogprob>>,
) -> SensitivityAnalysis {
let start_time = Instant::now();
let response = create_mock_response_with_multiple_choices(choices_logprobs);
let responses = vec![TimestampedResponse::new(response, 0)];
let recorded_stream = RecordedStream::new(responses, start_time, Instant::now());
let arc_stream = Arc::new(recorded_stream);
analyze_logprob_sensitivity(arc_stream)
}
fn create_analysis_with_mixed_sampling(mixed_data: TestTokenDataVec) -> SensitivityAnalysis {
let start_time = Instant::now();
let token_logprobs: Vec<ChatCompletionTokenLogprob> = mixed_data
.into_iter()
.map(|(selected_token, selected_prob, alternatives)| {
create_token_logprob_from_linear_probs(selected_token, selected_prob, alternatives)
})
.collect();
let response = create_mock_response_with_logprobs(token_logprobs);
let responses = vec![TimestampedResponse::new(response, 0)];
let recorded_stream = RecordedStream::new(responses, start_time, Instant::now());
let arc_stream = Arc::new(recorded_stream);
analyze_logprob_sensitivity(arc_stream)
}
fn create_analysis_with_missing_selected_token() -> SensitivityAnalysis {
let start_time = Instant::now();
let token_logprobs = vec![ChatCompletionTokenLogprob {
token: "unlikely_selection".to_string(),
logprob: (0.15_f32).ln(), bytes: None,
top_logprobs: vec![
TopLogprobs {
token: "best_option".to_string(),
logprob: (0.4_f32).ln(), bytes: None,
},
TopLogprobs {
token: "second_best".to_string(),
logprob: (0.3_f32).ln(), bytes: None,
},
],
}];
let response = create_mock_response_with_logprobs(token_logprobs);
let responses = vec![TimestampedResponse::new(response, 0)];
let recorded_stream = RecordedStream::new(responses, start_time, Instant::now());
let arc_stream = Arc::new(recorded_stream);
analyze_logprob_sensitivity(arc_stream)
}
fn create_token_logprob_from_linear_probs(
token: &str,
prob: f32,
top_probs: Vec<(&str, f32)>,
) -> ChatCompletionTokenLogprob {
assert!(
(0.0..=1.0).contains(&prob),
"Probability must be in [0, 1]: {}",
prob
);
let total_prob = prob + top_probs.iter().map(|(_, p)| p).sum::<f32>();
assert!(
total_prob <= 1.001,
"Total probability mass exceeds 1: {}",
total_prob
);
for (_, p) in &top_probs {
assert!(
*p >= 0.0 && *p <= 1.0,
"Probability must be in [0, 1]: {}",
p
);
}
ChatCompletionTokenLogprob {
token: token.to_string(),
logprob: prob.ln(),
bytes: None,
top_logprobs: top_probs
.into_iter()
.map(|(t, p)| TopLogprobs {
token: t.to_string(),
logprob: p.ln(),
bytes: None,
})
.collect(),
}
}
fn create_mock_response_with_logprobs(
token_logprobs: Vec<ChatCompletionTokenLogprob>,
) -> NvCreateChatCompletionStreamResponse {
#[expect(deprecated)]
NvCreateChatCompletionStreamResponse {
id: "test_id".to_string(),
choices: vec![ChatChoiceStream {
index: 0,
delta: ChatCompletionStreamResponseDelta {
content: Some(
dynamo_async_openai::types::ChatCompletionMessageContent::Text(
"test".to_string(),
),
),
function_call: None,
tool_calls: None,
role: Some(Role::Assistant),
refusal: None,
reasoning_content: None,
},
finish_reason: Some(FinishReason::Stop),
stop_reason: None,
logprobs: Some(ChatChoiceLogprobs {
content: Some(token_logprobs),
refusal: None,
}),
}],
created: 1234567890,
model: "test-model".to_string(),
service_tier: None,
system_fingerprint: None,
object: "chat.completion.chunk".to_string(),
usage: None,
nvext: None,
}
}
fn create_mock_response_with_multiple_choices(
choices_logprobs: Vec<Vec<ChatCompletionTokenLogprob>>,
) -> NvCreateChatCompletionStreamResponse {
#[expect(deprecated)]
let choices = choices_logprobs
.into_iter()
.enumerate()
.map(|(i, token_logprobs)| ChatChoiceStream {
index: i as u32,
delta: ChatCompletionStreamResponseDelta {
content: Some(
dynamo_async_openai::types::ChatCompletionMessageContent::Text(
"test".to_string(),
),
),
function_call: None,
tool_calls: None,
role: Some(Role::Assistant),
refusal: None,
reasoning_content: None,
},
finish_reason: Some(FinishReason::Stop),
stop_reason: None,
logprobs: Some(ChatChoiceLogprobs {
content: Some(token_logprobs),
refusal: None,
}),
})
.collect();
NvCreateChatCompletionStreamResponse {
id: "test_id".to_string(),
choices,
created: 1234567890,
model: "test-model".to_string(),
service_tier: None,
system_fingerprint: None,
object: "chat.completion.chunk".to_string(),
usage: None,
nvext: None,
}
}
#[test]
fn test_sensitivity_analysis() {
let start_time = Instant::now();
let responses = vec![TimestampedResponse::new(create_mock_response(), 0)];
let recorded_stream = RecordedStream::new(responses, start_time, Instant::now());
let arc_stream = Arc::new(recorded_stream);
let analysis = analyze_logprob_sensitivity(arc_stream);
assert_eq!(analysis.total_responses, 1);
assert!(analysis.close_position_percentage_for_choice(0, 0.5) >= 0.0);
}
#[test]
fn test_extract_logprobs_by_choice_empty() {
let response = create_mock_response();
let logprobs = response.extract_logprobs_by_choice();
assert!(logprobs.is_empty() || logprobs.values().any(|v| v.is_empty()));
}
#[test]
fn test_token_logprobs_struct() {
let selected = TokenLogprob {
token: "selected".to_string(),
logprob: 0.7_f32.ln(), bytes: None,
};
let alternatives = vec![
TokenLogprob {
token: "alt1".to_string(),
logprob: 0.2_f32.ln(), bytes: None,
},
TokenLogprob {
token: "alt2".to_string(),
logprob: 0.1_f32.ln(), bytes: None,
},
];
let token_logprobs = TokenLogProbs::new(selected.clone(), alternatives.clone());
assert_eq!(token_logprobs.selected_token(), &selected);
assert_eq!(token_logprobs.alternative_tokens().len(), 2);
assert_eq!(token_logprobs.all_tokens().len(), 3);
let all_tokens = token_logprobs.all_tokens();
assert_eq!(all_tokens[0].token, "selected"); assert_eq!(all_tokens[1].token, "alt1"); assert_eq!(all_tokens[2].token, "alt2");
let alt_tokens = token_logprobs.alternative_tokens();
assert_eq!(alt_tokens[0].token, "alt1"); assert_eq!(alt_tokens[1].token, "alt2"); }
#[test]
fn test_token_logprobs_selected_in_alternatives() {
let selected = TokenLogprob {
token: "token".to_string(),
logprob: 0.4_f32.ln(), bytes: None,
};
let alternatives = vec![
TokenLogprob {
token: "token".to_string(),
logprob: 0.4_f32.ln(), bytes: None,
},
TokenLogprob {
token: "other".to_string(),
logprob: 0.3_f32.ln(), bytes: None,
},
];
let token_logprobs = TokenLogProbs::new(selected, alternatives.clone());
let all_tokens = token_logprobs.all_tokens();
assert_eq!(all_tokens.len(), 2);
assert_eq!(all_tokens[0].token, "token"); assert_eq!(all_tokens[1].token, "other"); }
#[test]
fn test_validate_and_flatten_choices() {
let mut choices = HashMap::new();
choices.insert(0, vec![]);
choices.insert(1, vec![]);
choices.insert(2, vec![]);
let result = validate_and_flatten_choices(choices);
assert!(result.is_ok());
let flattened = result.unwrap();
assert_eq!(flattened.len(), 3);
let mut choices = HashMap::new();
choices.insert(0, vec![]);
choices.insert(2, vec![]);
let result = validate_and_flatten_choices(choices);
assert!(result.is_err());
let error_msg = result.unwrap_err();
assert!(
error_msg.contains("Missing choice indices")
&& error_msg.contains("expected 3 choices")
);
let choices = HashMap::new();
let result = validate_and_flatten_choices(choices);
assert!(result.is_ok());
assert_eq!(result.unwrap().len(), 0);
}
#[test]
fn test_probability_remaining_calculation() {
let analysis = create_analysis_with_logprobs(vec![create_token_logprob_from_linear_probs(
"token",
0.4, vec![
("alt1", 0.3), ("alt2", 0.1), ],
)]);
let close_positions = analysis.get_close_positions_for_choice(0, 1.0);
assert_eq!(close_positions.len(), 1);
let position = &close_positions[0];
assert_abs_diff_eq!(position.probability_remaining, 0.2, epsilon = 0.01);
let analysis_complete =
create_analysis_with_logprobs(vec![create_token_logprob_from_linear_probs(
"token",
0.5, vec![
("alt1", 0.3), ("alt2", 0.2), ],
)]);
let complete_positions = analysis_complete.get_close_positions_for_choice(0, 1.0);
assert_eq!(complete_positions.len(), 1);
let complete_position = &complete_positions[0];
assert_abs_diff_eq!(complete_position.probability_remaining, 0.0, epsilon = 0.01);
}
#[test]
fn test_position_closeness_ordering() {
let analysis = create_analysis_with_logprobs(vec![
create_token_logprob_from_linear_probs("far", 0.85, vec![("alt", 0.15)]),
create_token_logprob_from_linear_probs("close", 0.51, vec![("alt", 0.49)]),
create_token_logprob_from_linear_probs("medium", 0.7, vec![("alt", 0.3)]),
]);
let positions = &analysis.choice_analyses.get(&0).unwrap().position_closeness;
assert_eq!(positions.len(), 3);
assert!(positions[0].probability_difference <= positions[1].probability_difference);
assert!(positions[1].probability_difference <= positions[2].probability_difference);
assert_abs_diff_eq!(
positions[0].probability_difference,
0.02,
epsilon = FLOAT_EPSILON
);
assert_abs_diff_eq!(
positions[1].probability_difference,
0.4,
epsilon = FLOAT_EPSILON
);
assert_abs_diff_eq!(
positions[2].probability_difference,
0.7,
epsilon = FLOAT_EPSILON
);
}
#[test]
fn test_multiple_close_tokens_edge_cases() {
let analysis = create_analysis_with_logprobs(vec![create_token_logprob_from_linear_probs(
"token",
0.34,
vec![
("alt1", 0.33), ("alt2", 0.32), ("alt3", 0.01), ],
)]);
let multiple_close = analysis.detect_multiple_close_tokens(0, 0.025);
assert_eq!(multiple_close.len(), 1);
assert_eq!(multiple_close[0].close_count, 3);
}
#[test]
fn test_choice_analysis_independence() {
let analysis = create_analysis_with_multiple_choices(vec![
vec![
create_token_logprob_from_linear_probs("token1", 0.55, vec![("alt1", 0.45)]), create_token_logprob_from_linear_probs("token2", 0.9, vec![("alt2", 0.1)]), ],
vec![
create_token_logprob_from_linear_probs("token3", 0.501, vec![("alt3", 0.499)]), ],
]);
assert_eq!(analysis.choice_analyses.len(), 2);
assert_eq!(
analysis.choice_analyses.get(&0).unwrap().positions_analyzed,
2
);
assert_eq!(
analysis.choice_analyses.get(&1).unwrap().positions_analyzed,
1
);
let choice0_close = analysis.get_close_positions_for_choice(0, 0.5);
let choice1_close = analysis.get_close_positions_for_choice(1, 0.5);
assert_eq!(choice0_close.len(), 1);
assert_eq!(choice1_close.len(), 1);
assert!(choice1_close[0].probability_difference < choice0_close[0].probability_difference);
}
#[test]
fn test_get_closest_positions_boundary() {
let analysis = create_analysis_with_logprobs(vec![
create_token_logprob_from_linear_probs("token1", 0.6, vec![("alt1", 0.4)]),
create_token_logprob_from_linear_probs("token2", 0.75, vec![("alt2", 0.25)]),
]);
let closest = analysis.get_closest_positions_for_choice(0, 10);
assert_eq!(closest.len(), 2);
let closest = analysis.get_closest_positions_for_choice(0, 2);
assert_eq!(closest.len(), 2);
let closest = analysis.get_closest_positions_for_choice(0, 1);
assert_eq!(closest.len(), 1);
}
#[test]
fn test_zero_threshold() {
let analysis = create_analysis_with_logprobs(vec![
create_token_logprob_from_linear_probs("token", 0.5, vec![("alt", 0.5)]), ]);
let close_positions = analysis.get_close_positions_for_choice(0, 0.0);
assert_eq!(close_positions.len(), 1);
assert_abs_diff_eq!(
close_positions[0].probability_difference,
0.0,
epsilon = FLOAT_EPSILON
);
}
#[test]
fn test_nonexistent_choice() {
let analysis = create_analysis_with_logprobs(vec![create_token_logprob_from_linear_probs(
"token",
0.6,
vec![("alt", 0.4)],
)]);
let close_positions = analysis.get_close_positions_for_choice(5, 0.1);
assert!(close_positions.is_empty());
let closest = analysis.get_closest_positions_for_choice(5, 3);
assert!(closest.is_empty());
let percentage = analysis.close_position_percentage_for_choice(5, 0.1);
assert_eq!(percentage, 0.0);
}
#[test]
fn test_logprob_extractor_with_missing_data() {
#[expect(deprecated)]
let response = NvCreateChatCompletionStreamResponse {
id: "test_id".to_string(),
choices: vec![ChatChoiceStream {
index: 0,
delta: ChatCompletionStreamResponseDelta {
content: Some(
dynamo_async_openai::types::ChatCompletionMessageContent::Text(
"test".to_string(),
),
),
function_call: None,
tool_calls: None,
role: Some(Role::Assistant),
refusal: None,
reasoning_content: None,
},
finish_reason: Some(FinishReason::Stop),
stop_reason: None,
logprobs: None, }],
created: 1234567890,
model: "test-model".to_string(),
service_tier: None,
system_fingerprint: None,
object: "chat.completion.chunk".to_string(),
usage: None,
nvext: None,
};
let logprobs = response.extract_logprobs_by_choice();
assert_eq!(logprobs.len(), 1);
assert!(logprobs.values().any(|v| v.is_empty()));
}
#[test]
fn test_print_summary_no_panic() {
let analysis = create_analysis_with_logprobs(vec![create_token_logprob_from_linear_probs(
"token",
0.6,
vec![("alt", 0.4)],
)]);
analysis.print_summary();
}
#[test]
fn test_greedy_decoding_detection() {
let analysis = create_analysis_with_logprobs(vec![
create_token_logprob_from_linear_probs(
"best",
0.8,
vec![("second", 0.15), ("third", 0.05)],
),
create_token_logprob_from_linear_probs(
"optimal",
0.7,
vec![("suboptimal", 0.2), ("bad", 0.1)],
),
]);
let is_greedy = analysis.detect_likely_greedy_decoding(0);
assert!(is_greedy);
let greedy_percentage = analysis.greedy_selection_percentage(0);
assert!(greedy_percentage > 90.0); }
#[test]
fn test_non_greedy_decoding_detection() {
let analysis = create_analysis_with_mixed_sampling(vec![
("selected_best", 0.6, vec![("alternative", 0.4)]),
(
"close_choice",
0.35,
vec![("very_close", 0.33), ("also_close", 0.32)],
),
]);
let _is_greedy = analysis.detect_likely_greedy_decoding(0);
let greedy_percentage = analysis.greedy_selection_percentage(0);
assert!((0.0..=100.0).contains(&greedy_percentage)); }
#[test]
fn test_selected_token_not_in_top_logprobs() {
let analysis = create_analysis_with_missing_selected_token();
let greedy_percentage = analysis.greedy_selection_percentage(0);
assert!((0.0..=100.0).contains(&greedy_percentage)); }
#[test]
fn test_equal_logprobs_greedy_detection() {
let analysis = create_analysis_with_logprobs(vec![create_token_logprob_from_linear_probs(
"Ġblock",
0.403,
vec![("Ġchunk", 0.403)], )]);
let close_positions = analysis.get_close_positions_for_choice(0, 0.001);
assert_eq!(close_positions.len(), 1);
let is_greedy = analysis.detect_likely_greedy_decoding(0);
assert!(is_greedy);
}
#[tokio::test]
async fn test_real_sse_stream_analysis() {
let data = std::fs::read_to_string(
"tests/data/replays/deepseek-r1-distill-llama-8b/chat-completions.stream.1",
)
.expect("Failed to read test data file");
let sse_stream = create_message_stream(&data);
let response_stream =
convert_sse_stream::<NvCreateChatCompletionStreamResponse>(Box::pin(sse_stream));
let filtered_stream = response_stream.filter_map(|annotated| async move { annotated.data });
let ctx = Arc::new(MockContext::new());
let (recorded_stream, recording_rx) =
record_stream_with_context(Box::pin(filtered_stream), ctx, RecordingMode::Sink);
let _collected: Vec<_> = recorded_stream.collect().await;
let recorded = recording_rx
.await
.expect("Failed to receive recorded stream");
assert!(recorded.response_count() > 0, "No responses recorded");
println!("Recorded {} responses", recorded.response_count());
let arc_recorded = Arc::new(recorded);
let analysis = analyze_logprob_sensitivity(arc_recorded);
analysis.print_summary();
assert!(
!analysis.choice_analyses.is_empty(),
"No choice analyses found"
);
assert!(
analysis
.choice_analyses
.values()
.any(|a| a.positions_analyzed > 0),
"No positions analyzed"
);
let close_positions = analysis.get_close_positions_for_choice(0, 0.001);
assert!(!close_positions.is_empty(), "No close positions found");
let equal_positions = close_positions
.iter()
.filter(|pos| pos.probability_difference < 0.0001)
.count();
if equal_positions > 0 {
println!(
"Found {} positions with nearly equal probabilities",
equal_positions
);
}
let closest_3 = analysis.get_closest_positions_for_choice(0, 3);
assert!(
closest_3.len() <= 3,
"Should return at most 3 closest positions"
);
let percentage = analysis.close_position_percentage_for_choice(0, 0.1);
assert!(
(0.0..=100.0).contains(&percentage),
"Percentage should be valid"
);
let is_greedy = analysis.detect_likely_greedy_decoding(0);
let greedy_percentage = analysis.greedy_selection_percentage(0);
println!(
"Greedy detection: {} ({}% greedy-like)",
is_greedy, greedy_percentage
);
let multiple_close = analysis.detect_multiple_close_tokens(0, 0.05);
if !multiple_close.is_empty() {
println!(
"Found {} positions with multiple close tokens",
multiple_close.len()
);
}
}
fn create_mock_response() -> NvCreateChatCompletionStreamResponse {
NvCreateChatCompletionStreamResponse {
id: "test_id".to_string(),
choices: vec![],
created: 1234567890,
model: "test-model".to_string(),
service_tier: None,
system_fingerprint: None,
object: "chat.completion.chunk".to_string(),
usage: None,
nvext: None,
}
}
#[derive(Debug)]
struct MockContext {
id: String,
}
impl MockContext {
fn new() -> Self {
Self {
id: "test-context".to_string(),
}
}
}
#[async_trait::async_trait]
impl dynamo_runtime::engine::AsyncEngineContext for MockContext {
fn id(&self) -> &str {
&self.id
}
fn stop(&self) {
}
fn stop_generating(&self) {
}
fn kill(&self) {
}
fn is_stopped(&self) -> bool {
false
}
fn is_killed(&self) -> bool {
false
}
async fn stopped(&self) {
}
async fn killed(&self) {
}
fn link_child(&self, _: Arc<dyn dynamo_runtime::engine::AsyncEngineContext>) {
}
}
}