use std::sync::Arc;
use std::time::Instant;
use dynamo_llm::perf::logprobs::analyze_logprob_sensitivity;
use dynamo_llm::perf::{RecordedStream, TimestampedResponse};
use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse;
use dynamo_protocols::types::{
ChatChoiceLogprobs, ChatChoiceStream, ChatCompletionMessageContent,
ChatCompletionStreamResponseDelta, ChatCompletionTokenLogprob, FinishReason, Role, TopLogprobs,
};
type TokenAlternative = (&'static str, f32);
type TokenData = (&'static str, f32, Vec<TokenAlternative>);
type TokenDataVec = Vec<TokenData>;
type StringTokenAlternative = (String, f32);
type StringTokenData = (String, f32, Vec<StringTokenAlternative>);
type ChoiceTokenData = Vec<StringTokenData>;
type MultiChoiceData = Vec<ChoiceTokenData>;
#[test]
fn test_realistic_streaming_analysis() {
let stream = create_realistic_stream();
let analysis = analyze_logprob_sensitivity(stream);
assert_eq!(analysis.total_responses, 3);
assert_eq!(analysis.choice_analyses.len(), 1);
assert_eq!(
analysis.choice_analyses.get(&0).unwrap().positions_analyzed,
3
);
let positions = &analysis.choice_analyses.get(&0).unwrap().position_closeness;
for i in 1..positions.len() {
assert!(positions[i - 1].probability_difference <= positions[i].probability_difference);
}
let close_positions = analysis.get_close_positions_for_choice(0, 0.2);
assert!(!close_positions.is_empty());
let percentage = analysis.close_position_percentage_for_choice(0, 0.2);
assert!((0.0..=100.0).contains(&percentage));
}
#[test]
fn test_multiple_choices_independent_analysis() {
let stream = create_multi_choice_stream();
let analysis = analyze_logprob_sensitivity(stream);
assert_eq!(analysis.choice_analyses.len(), 2);
let choice0_count = analysis.choice_analyses.get(&0).unwrap().positions_analyzed;
let choice1_count = analysis.choice_analyses.get(&1).unwrap().positions_analyzed;
assert_eq!(choice0_count, 2);
assert_eq!(choice1_count, 2);
let choice0_close = analysis.get_close_positions_for_choice(0, 0.3);
let choice1_close = analysis.get_close_positions_for_choice(1, 0.3);
assert!(choice1_close.len() >= choice0_close.len());
}
#[test]
fn test_multiple_close_tokens_detection() {
let stream = create_stream_with_multiple_close_tokens();
let analysis = analyze_logprob_sensitivity(stream);
let multiple_close = analysis.detect_multiple_close_tokens(0, 0.05);
assert!(!multiple_close.is_empty());
let first_multiple = &multiple_close[0];
assert!(first_multiple.close_count >= 3);
assert!(first_multiple.max_difference <= 0.05);
for i in 1..first_multiple.close_tokens.len() {
let prob_top = first_multiple.close_tokens[0].logprob.exp();
let prob_current = first_multiple.close_tokens[i].logprob.exp();
let diff = prob_top - prob_current;
assert!(diff <= 0.05);
}
}
#[test]
fn test_edge_cases() {
let empty_stream = create_empty_stream();
let analysis = analyze_logprob_sensitivity(empty_stream);
assert_eq!(analysis.total_responses, 0);
assert!(analysis.choice_analyses.is_empty());
let single_token_stream = create_single_token_stream();
let analysis = analyze_logprob_sensitivity(single_token_stream);
let close_positions = analysis.get_close_positions_for_choice(0, 1.0);
assert!(close_positions.is_empty());
}
#[test]
fn test_threshold_sensitivity() {
let stream = create_graduated_closeness_stream();
let analysis = analyze_logprob_sensitivity(stream);
let strict_close = analysis.get_close_positions_for_choice(0, 0.01);
let permissive_close = analysis.get_close_positions_for_choice(0, 0.1);
let very_permissive_close = analysis.get_close_positions_for_choice(0, 0.5);
assert!(strict_close.len() <= permissive_close.len());
assert!(permissive_close.len() <= very_permissive_close.len());
let strict_pct = analysis.close_position_percentage_for_choice(0, 0.01);
let permissive_pct = analysis.close_position_percentage_for_choice(0, 0.1);
assert!(strict_pct <= permissive_pct);
}
#[test]
fn test_large_dataset_performance() {
let stream = create_large_stream(100, 5); let start_time = Instant::now();
let analysis = analyze_logprob_sensitivity(stream);
let elapsed = start_time.elapsed();
assert!(elapsed.as_millis() < 100);
assert_eq!(analysis.total_responses, 100);
assert_eq!(analysis.choice_analyses.len(), 5);
for i in 0..5 {
let choice_analysis = analysis.choice_analyses.get(&(i as u32)).unwrap();
assert_eq!(choice_analysis.choice_index, i as u32);
assert_eq!(choice_analysis.positions_analyzed, 100);
}
}
fn create_realistic_stream() -> Arc<RecordedStream<NvCreateChatCompletionStreamResponse>> {
let start_time = Instant::now();
let responses = vec![
TimestampedResponse::new(
create_response_with_linear_probs(
"Hello",
vec![("Hello", 0.6, vec![("Hi", 0.3), ("Hey", 0.1)])], ),
0,
),
TimestampedResponse::new(
create_response_with_linear_probs(
" world",
vec![(" world", 0.55, vec![(" there", 0.4), (" everyone", 0.05)])], ),
1,
),
TimestampedResponse::new(
create_response_with_linear_probs(
"!",
vec![("!", 0.8, vec![(".", 0.15), ("?", 0.05)])],
), 2,
),
];
let stream = RecordedStream::new(responses, start_time, Instant::now());
Arc::new(stream)
}
fn create_multi_choice_stream() -> Arc<RecordedStream<NvCreateChatCompletionStreamResponse>> {
let start_time = Instant::now();
let responses = vec![
TimestampedResponse::new(
create_multi_choice_response(vec![
vec![("token1".to_string(), 0.65, vec![("alt1".to_string(), 0.35)])],
vec![("token2".to_string(), 0.51, vec![("alt2".to_string(), 0.49)])],
]),
0,
),
TimestampedResponse::new(
create_multi_choice_response(vec![
vec![("token3".to_string(), 0.8, vec![("alt3".to_string(), 0.2)])],
vec![("token4".to_string(), 0.53, vec![("alt4".to_string(), 0.47)])],
]),
1,
),
];
let stream = RecordedStream::new(responses, start_time, Instant::now());
Arc::new(stream)
}
fn create_stream_with_multiple_close_tokens()
-> Arc<RecordedStream<NvCreateChatCompletionStreamResponse>> {
let start_time = Instant::now();
let responses = vec![TimestampedResponse::new(
create_response_with_linear_probs(
"test",
vec![(
"test",
0.27,
vec![
("best", 0.26), ("rest", 0.25), ("nest", 0.22), ],
)],
),
0,
)];
let stream = RecordedStream::new(responses, start_time, Instant::now());
Arc::new(stream)
}
fn create_empty_stream() -> Arc<RecordedStream<NvCreateChatCompletionStreamResponse>> {
let start_time = Instant::now();
let stream = RecordedStream::new(vec![], start_time, Instant::now());
Arc::new(stream)
}
fn create_single_token_stream() -> Arc<RecordedStream<NvCreateChatCompletionStreamResponse>> {
let start_time = Instant::now();
let responses = vec![TimestampedResponse::new(
create_response_with_linear_probs(
"only",
vec![
("only", 1.0, vec![]), ],
),
0,
)];
let stream = RecordedStream::new(responses, start_time, Instant::now());
Arc::new(stream)
}
fn create_graduated_closeness_stream() -> Arc<RecordedStream<NvCreateChatCompletionStreamResponse>>
{
let start_time = Instant::now();
let responses = vec![TimestampedResponse::new(
create_response_with_linear_probs(
"test",
vec![
("very_close", 0.501, vec![("alt1", 0.499)]), ("close", 0.55, vec![("alt2", 0.45)]), ("medium", 0.7, vec![("alt3", 0.3)]), ("far", 0.9, vec![("alt4", 0.1)]), ],
),
0,
)];
let stream = RecordedStream::new(responses, start_time, Instant::now());
Arc::new(stream)
}
fn create_large_stream(
positions: usize,
choices: usize,
) -> Arc<RecordedStream<NvCreateChatCompletionStreamResponse>> {
let start_time = Instant::now();
let mut responses = Vec::new();
for i in 0..positions {
let mut choice_data = Vec::new();
for j in 0..choices {
let token = format!("token_{}_{}", i, j);
let alt = format!("alt_{}_{}", i, j);
let prob = 0.5 + (i as f32 * 0.001) + (j as f32 * 0.01); let alt_prob = 1.0 - prob - 0.05; let alt_prob = alt_prob.max(0.1);
choice_data.push(vec![(token, prob, vec![(alt, alt_prob)])]);
}
responses.push(TimestampedResponse::new(
create_multi_choice_response(choice_data),
i,
));
}
let stream = RecordedStream::new(responses, start_time, Instant::now());
Arc::new(stream)
}
fn create_response_with_linear_probs(
_content: &str,
token_data: TokenDataVec,
) -> NvCreateChatCompletionStreamResponse {
let token_logprobs = token_data
.into_iter()
.map(|(token, prob, alternatives)| {
assert!(
(0.0..=1.0).contains(&prob),
"Probability must be in [0, 1]: {}",
prob
);
let total_prob = prob + alternatives.iter().map(|(_, p)| p).sum::<f32>();
assert!(
total_prob <= 1.001,
"Total probability mass exceeds 1: {}",
total_prob
);
let top_logprobs = alternatives
.into_iter()
.map(|(alt_token, alt_prob)| {
assert!(
(0.0..=1.0).contains(&alt_prob),
"Probability must be in [0, 1]: {}",
alt_prob
);
TopLogprobs {
token: alt_token.to_string(),
logprob: alt_prob.ln(),
bytes: None,
}
})
.collect();
ChatCompletionTokenLogprob {
token: token.to_string(),
logprob: prob.ln(),
bytes: None,
top_logprobs,
}
})
.collect();
let choice = ChatChoiceStream {
index: 0,
delta: ChatCompletionStreamResponseDelta {
content: Some(ChatCompletionMessageContent::Text(_content.to_string())),
function_call: None,
tool_calls: None,
role: Some(Role::Assistant),
refusal: None,
reasoning_content: None,
},
finish_reason: Some(FinishReason::Stop),
stop_reason: None,
logprobs: Some(ChatChoiceLogprobs {
content: Some(token_logprobs),
refusal: None,
}),
};
NvCreateChatCompletionStreamResponse {
inner: dynamo_protocols::types::CreateChatCompletionStreamResponse {
id: "test_id".to_string(),
choices: vec![choice],
created: 1234567890,
model: "test-model".to_string(),
service_tier: None,
system_fingerprint: None,
object: "chat.completion.chunk".to_string(),
usage: None,
},
nvext: None,
}
}
fn create_multi_choice_response(
choices_data: MultiChoiceData,
) -> NvCreateChatCompletionStreamResponse {
let choices = choices_data
.into_iter()
.enumerate()
.map(|(choice_idx, token_data)| {
let token_logprobs = token_data
.into_iter()
.map(|(token, prob, alternatives)| {
assert!(
(0.0..=1.0).contains(&prob),
"Probability must be in [0, 1]: {}",
prob
);
let total_prob = prob + alternatives.iter().map(|(_, p)| p).sum::<f32>();
assert!(
total_prob <= 1.001,
"Total probability mass exceeds 1: {}",
total_prob
);
let top_logprobs = alternatives
.into_iter()
.map(|(alt_token, alt_prob)| {
assert!(
(0.0..=1.0).contains(&alt_prob),
"Probability must be in [0, 1]: {}",
alt_prob
);
TopLogprobs {
token: alt_token,
logprob: alt_prob.ln(),
bytes: None,
}
})
.collect();
ChatCompletionTokenLogprob {
token,
logprob: prob.ln(),
bytes: None,
top_logprobs,
}
})
.collect();
ChatChoiceStream {
index: choice_idx as u32,
delta: ChatCompletionStreamResponseDelta {
content: Some(ChatCompletionMessageContent::Text("test".to_string())),
function_call: None,
tool_calls: None,
role: Some(Role::Assistant),
refusal: None,
reasoning_content: None,
},
finish_reason: Some(FinishReason::Stop),
stop_reason: None,
logprobs: Some(ChatChoiceLogprobs {
content: Some(token_logprobs),
refusal: None,
}),
}
})
.collect();
NvCreateChatCompletionStreamResponse {
inner: dynamo_protocols::types::CreateChatCompletionStreamResponse {
id: "test_id".to_string(),
choices,
created: 1234567890,
model: "test-model".to_string(),
service_tier: None,
system_fingerprint: None,
object: "chat.completion.chunk".to_string(),
usage: None,
},
nvext: None,
}
}