use std::sync::Arc;
use std::time::Instant;
use anyhow::Result;
use indicatif::{ProgressBar, ProgressStyle};
use tokio::sync::Semaphore;
use crate::judge;
use crate::resume;
use crate::sampling;
use crate::traits::{BenchmarkDataset, LLMClient, MemorySystem};
use crate::types::{BenchmarkQuestion, EvalResult};
pub struct RunConfig {
pub concurrency: usize,
pub token_budget: usize,
pub output_path: std::path::PathBuf,
pub filter_types: Option<Vec<String>>,
pub resume: bool,
pub quick_size: Option<usize>,
}
pub async fn run_benchmark(
system: &dyn MemorySystem,
dataset: &dyn BenchmarkDataset,
gen_llm: Arc<dyn LLMClient>,
judge_llm: Arc<dyn LLMClient>,
config: &RunConfig,
) -> Result<Vec<EvalResult>> {
let completed_ids = if config.resume {
resume::load_completed_ids(&config.output_path)?
} else {
std::collections::HashSet::new()
};
let all_questions = dataset.questions();
let sampled: Vec<&BenchmarkQuestion>;
let base_questions: &[&BenchmarkQuestion] = if let Some(quick_size) = config.quick_size {
sampled = sampling::stratified_sample(all_questions, quick_size, 42);
&sampled
} else {
sampled = Vec::new();
&sampled };
let questions: Vec<&BenchmarkQuestion> = if config.quick_size.is_some() {
base_questions.iter()
.filter(|q| {
if completed_ids.contains(&q.id) { return false; }
if let Some(ref types) = config.filter_types {
return types.iter().any(|t| t == &q.question_type);
}
true
})
.copied()
.collect()
} else {
all_questions.iter()
.filter(|q| {
if completed_ids.contains(&q.id) { return false; }
if let Some(ref types) = config.filter_types {
return types.iter().any(|t| t == &q.question_type);
}
true
})
.collect()
};
if questions.is_empty() {
tracing::info!("No questions to evaluate (all completed or filtered).");
if config.resume {
return resume::load_results(&config.output_path);
}
return Ok(Vec::new());
}
tracing::info!(
"Evaluating {} questions against {} ({})",
questions.len(),
system.name(),
system.version(),
);
let pb = ProgressBar::new(questions.len() as u64);
pb.set_style(
ProgressStyle::default_bar()
.template("{msg}\n{wide_bar:.green/dim} {pos}/{len} ({eta})")
.unwrap()
.progress_chars("##-"),
);
pb.set_message(format!("Running {} on {}", dataset.name(), system.name()));
let semaphore = Arc::new(Semaphore::new(config.concurrency));
let results = Arc::new(tokio::sync::Mutex::new(Vec::new()));
let output_path = config.output_path.clone();
let correct_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
let total_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
for question in &questions {
let _permit = semaphore.acquire().await?;
let result = evaluate_question(
system,
question,
gen_llm.as_ref(),
judge_llm.as_ref(),
config.token_budget,
).await;
match result {
Ok(eval_result) => {
if eval_result.is_correct {
correct_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
total_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
if let Err(e) = resume::append_result(&output_path, &eval_result) {
tracing::error!("Failed to write result: {e}");
}
results.lock().await.push(eval_result);
}
Err(e) => {
tracing::error!("Error evaluating question {}: {e}", question.id);
total_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
}
let total = total_count.load(std::sync::atomic::Ordering::Relaxed);
let correct = correct_count.load(std::sync::atomic::Ordering::Relaxed);
let accuracy = if total > 0 { correct as f64 / total as f64 * 100.0 } else { 0.0 };
pb.set_message(format!(
"Running {} on {} — {:.1}% ({}/{})",
dataset.name(), system.name(), accuracy, correct, total,
));
pb.inc(1);
}
pb.finish_with_message("Evaluation complete");
let final_results = results.lock().await.clone();
if config.resume && !completed_ids.is_empty() {
let all_results = resume::load_results(&config.output_path)?;
return Ok(all_results);
}
Ok(final_results)
}
async fn evaluate_question(
system: &dyn MemorySystem,
question: &BenchmarkQuestion,
gen_llm: &dyn LLMClient,
judge_llm: &dyn LLMClient,
token_budget: usize,
) -> Result<EvalResult> {
system.reset().await?;
let ingest_start = Instant::now();
for session in &question.sessions {
system.ingest_session(session).await?;
}
let ingest_ms = ingest_start.elapsed().as_millis() as u64;
let retrieval_start = Instant::now();
let retrieval = system.retrieve_context(
&question.question,
question.question_date.as_deref(),
token_budget,
).await?;
let retrieval_ms = retrieval_start.elapsed().as_millis() as u64;
let gen_start = Instant::now();
let prompt = build_generation_prompt(
&retrieval.context,
&question.question,
question.question_date.as_deref(),
);
let hypothesis = gen_llm.generate(&prompt, 256).await?;
let generation_ms = gen_start.elapsed().as_millis() as u64;
let judge_start = Instant::now();
let ground_truth = question.ground_truth.join(", ");
let is_correct = judge::judge_answer(
&question.question_type,
&question.question,
&ground_truth,
&hypothesis,
question.is_abstention,
judge_llm,
).await?;
let judge_ms = judge_start.elapsed().as_millis() as u64;
Ok(EvalResult {
question_id: question.id.clone(),
system_name: system.name().to_string(),
question_type: question.question_type.clone(),
is_abstention: question.is_abstention,
hypothesis,
ground_truth,
is_correct,
ingest_latency_ms: ingest_ms,
retrieval_latency_ms: retrieval_ms,
generation_latency_ms: generation_ms,
judge_latency_ms: judge_ms,
tokens_used: retrieval.tokens_used as u32,
tokens_generated: 0, timestamp: Some(chrono::Utc::now()),
})
}
pub fn build_generation_prompt(context: &str, question: &str, date: Option<&str>) -> String {
let date_line = date.map(|d| format!("\nCurrent date: {d}")).unwrap_or_default();
format!(
r#"You are a helpful assistant with access to conversation history. Use the following context to answer the question accurately and concisely.
Context from memory:
{context}
{date_line}
Question: {question}
Answer the question based on the context above. If the information is not available in the context, say "I don't have enough information to answer this question.""#
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn generation_prompt_with_date() {
let prompt = build_generation_prompt("some context", "What happened?", Some("2024/03/10"));
assert!(prompt.contains("some context"));
assert!(prompt.contains("What happened?"));
assert!(prompt.contains("2024/03/10"));
}
#[test]
fn generation_prompt_without_date() {
let prompt = build_generation_prompt("ctx", "Q?", None);
assert!(prompt.contains("ctx"));
assert!(prompt.contains("Q?"));
assert!(!prompt.contains("Current date"));
}
#[tokio::test]
async fn evaluate_with_echo_and_mock() {
use crate::systems::echo::EchoSystem;
use crate::types::{ConversationSession, Turn};
struct MockLLM;
#[async_trait::async_trait]
impl LLMClient for MockLLM {
fn name(&self) -> &str { "mock" }
async fn generate(&self, _: &str, _: usize) -> Result<String> {
Ok("yes".to_string())
}
}
let system = EchoSystem::new();
let question = BenchmarkQuestion {
id: "q001".to_string(),
question_type: "default".to_string(),
question: "What is the answer?".to_string(),
ground_truth: vec!["yes".to_string()],
question_date: None,
sessions: vec![ConversationSession {
id: "s1".to_string(),
date: None,
turns: vec![Turn { role: "user".to_string(), content: "test".to_string() }],
}],
is_abstention: false,
metadata: std::collections::HashMap::new(),
};
let mock = MockLLM;
let result = evaluate_question(&system, &question, &mock, &mock, 16384).await.unwrap();
assert_eq!(result.question_id, "q001");
assert!(result.is_correct);
assert!(result.ingest_latency_ms < 1000);
}
}