use std::collections::VecDeque;
use serde::{Deserialize, Serialize};
use crate::math::clamp;
pub const DEFAULT_WINDOW_SIZE: usize = 128;
pub const LOGPROB_UNAVAILABLE: f64 = 0.0;
pub const MEAN_NLL_CONFIDENCE_MIDPOINT: f64 = 4.0;
pub const MEAN_NLL_CONFIDENCE_HALF_RANGE: f64 = 3.0;
pub const STRUCT_MIN_LENGTH_CHARS: usize = 40;
pub const STRUCT_HIGH_QUESTION_RATIO: f64 = 0.02;
pub const STRUCT_FALLBACK_DEFAULT: f64 = 0.7;
pub const STRUCT_FALLBACK_WEAK: f64 = 0.4;
pub const STRUCT_FALLBACK_STRONG: f64 = 0.2;
pub const NEUTRAL_CONFIDENCE: f64 = 0.5;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenStatsAccumulator {
logprobs: VecDeque<f64>,
total_tokens: usize,
unavailable_count: usize,
window_size: usize,
}
impl TokenStatsAccumulator {
pub fn new() -> Self {
Self::with_window(DEFAULT_WINDOW_SIZE)
}
pub fn with_window(window_size: usize) -> Self {
Self {
logprobs: VecDeque::new(),
total_tokens: 0,
unavailable_count: 0,
window_size: window_size.max(1),
}
}
pub fn begin_turn(&mut self) {
self.logprobs.clear();
self.total_tokens = 0;
self.unavailable_count = 0;
}
pub fn on_token(&mut self, logprob: f64) {
self.total_tokens += 1;
if logprob >= 0.0 || !logprob.is_finite() {
self.unavailable_count += 1;
return;
}
self.logprobs.push_back(logprob);
if self.logprobs.len() > self.window_size {
self.logprobs.pop_front();
}
}
pub fn has_logprobs(&self) -> bool {
!self.logprobs.is_empty()
}
pub fn logprob_confidence(&self) -> f64 {
if self.logprobs.is_empty() {
return NEUTRAL_CONFIDENCE;
}
let mean_nll: f64 =
self.logprobs.iter().map(|lp| -lp).sum::<f64>() / self.logprobs.len() as f64;
let offset = MEAN_NLL_CONFIDENCE_MIDPOINT - mean_nll;
let confidence = 0.5 + 0.5 * (offset / MEAN_NLL_CONFIDENCE_HALF_RANGE);
clamp(confidence, 0.0, 1.0)
}
pub fn token_count(&self) -> usize {
self.total_tokens
}
pub fn logprob_coverage(&self) -> f64 {
if self.total_tokens == 0 {
return 0.0;
}
let available = self.total_tokens - self.unavailable_count;
available as f64 / self.total_tokens as f64
}
}
impl Default for TokenStatsAccumulator {
fn default() -> Self {
Self::new()
}
}
pub fn structural_confidence(response_text: &str) -> f64 {
let len = response_text.chars().count();
if len == 0 {
return NEUTRAL_CONFIDENCE;
}
let q_count = response_text.chars().filter(|c| *c == '?').count();
let q_ratio = q_count as f64 / len as f64;
let short = len < STRUCT_MIN_LENGTH_CHARS;
let question_heavy = q_ratio >= STRUCT_HIGH_QUESTION_RATIO;
match (short, question_heavy) {
(true, true) => STRUCT_FALLBACK_STRONG,
(true, false) | (false, true) => STRUCT_FALLBACK_WEAK,
(false, false) => STRUCT_FALLBACK_DEFAULT,
}
}
pub fn confidence_with_fallback(
stats: &TokenStatsAccumulator,
response_text: Option<&str>,
) -> f64 {
if stats.has_logprobs() {
return stats.logprob_confidence();
}
if let Some(text) = response_text {
return structural_confidence(text);
}
NEUTRAL_CONFIDENCE
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_accumulator_is_neutral() {
let stats = TokenStatsAccumulator::new();
assert!(!stats.has_logprobs());
assert_eq!(stats.token_count(), 0);
assert_eq!(stats.logprob_coverage(), 0.0);
assert!(
(stats.logprob_confidence() - NEUTRAL_CONFIDENCE).abs() < 1e-9,
"empty window must return neutral confidence"
);
}
#[test]
fn high_logprob_tokens_raise_confidence() {
let mut stats = TokenStatsAccumulator::new();
for _ in 0..20 {
stats.on_token(-0.1);
}
let c = stats.logprob_confidence();
assert!(c > 0.9, "high-probability tokens should push confidence >0.9 (got {c})");
}
#[test]
fn low_logprob_tokens_lower_confidence() {
let mut stats = TokenStatsAccumulator::new();
for _ in 0..20 {
stats.on_token(-7.0);
}
let c = stats.logprob_confidence();
assert!(
c < 0.1,
"high-NLL (gibberish-like) tokens should pull confidence <0.1 (got {c})"
);
}
#[test]
fn mid_range_logprobs_produce_mid_confidence() {
let mut stats = TokenStatsAccumulator::new();
for _ in 0..10 {
stats.on_token(-4.0);
}
let c = stats.logprob_confidence();
assert!(
(c - 0.5).abs() < 1e-9,
"mean NLL at midpoint must map to confidence 0.5 (got {c})"
);
}
#[test]
fn begin_turn_resets_accumulator() {
let mut stats = TokenStatsAccumulator::new();
stats.on_token(-1.0);
stats.on_token(-2.0);
assert!(stats.has_logprobs());
stats.begin_turn();
assert!(!stats.has_logprobs());
assert_eq!(stats.token_count(), 0);
assert_eq!(stats.logprob_coverage(), 0.0);
}
#[test]
fn rolling_window_evicts_oldest() {
let mut reference = TokenStatsAccumulator::with_window(4);
reference.on_token(-3.0);
reference.on_token(-4.0);
reference.on_token(-5.0);
reference.on_token(-6.0);
let expected = reference.logprob_confidence();
let mut stats = TokenStatsAccumulator::with_window(4);
stats.on_token(-1.0); stats.on_token(-2.0); stats.on_token(-3.0);
stats.on_token(-4.0);
stats.on_token(-5.0);
stats.on_token(-6.0);
assert!(
(stats.logprob_confidence() - expected).abs() < 1e-9,
"eviction must match a reference built only from surviving tokens \
(got {}, expected {})",
stats.logprob_confidence(),
expected
);
assert_eq!(stats.token_count(), 6);
}
#[test]
fn unavailable_logprob_does_not_enter_window() {
let mut stats = TokenStatsAccumulator::new();
stats.on_token(LOGPROB_UNAVAILABLE);
assert!(!stats.has_logprobs());
assert_eq!(stats.token_count(), 1);
assert_eq!(stats.logprob_coverage(), 0.0);
}
#[test]
fn non_finite_or_positive_logprob_treated_as_unavailable() {
let mut stats = TokenStatsAccumulator::new();
stats.on_token(1.5); stats.on_token(f64::NAN);
stats.on_token(f64::INFINITY);
stats.on_token(f64::NEG_INFINITY);
assert!(!stats.has_logprobs());
assert_eq!(stats.token_count(), 4);
assert_eq!(stats.logprob_coverage(), 0.0);
}
#[test]
fn logprob_coverage_tracks_available_fraction() {
let mut stats = TokenStatsAccumulator::new();
stats.on_token(-1.0);
stats.on_token(LOGPROB_UNAVAILABLE);
stats.on_token(-2.0);
stats.on_token(LOGPROB_UNAVAILABLE);
assert!((stats.logprob_coverage() - 0.5).abs() < 1e-9);
}
#[test]
fn zero_window_size_clamped_to_one() {
let mut reference = TokenStatsAccumulator::with_window(1);
reference.on_token(-2.0);
let expected = reference.logprob_confidence();
let mut stats = TokenStatsAccumulator::with_window(0);
stats.on_token(-1.0);
stats.on_token(-2.0);
assert!(
(stats.logprob_confidence() - expected).abs() < 1e-9,
"size-0 window should clamp to 1 and evict oldest \
(got {}, expected {})",
stats.logprob_confidence(),
expected
);
}
#[test]
fn structural_empty_text_is_neutral() {
assert!((structural_confidence("") - NEUTRAL_CONFIDENCE).abs() < 1e-9);
}
#[test]
fn structural_short_response_is_low() {
let c = structural_confidence("I don't know.");
assert!(
(c - STRUCT_FALLBACK_WEAK).abs() < 1e-9,
"short response should return weak fallback (got {c})"
);
}
#[test]
fn structural_question_heavy_short_is_strongest_low() {
let c = structural_confidence("What? How? When?");
assert!(
(c - STRUCT_FALLBACK_STRONG).abs() < 1e-9,
"short + question-heavy should return strong-low fallback (got {c})"
);
}
#[test]
fn structural_question_heavy_long_is_weak() {
let c = structural_confidence(
"Which file did you mean? And which function inside it? \
Also, should the refactor preserve the existing signature?",
);
assert!(
(c - STRUCT_FALLBACK_WEAK).abs() < 1e-9,
"question-heavy long response should return weak fallback (got {c})"
);
}
#[test]
fn structural_normal_response_is_default() {
let c = structural_confidence(
"Here is the refactored function. It preserves the original \
signature and moves the body into an async block returning \
a Future. No behaviour changes for synchronous callers.",
);
assert!(
(c - STRUCT_FALLBACK_DEFAULT).abs() < 1e-9,
"unremarkable response should return default fallback (got {c})"
);
}
#[test]
fn fallback_prefers_logprobs_when_available() {
let mut stats = TokenStatsAccumulator::new();
for _ in 0..10 {
stats.on_token(-0.5); }
let c = confidence_with_fallback(&stats, Some("???"));
assert!(c > 0.8, "logprob path should override structural (got {c})");
}
#[test]
fn fallback_uses_structural_when_no_logprobs() {
let stats = TokenStatsAccumulator::new();
let c = confidence_with_fallback(
&stats,
Some("Here is a clear answer with enough length to pass the minimum."),
);
assert!(
(c - STRUCT_FALLBACK_DEFAULT).abs() < 1e-9,
"empty-logprobs + clean text should use structural default (got {c})"
);
}
#[test]
fn fallback_neutral_when_no_signal() {
let stats = TokenStatsAccumulator::new();
let c = confidence_with_fallback(&stats, None);
assert!((c - NEUTRAL_CONFIDENCE).abs() < 1e-9);
}
#[test]
fn fallback_gibberish_path_yields_low_confidence() {
let mut stats = TokenStatsAccumulator::new();
for _ in 0..30 {
stats.on_token(-6.5); }
let c = confidence_with_fallback(&stats, None);
assert!(
c < 0.2,
"gibberish-level mean NLL should land confidence in the low band (got {c})"
);
}
}