use std::collections::HashMap;
use std::time::Instant;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use zeph_llm::any::AnyProvider;
use zeph_llm::provider::{LlmProvider as _, Message, MessageMetadata, MessagePart, Role};
use crate::error::MemoryError;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "lowercase")]
pub enum ProbeCategory {
Recall,
Artifact,
Continuation,
Decision,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CategoryScore {
pub category: ProbeCategory,
pub score: f32,
pub probes_run: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct ProbeQuestion {
pub question: String,
pub expected_answer: String,
#[serde(default = "default_probe_category")]
pub category: ProbeCategory,
}
fn default_probe_category() -> ProbeCategory {
ProbeCategory::Recall
}
impl Default for ProbeQuestion {
fn default() -> Self {
Self {
question: String::new(),
expected_answer: String::new(),
category: ProbeCategory::Recall,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ProbeVerdict {
Pass,
SoftFail,
HardFail,
Error,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompactionProbeResult {
pub score: f32,
#[serde(default)]
pub category_scores: Vec<CategoryScore>,
pub questions: Vec<ProbeQuestion>,
pub answers: Vec<String>,
pub per_question_scores: Vec<f32>,
pub verdict: ProbeVerdict,
pub threshold: f32,
pub hard_fail_threshold: f32,
pub model: String,
pub duration_ms: u64,
}
#[derive(Debug, Deserialize, JsonSchema)]
struct ProbeQuestionsOutput {
questions: Vec<ProbeQuestion>,
}
#[must_use]
fn compute_category_scores(
questions: &[ProbeQuestion],
per_question_scores: &[f32],
category_weights: Option<&HashMap<ProbeCategory, f32>>,
) -> (Vec<CategoryScore>, f32) {
let mut by_cat: HashMap<ProbeCategory, Vec<f32>> = HashMap::new();
for (q, &s) in questions.iter().zip(per_question_scores.iter()) {
by_cat.entry(q.category).or_default().push(s);
}
#[allow(clippy::cast_precision_loss)]
let category_scores: Vec<CategoryScore> = by_cat
.into_iter()
.map(|(category, scores)| {
let avg = scores.iter().sum::<f32>() / scores.len() as f32;
CategoryScore {
category,
score: avg,
#[allow(clippy::cast_possible_truncation)]
probes_run: scores.len() as u32,
}
})
.collect();
if category_scores.is_empty() {
return (category_scores, 0.0);
}
let mut weighted_sum = 0.0_f32;
let mut weight_total = 0.0_f32;
for cs in &category_scores {
let raw_w = category_weights
.and_then(|m| m.get(&cs.category).copied())
.unwrap_or(1.0);
if raw_w < 0.0 {
tracing::warn!(
category = ?cs.category,
weight = raw_w,
"category_weights contains a negative value — treating as 0.0 (category excluded from scoring)"
);
}
let w = raw_w.max(0.0);
weighted_sum += cs.score * w;
weight_total += w;
}
let overall = if weight_total > 0.0 {
weighted_sum / weight_total
} else {
#[allow(clippy::cast_precision_loss)]
let n = category_scores.len() as f32;
category_scores.iter().map(|cs| cs.score).sum::<f32>() / n
};
(category_scores, overall)
}
#[derive(Debug, Deserialize, JsonSchema)]
struct ProbeAnswersOutput {
answers: Vec<String>,
}
const REFUSAL_PATTERNS: &[&str] = &[
"unknown",
"not mentioned",
"not found",
"n/a",
"cannot determine",
"no information",
"not provided",
"not specified",
"not stated",
"not available",
];
fn is_refusal(text: &str) -> bool {
let lower = text.to_lowercase();
REFUSAL_PATTERNS.iter().any(|p| lower.contains(p))
}
fn normalize_tokens(text: &str) -> Vec<String> {
text.to_lowercase()
.split(|c: char| !c.is_alphanumeric())
.filter(|t| t.len() >= 3)
.map(String::from)
.collect()
}
fn jaccard(a: &[String], b: &[String]) -> f32 {
if a.is_empty() && b.is_empty() {
return 1.0;
}
let set_a: std::collections::HashSet<&str> = a.iter().map(String::as_str).collect();
let set_b: std::collections::HashSet<&str> = b.iter().map(String::as_str).collect();
let intersection = set_a.intersection(&set_b).count();
let union = set_a.union(&set_b).count();
if union == 0 {
return 0.0;
}
#[allow(clippy::cast_precision_loss)]
{
intersection as f32 / union as f32
}
}
fn score_pair(expected: &str, actual: &str) -> f32 {
if is_refusal(actual) {
return 0.0;
}
let tokens_e = normalize_tokens(expected);
let tokens_a = normalize_tokens(actual);
if !tokens_e.is_empty() {
let set_e: std::collections::HashSet<&str> = tokens_e.iter().map(String::as_str).collect();
let set_a: std::collections::HashSet<&str> = tokens_a.iter().map(String::as_str).collect();
if set_e.is_subset(&set_a) {
return 1.0;
}
}
let j_full = jaccard(&tokens_e, &tokens_a);
let set_e: std::collections::HashSet<&str> = tokens_e.iter().map(String::as_str).collect();
let set_a: std::collections::HashSet<&str> = tokens_a.iter().map(String::as_str).collect();
let intersection: Vec<String> = set_e
.intersection(&set_a)
.map(|s| (*s).to_owned())
.collect();
#[allow(clippy::cast_precision_loss)]
let j_e = if tokens_e.is_empty() {
0.0_f32
} else {
intersection.len() as f32 / tokens_e.len() as f32
};
#[allow(clippy::cast_precision_loss)]
let j_a = if tokens_a.is_empty() {
0.0_f32
} else {
intersection.len() as f32 / tokens_a.len() as f32
};
j_full.max(j_e).max(j_a)
}
#[must_use]
pub fn score_answers(questions: &[ProbeQuestion], answers: &[String]) -> (Vec<f32>, f32) {
if questions.is_empty() {
return (vec![], 0.0);
}
let scores: Vec<f32> = questions
.iter()
.zip(answers.iter().chain(std::iter::repeat(&String::new())))
.map(|(q, a)| score_pair(&q.expected_answer, a))
.collect();
#[allow(clippy::cast_precision_loss)]
let avg = if scores.is_empty() {
0.0
} else {
scores.iter().sum::<f32>() / scores.len() as f32
};
(scores, avg)
}
fn truncate_tool_bodies(messages: &[Message]) -> Vec<Message> {
messages
.iter()
.map(|m| {
let mut msg = m.clone();
for part in &mut msg.parts {
if let MessagePart::ToolOutput { body, .. } = part {
if body.len() <= 500 {
continue;
}
body.truncate(500);
body.push('\u{2026}');
}
}
msg.rebuild_content();
msg
})
.collect()
}
#[cfg_attr(
feature = "profiling",
tracing::instrument(name = "memory.compaction_probe", skip_all)
)]
pub async fn generate_probe_questions(
provider: &AnyProvider,
messages: &[Message],
max_questions: usize,
) -> Result<Vec<ProbeQuestion>, MemoryError> {
let truncated = truncate_tool_bodies(messages);
let mut history = String::new();
for msg in &truncated {
let role = match msg.role {
Role::User => "user",
Role::Assistant => "assistant",
Role::System => "system",
};
history.push_str(role);
history.push_str(": ");
history.push_str(&msg.content);
history.push('\n');
}
let prompt = format!(
"Given the following conversation excerpt, generate {max_questions} factual questions \
that test whether a summary preserves the most important concrete details.\n\
\n\
You MUST generate at least one question per category when max_questions >= 4. \
If the conversation lacks information for a category, generate a question noting that absence.\n\
\n\
Categories:\n\
- recall: Specific facts that survived (file paths, function names, values). \
Example: \"What file was modified?\"\n\
- artifact: Which files/tools/URLs the agent used. \
Example: \"Which tool was executed?\"\n\
- continuation: Next steps, blockers, open questions. \
Example: \"What is the next step?\"\n\
- decision: Past reasoning traces (why X over Y, trade-offs). \
Example: \"Why was X chosen over Y?\"\n\
\n\
Do NOT generate questions about:\n\
- Raw tool output content (compiler warnings, test output line numbers)\n\
- Intermediate debugging steps that were superseded\n\
- Opinions or reasoning that cannot be verified\n\
\n\
Each question must have a single unambiguous expected answer extractable from the text.\n\
\n\
Conversation:\n{history}\n\
\n\
Respond in JSON with schema: {{\"questions\": [{{\"question\": \"...\", \
\"expected_answer\": \"...\", \"category\": \"recall|artifact|continuation|decision\"}}]}}"
);
let msgs = [Message {
role: Role::User,
content: prompt,
parts: vec![],
metadata: MessageMetadata::default(),
}];
let mut output: ProbeQuestionsOutput = provider
.chat_typed_erased::<ProbeQuestionsOutput>(&msgs)
.await
.map_err(MemoryError::Llm)?;
output.questions.truncate(max_questions);
Ok(output.questions)
}
pub async fn answer_probe_questions(
provider: &AnyProvider,
summary: &str,
questions: &[ProbeQuestion],
) -> Result<Vec<String>, MemoryError> {
let mut numbered = String::new();
for (i, q) in questions.iter().enumerate() {
use std::fmt::Write as _;
let _ = writeln!(numbered, "{}. {}", i + 1, q.question);
}
let prompt = format!(
"Given the following summary of a conversation, answer each question using ONLY \
information present in the summary. If the answer is not in the summary, respond \
with \"UNKNOWN\".\n\
\n\
Summary:\n{summary}\n\
\n\
Questions:\n{numbered}\n\
\n\
Respond in JSON with schema: {{\"answers\": [\"answer1\", \"answer2\", ...]}}"
);
let msgs = [Message {
role: Role::User,
content: prompt,
parts: vec![],
metadata: MessageMetadata::default(),
}];
let output: ProbeAnswersOutput = provider
.chat_typed_erased::<ProbeAnswersOutput>(&msgs)
.await
.map_err(MemoryError::Llm)?;
Ok(output.answers)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct CompactionProbeConfig {
pub enabled: bool,
pub probe_provider: String,
pub threshold: f32,
pub hard_fail_threshold: f32,
pub max_questions: usize,
pub timeout_secs: u64,
#[serde(default)]
pub category_weights: Option<HashMap<ProbeCategory, f32>>,
}
impl Default for CompactionProbeConfig {
fn default() -> Self {
Self {
enabled: false,
probe_provider: String::new(),
threshold: 0.6,
hard_fail_threshold: 0.35,
max_questions: 5,
timeout_secs: 15,
category_weights: None,
}
}
}
pub async fn validate_compaction(
provider: AnyProvider,
messages: Vec<Message>,
summary: String,
config: &CompactionProbeConfig,
) -> Result<Option<CompactionProbeResult>, MemoryError> {
if !config.enabled {
return Ok(None);
}
let timeout = std::time::Duration::from_secs(config.timeout_secs);
let start = Instant::now();
let result = tokio::time::timeout(timeout, async {
run_probe(provider, messages, summary, config).await
})
.await;
match result {
Ok(inner) => inner,
Err(_elapsed) => {
tracing::warn!(
timeout_secs = config.timeout_secs,
"compaction probe timed out — proceeding with compaction"
);
Ok(None)
}
}
.map(|opt| {
opt.map(|mut r| {
r.duration_ms = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
r
})
})
}
async fn run_probe(
provider: AnyProvider,
messages: Vec<Message>,
summary: String,
config: &CompactionProbeConfig,
) -> Result<Option<CompactionProbeResult>, MemoryError> {
if summary.len() < 10 {
tracing::warn!(
len = summary.len(),
"compaction probe: summary too short — skipping probe"
);
return Ok(None);
}
let questions = generate_probe_questions(&provider, &messages, config.max_questions).await?;
if questions.len() < 2 {
tracing::debug!(
count = questions.len(),
"compaction probe: fewer than 2 questions generated — skipping probe"
);
return Ok(None);
}
if config.max_questions >= 4 {
use std::collections::HashSet;
let covered: HashSet<_> = questions.iter().map(|q| q.category).collect();
for cat in [
ProbeCategory::Recall,
ProbeCategory::Artifact,
ProbeCategory::Continuation,
ProbeCategory::Decision,
] {
if !covered.contains(&cat) {
tracing::warn!(
category = ?cat,
"compaction probe: LLM did not generate questions for category"
);
}
}
}
let answers = answer_probe_questions(&provider, &summary, &questions).await?;
let (per_question_scores, _simple_avg) = score_answers(&questions, &answers);
let (category_scores, score) = compute_category_scores(
&questions,
&per_question_scores,
config.category_weights.as_ref(),
);
let verdict = if score >= config.threshold {
ProbeVerdict::Pass
} else if score >= config.hard_fail_threshold {
ProbeVerdict::SoftFail
} else {
ProbeVerdict::HardFail
};
let model = provider.name().to_owned();
Ok(Some(CompactionProbeResult {
score,
category_scores,
questions,
answers,
per_question_scores,
verdict,
threshold: config.threshold,
hard_fail_threshold: config.hard_fail_threshold,
model,
duration_ms: 0, }))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn score_perfect_match() {
let q = vec![ProbeQuestion {
question: "What crate is used?".into(),
expected_answer: "thiserror".into(),
category: ProbeCategory::Recall,
}];
let a = vec!["thiserror".into()];
let (scores, avg) = score_answers(&q, &a);
assert_eq!(scores.len(), 1);
assert!((avg - 1.0).abs() < 0.01, "expected ~1.0, got {avg}");
}
#[test]
fn score_complete_mismatch() {
let q = vec![ProbeQuestion {
question: "What file was modified?".into(),
expected_answer: "src/auth.rs".into(),
..Default::default()
}];
let a = vec!["definitely not in the summary".into()];
let (scores, avg) = score_answers(&q, &a);
assert_eq!(scores.len(), 1);
assert!(avg < 0.5, "expected low score, got {avg}");
}
#[test]
fn score_refusal_is_zero() {
let q = vec![ProbeQuestion {
question: "What was the decision?".into(),
expected_answer: "Use thiserror for typed errors".into(),
..Default::default()
}];
for refusal in &[
"UNKNOWN",
"not mentioned",
"N/A",
"cannot determine",
"No information",
] {
let a = vec![(*refusal).to_owned()];
let (_, avg) = score_answers(&q, &a);
assert!(avg < 0.01, "expected 0 for refusal '{refusal}', got {avg}");
}
}
#[test]
fn score_paraphrased_answer_above_half() {
let q = vec![ProbeQuestion {
question: "What error handling crate was chosen?".into(),
expected_answer: "Use thiserror for typed errors in library crates".into(),
..Default::default()
}];
let a = vec!["thiserror was chosen for error types in library crates".into()];
let (_, avg) = score_answers(&q, &a);
assert!(avg > 0.5, "expected >0.5 for paraphrase, got {avg}");
}
#[test]
fn score_empty_strings() {
let q = vec![ProbeQuestion {
question: "What?".into(),
expected_answer: String::new(),
..Default::default()
}];
let a = vec![String::new()];
let (scores, avg) = score_answers(&q, &a);
assert_eq!(scores.len(), 1);
assert!(
(avg - 1.0).abs() < 0.01,
"expected 1.0 for empty vs empty, got {avg}"
);
}
#[test]
fn score_empty_questions_list() {
let (scores, avg) = score_answers(&[], &[]);
assert!(scores.is_empty());
assert!((avg - 0.0).abs() < 0.01);
}
#[test]
fn score_file_path_exact() {
let q = vec![ProbeQuestion {
question: "Which file was modified?".into(),
expected_answer: "crates/zeph-memory/src/compaction_probe.rs".into(),
..Default::default()
}];
let a = vec!["The file crates/zeph-memory/src/compaction_probe.rs was modified.".into()];
let (_, avg) = score_answers(&q, &a);
assert!(
avg > 0.8,
"expected high score for file path match, got {avg}"
);
}
#[test]
fn score_unicode_input() {
let q = vec![ProbeQuestion {
question: "Что было изменено?".into(),
expected_answer: "файл config.toml".into(),
..Default::default()
}];
let a = vec!["config.toml был изменён".into()];
let (scores, _) = score_answers(&q, &a);
assert_eq!(scores.len(), 1);
}
#[test]
fn verdict_thresholds() {
let config = CompactionProbeConfig::default();
let score = 0.7_f32;
let verdict = if score >= config.threshold {
ProbeVerdict::Pass
} else if score >= config.hard_fail_threshold {
ProbeVerdict::SoftFail
} else {
ProbeVerdict::HardFail
};
assert_eq!(verdict, ProbeVerdict::Pass);
let score = 0.5_f32;
let verdict = if score >= config.threshold {
ProbeVerdict::Pass
} else if score >= config.hard_fail_threshold {
ProbeVerdict::SoftFail
} else {
ProbeVerdict::HardFail
};
assert_eq!(verdict, ProbeVerdict::SoftFail);
let score = 0.2_f32;
let verdict = if score >= config.threshold {
ProbeVerdict::Pass
} else if score >= config.hard_fail_threshold {
ProbeVerdict::SoftFail
} else {
ProbeVerdict::HardFail
};
assert_eq!(verdict, ProbeVerdict::HardFail);
}
#[test]
fn config_defaults() {
let c = CompactionProbeConfig::default();
assert!(!c.enabled);
assert!(c.probe_provider.is_empty());
assert!((c.threshold - 0.6).abs() < 0.001);
assert!((c.hard_fail_threshold - 0.35).abs() < 0.001);
assert_eq!(c.max_questions, 5);
assert_eq!(c.timeout_secs, 15);
assert!(c.category_weights.is_none());
}
#[test]
fn config_serde_round_trip() {
let original = CompactionProbeConfig {
enabled: true,
probe_provider: "fast".into(),
threshold: 0.65,
hard_fail_threshold: 0.4,
max_questions: 5,
timeout_secs: 20,
category_weights: None,
};
let json = serde_json::to_string(&original).expect("serialize");
let restored: CompactionProbeConfig = serde_json::from_str(&json).expect("deserialize");
assert!(restored.enabled);
assert_eq!(restored.probe_provider, "fast");
assert!((restored.threshold - 0.65).abs() < 0.001);
}
#[test]
fn probe_result_serde_round_trip() {
let result = CompactionProbeResult {
score: 0.75,
category_scores: vec![CategoryScore {
category: ProbeCategory::Recall,
score: 0.75,
probes_run: 1,
}],
questions: vec![ProbeQuestion {
question: "What?".into(),
expected_answer: "thiserror".into(),
category: ProbeCategory::Recall,
}],
answers: vec!["thiserror".into()],
per_question_scores: vec![1.0],
verdict: ProbeVerdict::Pass,
threshold: 0.6,
hard_fail_threshold: 0.35,
model: "haiku".into(),
duration_ms: 1234,
};
let json = serde_json::to_string(&result).expect("serialize");
let restored: CompactionProbeResult = serde_json::from_str(&json).expect("deserialize");
assert!((restored.score - 0.75).abs() < 0.001);
assert_eq!(restored.verdict, ProbeVerdict::Pass);
assert_eq!(restored.category_scores.len(), 1);
}
#[test]
fn probe_result_backward_compat_no_category_scores() {
let json = r#"{"score":0.75,"questions":[],"answers":[],"per_question_scores":[],"verdict":"Pass","threshold":0.6,"hard_fail_threshold":0.35,"model":"haiku","duration_ms":0}"#;
let restored: CompactionProbeResult = serde_json::from_str(json).expect("deserialize");
assert!(restored.category_scores.is_empty());
}
#[test]
fn score_fewer_answers_than_questions() {
let questions = vec![
ProbeQuestion {
question: "What crate?".into(),
expected_answer: "thiserror".into(),
..Default::default()
},
ProbeQuestion {
question: "What file?".into(),
expected_answer: "src/lib.rs".into(),
..Default::default()
},
ProbeQuestion {
question: "What decision?".into(),
expected_answer: "use async traits".into(),
..Default::default()
},
];
let answers = vec!["thiserror".into()];
let (scores, avg) = score_answers(&questions, &answers);
assert_eq!(scores.len(), 3);
assert!(
(scores[0] - 1.0).abs() < 0.01,
"first score should be ~1.0, got {}",
scores[0]
);
assert!(
scores[1] < 0.5,
"second score should be low for missing answer, got {}",
scores[1]
);
assert!(
scores[2] < 0.5,
"third score should be low for missing answer, got {}",
scores[2]
);
assert!(
avg < 0.5,
"average should be below 0.5 with 2 missing answers, got {avg}"
);
}
#[test]
fn verdict_boundary_at_threshold() {
let config = CompactionProbeConfig::default();
let score = config.threshold;
let verdict = if score >= config.threshold {
ProbeVerdict::Pass
} else if score >= config.hard_fail_threshold {
ProbeVerdict::SoftFail
} else {
ProbeVerdict::HardFail
};
assert_eq!(verdict, ProbeVerdict::Pass);
let score = config.threshold - f32::EPSILON;
let verdict = if score >= config.threshold {
ProbeVerdict::Pass
} else if score >= config.hard_fail_threshold {
ProbeVerdict::SoftFail
} else {
ProbeVerdict::HardFail
};
assert_eq!(verdict, ProbeVerdict::SoftFail);
let score = config.hard_fail_threshold;
let verdict = if score >= config.threshold {
ProbeVerdict::Pass
} else if score >= config.hard_fail_threshold {
ProbeVerdict::SoftFail
} else {
ProbeVerdict::HardFail
};
assert_eq!(verdict, ProbeVerdict::SoftFail);
let score = config.hard_fail_threshold - f32::EPSILON;
let verdict = if score >= config.threshold {
ProbeVerdict::Pass
} else if score >= config.hard_fail_threshold {
ProbeVerdict::SoftFail
} else {
ProbeVerdict::HardFail
};
assert_eq!(verdict, ProbeVerdict::HardFail);
}
#[test]
fn config_partial_json_uses_defaults() {
let json = r#"{"enabled": true}"#;
let c: CompactionProbeConfig =
serde_json::from_str(json).expect("deserialize partial json");
assert!(c.enabled);
assert!(c.probe_provider.is_empty());
assert!((c.threshold - 0.6).abs() < 0.001);
assert!((c.hard_fail_threshold - 0.35).abs() < 0.001);
assert_eq!(c.max_questions, 5);
assert_eq!(c.timeout_secs, 15);
}
#[test]
fn config_empty_json_uses_all_defaults() {
let c: CompactionProbeConfig = serde_json::from_str("{}").expect("deserialize empty json");
assert!(!c.enabled);
assert!(c.probe_provider.is_empty());
}
#[test]
fn probe_category_serde_lowercase() {
assert_eq!(
serde_json::to_string(&ProbeCategory::Recall).unwrap(),
r#""recall""#
);
assert_eq!(
serde_json::to_string(&ProbeCategory::Artifact).unwrap(),
r#""artifact""#
);
assert_eq!(
serde_json::to_string(&ProbeCategory::Continuation).unwrap(),
r#""continuation""#
);
assert_eq!(
serde_json::to_string(&ProbeCategory::Decision).unwrap(),
r#""decision""#
);
let cat: ProbeCategory = serde_json::from_str(r#""recall""#).unwrap();
assert_eq!(cat, ProbeCategory::Recall);
}
#[test]
fn category_weights_toml_round_trip() {
let toml_str = r#"
enabled = true
probe_provider = "fast"
threshold = 0.6
hard_fail_threshold = 0.35
max_questions = 5
timeout_secs = 15
[category_weights]
recall = 1.5
artifact = 1.0
continuation = 1.0
decision = 0.8
"#;
let c: CompactionProbeConfig = toml::from_str(toml_str).expect("deserialize toml");
let weights = c.category_weights.as_ref().unwrap();
assert!((weights[&ProbeCategory::Recall] - 1.5).abs() < 0.001);
assert!((weights[&ProbeCategory::Decision] - 0.8).abs() < 0.001);
}
#[test]
fn category_scores_equal_weights() {
let questions = vec![
ProbeQuestion {
question: "Q1".into(),
expected_answer: "A1".into(),
category: ProbeCategory::Recall,
},
ProbeQuestion {
question: "Q2".into(),
expected_answer: "A2".into(),
category: ProbeCategory::Artifact,
},
];
let scores = [1.0_f32, 0.0_f32];
let (cats, overall) = compute_category_scores(&questions, &scores, None);
assert_eq!(cats.len(), 2);
assert!((overall - 0.5).abs() < 0.001, "overall={overall}");
}
#[test]
fn category_scores_missing_category_excluded() {
let questions = vec![
ProbeQuestion {
question: "Q1".into(),
expected_answer: "A1".into(),
category: ProbeCategory::Recall,
},
ProbeQuestion {
question: "Q2".into(),
expected_answer: "A2".into(),
category: ProbeCategory::Decision,
},
];
let scores = [1.0_f32, 0.6_f32];
let (cats, _overall) = compute_category_scores(&questions, &scores, None);
assert_eq!(cats.len(), 2, "only categories with questions present");
let categories: Vec<_> = cats.iter().map(|c| c.category).collect();
assert!(!categories.contains(&ProbeCategory::Artifact));
assert!(!categories.contains(&ProbeCategory::Continuation));
}
#[test]
fn category_scores_custom_weights() {
let questions = vec![
ProbeQuestion {
question: "Q1".into(),
expected_answer: "A1".into(),
category: ProbeCategory::Recall,
},
ProbeQuestion {
question: "Q2".into(),
expected_answer: "A2".into(),
category: ProbeCategory::Decision,
},
];
let scores = [1.0_f32, 0.0_f32];
let mut weights = HashMap::new();
weights.insert(ProbeCategory::Recall, 2.0_f32);
weights.insert(ProbeCategory::Decision, 1.0_f32);
let (_, overall) = compute_category_scores(&questions, &scores, Some(&weights));
assert!(
(overall - 2.0 / 3.0).abs() < 0.001,
"expected ~0.667, got {overall}"
);
}
#[test]
fn category_scores_all_zero_weights_fallback() {
let questions = vec![
ProbeQuestion {
question: "Q1".into(),
expected_answer: "A1".into(),
category: ProbeCategory::Recall,
},
ProbeQuestion {
question: "Q2".into(),
expected_answer: "A2".into(),
category: ProbeCategory::Artifact,
},
];
let scores = [1.0_f32, 0.0_f32];
let mut weights = HashMap::new();
weights.insert(ProbeCategory::Recall, 0.0_f32);
weights.insert(ProbeCategory::Artifact, 0.0_f32);
let (_, overall) = compute_category_scores(&questions, &scores, Some(&weights));
assert!((overall - 0.5).abs() < 0.001, "overall={overall}");
}
#[test]
fn category_scores_empty_questions() {
let (cats, overall) = compute_category_scores(&[], &[], None);
assert!(cats.is_empty());
assert!((overall - 0.0).abs() < 0.001);
}
#[test]
fn category_scores_multi_probe_single_category_averages() {
let questions = vec![
ProbeQuestion {
question: "Q1".into(),
expected_answer: "A1".into(),
category: ProbeCategory::Recall,
},
ProbeQuestion {
question: "Q2".into(),
expected_answer: "A2".into(),
category: ProbeCategory::Recall,
},
ProbeQuestion {
question: "Q3".into(),
expected_answer: "A3".into(),
category: ProbeCategory::Recall,
},
];
let scores = [1.0_f32, 0.0_f32, 0.5_f32];
let (cats, overall) = compute_category_scores(&questions, &scores, None);
assert_eq!(cats.len(), 1, "only one category present");
assert_eq!(cats[0].category, ProbeCategory::Recall);
assert_eq!(cats[0].probes_run, 3);
assert!(
(cats[0].score - 0.5).abs() < 0.001,
"cat score={}",
cats[0].score
);
assert!((overall - 0.5).abs() < 0.001, "overall={overall}");
}
#[test]
fn probe_question_serde_default_category() {
let json = r#"{"question":"What file?","expected_answer":"src/lib.rs"}"#;
let q: ProbeQuestion = serde_json::from_str(json).expect("deserialize");
assert_eq!(q.category, ProbeCategory::Recall);
assert_eq!(q.question, "What file?");
assert_eq!(q.expected_answer, "src/lib.rs");
}
#[test]
fn probe_question_serde_all_categories_round_trip() {
for cat in [
ProbeCategory::Recall,
ProbeCategory::Artifact,
ProbeCategory::Continuation,
ProbeCategory::Decision,
] {
let q = ProbeQuestion {
question: "test?".into(),
expected_answer: "answer".into(),
category: cat,
};
let json = serde_json::to_string(&q).expect("serialize");
let restored: ProbeQuestion = serde_json::from_str(&json).expect("deserialize");
assert_eq!(restored.category, cat);
}
}
}