use super::session::InferenceSession;
use crate::error::{AmbiError, Result};
use llama_cpp_2::context::LlamaContext;
use llama_cpp_2::llama_batch::LlamaBatch;
use llama_cpp_2::model::{AddBos, LlamaModel};
impl InferenceSession {
pub(crate) fn evaluate_entropy(
sentence: &str,
model: &LlamaModel,
context: &mut LlamaContext,
mut batch: &mut LlamaBatch,
session: &mut InferenceSession,
) -> Result<f32> {
let tokens = model
.str_to_token(sentence, AddBos::Always)
.map_err(|e| AmbiError::EngineError(format!("Tokenize failed: {}", e)))?
.to_vec();
if tokens.is_empty() {
return Ok(0.0);
}
batch.clear();
for (i, &t) in tokens.iter().enumerate() {
batch
.add(t, session.pos + i as i32, &[0], true)
.map_err(|e| AmbiError::EngineError(format!("Batch add failed: {}", e)))?;
}
context
.decode(&mut batch)
.map_err(|e| AmbiError::EngineError(format!("Decoding failed: {}", e)))?;
let mut total_entropy = 0.0_f32;
for i in 0..tokens.len() {
let logits = context.get_logits_ith(i as i32);
total_entropy += Self::token_entropy(logits);
}
session.pos += tokens.len() as i32;
Ok(total_entropy / tokens.len() as f32)
}
fn token_entropy(logits: &[f32]) -> f32 {
let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let sum_exp: f32 = logits.iter().map(|&l| (l - max_logit).exp()).sum();
logits
.iter()
.map(|&l| {
let p = (l - max_logit).exp() / sum_exp;
if p > 1e-7 {
-p * p.ln()
} else {
0.0
}
})
.sum()
}
}