Skip to main content

mnem_bench/
runner.rs

1//! Top-level dispatch: take a [`RunPlan`], run each shipped bench,
2//! emit RESULTS.md + per-bench JSON / JSONL into the output dir.
3
4use std::fs;
5use std::io::Write;
6use std::path::{Path, PathBuf};
7use std::time::Instant;
8
9use anyhow::{Context, Result, anyhow, bail};
10
11use crate::adapters::MnemAdapter;
12use crate::bench::{AdapterKind, Bench, EmbedderChoice, RunMode};
13use crate::datasets;
14use crate::embed::{BenchEmbedder, DEFAULT_DIM};
15use crate::output;
16use crate::score::{PerQuestionRow, ScoreReport};
17
18/// Plan for one `mnem bench run` invocation.
19#[derive(Clone, Debug)]
20pub struct RunPlan {
21    /// Benches to attempt.
22    pub benches: Vec<Bench>,
23    /// Adapters (systems-under-test) to run.
24    pub adapters: Vec<AdapterKind>,
25    /// Run mode.
26    pub mode: RunMode,
27    /// Embedder choice.
28    pub embedder: EmbedderChoice,
29    /// Output directory. Created if missing.
30    pub out: PathBuf,
31    /// Top-K depth for per-bench retrieves.
32    pub top_k: usize,
33    /// Per-bench question / conversation cap. `None` = no cap.
34    pub limit: Option<usize>,
35    /// Skip the cached download check; force re-download.
36    pub no_cache: bool,
37    /// Suppress all stderr progress output.
38    pub quiet: bool,
39}
40
41/// Per-bench outcome.
42#[derive(Clone, Debug)]
43pub struct BenchOutcome {
44    /// Bench identity.
45    pub bench: Bench,
46    /// Adapter run for this outcome.
47    pub adapter: AdapterKind,
48    /// Final score report (None when the bench was skipped).
49    pub report: Option<ScoreReport>,
50    /// Free-form skip reason. Empty when the bench succeeded.
51    pub skipped_reason: String,
52}
53
54/// Dispatch the plan. Returns one [`BenchOutcome`] per
55/// `(bench, adapter)` pair.
56pub fn run(plan: &RunPlan) -> Result<Vec<BenchOutcome>> {
57    fs::create_dir_all(&plan.out).with_context(|| format!("creating {}", plan.out.display()))?;
58    let logs_dir = plan.out.join("logs");
59    fs::create_dir_all(&logs_dir).context("creating logs/ subdir")?;
60
61    // OnnxMiniLm needs the (default-on) `onnx-minilm` feature; if the
62    // crate was built without it, [`build_embedder`] silently falls back
63    // to bag-of-tokens. Surface the notice up-front.
64    if matches!(plan.embedder, EmbedderChoice::OnnxMiniLm) {
65        #[cfg(not(feature = "onnx-minilm"))]
66        eprintln!(
67            "[mnem bench] embedder 'onnx-minilm' was selected but mnem-bench was \
68             built without the `onnx-minilm` feature; falling back to bag-of-tokens. \
69             Rebuild with `cargo build -p mnem-bench --features onnx-minilm`."
70        );
71    }
72
73    let timing_log_path = plan.out.join("timing.log");
74    let mut timing_log = fs::File::create(&timing_log_path)
75        .with_context(|| format!("creating {}", timing_log_path.display()))?;
76    let t_total = Instant::now();
77
78    let mut outcomes = Vec::new();
79
80    for adapter_kind in &plan.adapters {
81        for bench in &plan.benches {
82            let meta = bench.metadata();
83            let embedder =
84                build_embedder(plan.embedder).map_err(|e| anyhow!("constructing embedder: {e}"))?;
85            let mut adapter = MnemAdapter::with_embedder(embedder)
86                .map_err(|e| anyhow!("constructing mnem adapter: {e}"))?;
87
88            let outcome = match *bench {
89                Bench::LongMemEval => run_longmemeval(&mut adapter, plan, &logs_dir),
90                Bench::Locomo => run_locomo(&mut adapter, plan, &logs_dir),
91                Bench::Convomem => run_convomem(&mut adapter, plan, &logs_dir),
92                Bench::MembenchSimpleRoles => {
93                    run_membench(&mut adapter, plan, &logs_dir, MembenchSlice::SimpleRoles)
94                }
95                Bench::MembenchHighlevelMovie => {
96                    run_membench(&mut adapter, plan, &logs_dir, MembenchSlice::HighlevelMovie)
97                }
98                Bench::LongMemEvalHybridV4 => {
99                    run_longmemeval_hybrid_v4(&mut adapter, plan, &logs_dir)
100                }
101            };
102            match outcome {
103                Ok((report, rows)) => {
104                    write_outputs(plan, *bench, &report, &rows)?;
105                    writeln!(
106                        timing_log,
107                        "[{}] {} runtime_s={:.2} ingest_s={:.2} retrieve_s={:.2} score_s={:.2}",
108                        adapter_kind.id(),
109                        meta.id,
110                        report.runtime_seconds,
111                        report.timing.ingest_s,
112                        report.timing.retrieve_s,
113                        report.timing.score_s,
114                    )
115                    .ok();
116                    outcomes.push(BenchOutcome {
117                        bench: *bench,
118                        adapter: *adapter_kind,
119                        report: Some(report),
120                        skipped_reason: String::new(),
121                    });
122                }
123                Err(e) => {
124                    let msg = format!("bench {} failed: {e:#}", meta.id);
125                    eprintln!("[mnem bench] {msg}");
126                    let log_path = logs_dir.join(format!("{}.log", meta.id));
127                    let _ = fs::write(&log_path, msg.as_bytes());
128                    outcomes.push(BenchOutcome {
129                        bench: *bench,
130                        adapter: *adapter_kind,
131                        report: None,
132                        skipped_reason: format!("error: {e}"),
133                    });
134                }
135            }
136        }
137    }
138
139    writeln!(
140        timing_log,
141        "[total] elapsed_s={:.2}",
142        t_total.elapsed().as_secs_f64()
143    )
144    .ok();
145    output::write_results_md(&plan.out, &outcomes)?;
146    Ok(outcomes)
147}
148
149fn run_longmemeval(
150    adapter: &mut MnemAdapter,
151    plan: &RunPlan,
152    logs_dir: &Path,
153) -> Result<(ScoreReport, Vec<PerQuestionRow>)> {
154    let path = resolve_dataset(Bench::LongMemEval, plan)?;
155    let mut all = crate::datasets::longmemeval::load(&path)?;
156    if let Some(n) = plan.limit
157        && all.len() > n
158    {
159        all.truncate(n);
160    }
161    if all.is_empty() {
162        bail!(
163            "longmemeval dataset at {} contained no questions",
164            path.display()
165        );
166    }
167    if !plan.quiet {
168        eprintln!("[mnem bench] longmemeval: {} questions", all.len());
169    }
170    let log_path = logs_dir.join("longmemeval.log");
171    let _ = fs::write(
172        &log_path,
173        format!(
174            "longmemeval dataset={} n_questions={} top_k={}\n",
175            path.display(),
176            all.len(),
177            plan.top_k
178        ),
179    );
180    crate::score::longmemeval::run(adapter, &all, plan.top_k, &path)
181}
182
183fn run_locomo(
184    adapter: &mut MnemAdapter,
185    plan: &RunPlan,
186    logs_dir: &Path,
187) -> Result<(ScoreReport, Vec<PerQuestionRow>)> {
188    let path = resolve_dataset(Bench::Locomo, plan)?;
189    let mut all = crate::datasets::locomo::load(&path)?;
190    if let Some(n) = plan.limit
191        && all.len() > n
192    {
193        all.truncate(n);
194    }
195    if all.is_empty() {
196        bail!(
197            "locomo dataset at {} contained no conversations",
198            path.display()
199        );
200    }
201    if !plan.quiet {
202        eprintln!("[mnem bench] locomo: {} conversations", all.len());
203    }
204    let log_path = logs_dir.join("locomo.log");
205    let _ = fs::write(
206        &log_path,
207        format!(
208            "locomo dataset={} n_conversations={} top_k={}\n",
209            path.display(),
210            all.len(),
211            plan.top_k
212        ),
213    );
214    crate::score::locomo::run(adapter, &all, plan.top_k, &path)
215}
216
217fn run_convomem(
218    adapter: &mut MnemAdapter,
219    plan: &RunPlan,
220    logs_dir: &Path,
221) -> Result<(ScoreReport, Vec<PerQuestionRow>)> {
222    let path = resolve_dataset(Bench::Convomem, plan)?;
223    let mut items = crate::datasets::convomem::load(&path)?;
224    if let Some(n) = plan.limit
225        && items.len() > n
226    {
227        items.truncate(n);
228    }
229    if items.is_empty() {
230        bail!("convomem dataset at {} contained no items", path.display());
231    }
232    if !plan.quiet {
233        eprintln!("[mnem bench] convomem: {} items", items.len());
234    }
235    let log_path = logs_dir.join("convomem.log");
236    let _ = fs::write(
237        &log_path,
238        format!(
239            "convomem dataset={} n_items={} top_k={}\n",
240            path.display(),
241            items.len(),
242            plan.top_k
243        ),
244    );
245    crate::score::convomem::run(adapter, &items, plan.top_k, &path)
246}
247
248/// Which MemBench slice the runner should execute.
249#[derive(Clone, Copy, Debug)]
250enum MembenchSlice {
251    /// `simple.json` filtered by topic=`roles`.
252    SimpleRoles,
253    /// `highlevel.json` filtered by topic=`movie`.
254    HighlevelMovie,
255}
256
257/// Default headline-slice cap for MemBench (matches the n=100 used by
258/// MemPalace's published numbers, so the comparison is apples-to-
259/// apples). User-supplied `--limit` overrides.
260const MEMBENCH_HEADLINE_CAP: usize = 100;
261
262fn run_membench(
263    adapter: &mut MnemAdapter,
264    plan: &RunPlan,
265    logs_dir: &Path,
266    slice: MembenchSlice,
267) -> Result<(ScoreReport, Vec<PerQuestionRow>)> {
268    let (bench, category, topic) = match slice {
269        MembenchSlice::SimpleRoles => (Bench::MembenchSimpleRoles, "simple", "roles"),
270        MembenchSlice::HighlevelMovie => (Bench::MembenchHighlevelMovie, "highlevel", "movie"),
271    };
272    let path = resolve_dataset(bench, plan)?;
273    let mut items = crate::datasets::membench::load_filtered(&path, category, Some(topic))?;
274    let effective_limit = plan.limit.unwrap_or(MEMBENCH_HEADLINE_CAP);
275    if items.len() > effective_limit {
276        items.truncate(effective_limit);
277    }
278    if items.is_empty() {
279        bail!(
280            "membench dataset at {} contained no items for topic={}",
281            path.display(),
282            topic
283        );
284    }
285    if !plan.quiet {
286        eprintln!(
287            "[mnem bench] {}: {} items (topic={})",
288            bench.metadata().id,
289            items.len(),
290            topic
291        );
292    }
293    let log_path = logs_dir.join(format!("{}.log", bench.metadata().id));
294    let _ = fs::write(
295        &log_path,
296        format!(
297            "{} dataset={} n_items={} top_k={} topic={}\n",
298            bench.metadata().id,
299            path.display(),
300            items.len(),
301            plan.top_k,
302            topic,
303        ),
304    );
305    match slice {
306        MembenchSlice::SimpleRoles => {
307            crate::score::membench::run_simple_roles(adapter, &items, plan.top_k, &path)
308        }
309        MembenchSlice::HighlevelMovie => {
310            crate::score::membench::run_highlevel_movie(adapter, &items, plan.top_k, &path)
311        }
312    }
313}
314
315fn run_longmemeval_hybrid_v4(
316    adapter: &mut MnemAdapter,
317    plan: &RunPlan,
318    logs_dir: &Path,
319) -> Result<(ScoreReport, Vec<PerQuestionRow>)> {
320    // Reuses the LongMemEval cache. No separate dataset blob exists.
321    let path = resolve_dataset(Bench::LongMemEval, plan)?;
322    let mut all = crate::datasets::longmemeval::load(&path)?;
323    if let Some(n) = plan.limit
324        && all.len() > n
325    {
326        all.truncate(n);
327    }
328    if all.is_empty() {
329        bail!(
330            "longmemeval dataset at {} contained no questions (hybrid-v4)",
331            path.display()
332        );
333    }
334    if !plan.quiet {
335        eprintln!(
336            "[mnem bench] longmemeval-hybrid-v4: {} questions (boost_weight={})",
337            all.len(),
338            crate::score::hybrid_v4::DEFAULT_BOOST_WEIGHT
339        );
340    }
341    let log_path = logs_dir.join("longmemeval-hybrid-v4.log");
342    let _ = fs::write(
343        &log_path,
344        format!(
345            "longmemeval-hybrid-v4 dataset={} n_questions={} top_k={} boost_weight={}\n",
346            path.display(),
347            all.len(),
348            plan.top_k,
349            crate::score::hybrid_v4::DEFAULT_BOOST_WEIGHT,
350        ),
351    );
352    crate::score::hybrid_v4::run(
353        adapter,
354        &all,
355        plan.top_k,
356        &path,
357        crate::score::hybrid_v4::DEFAULT_BOOST_WEIGHT,
358    )
359}
360
361/// Locate the dataset for `bench`. If a copy exists in the cache
362/// dir we use it; otherwise call into [`datasets::fetch`] to
363/// download. Mirrors the upstream Python "fail-fast if file
364/// missing" UX when network is unavailable.
365fn resolve_dataset(bench: Bench, plan: &RunPlan) -> Result<PathBuf> {
366    let cached = datasets::cached_path(bench)?;
367    if cached.is_file() && !plan.no_cache {
368        return Ok(cached);
369    }
370    if !plan.quiet {
371        eprintln!("[mnem bench] fetching {} dataset...", bench.metadata().id);
372    }
373    datasets::fetch(bench, !plan.no_cache, |_d, _t| {})
374}
375
376/// Resolve the runtime embedder for a given [`EmbedderChoice`].
377/// `OnnxMiniLm` needs the `onnx-minilm` feature; absent the feature
378/// we silently fall back to bag-of-tokens (the runner already printed
379/// a notice).
380fn build_embedder(choice: EmbedderChoice) -> Result<BenchEmbedder> {
381    match choice {
382        EmbedderChoice::BagOfTokens => Ok(BenchEmbedder::bag_of_tokens(DEFAULT_DIM)),
383        EmbedderChoice::OnnxMiniLm => {
384            #[cfg(feature = "onnx-minilm")]
385            {
386                BenchEmbedder::onnx_minilm().map_err(|e| anyhow!("onnx-minilm init: {e}"))
387            }
388            #[cfg(not(feature = "onnx-minilm"))]
389            {
390                Ok(BenchEmbedder::bag_of_tokens(DEFAULT_DIM))
391            }
392        }
393    }
394}
395
396fn write_outputs(
397    plan: &RunPlan,
398    bench: Bench,
399    report: &ScoreReport,
400    rows: &[PerQuestionRow],
401) -> Result<()> {
402    let id = bench.metadata().id;
403    let json_path = plan.out.join(format!("{id}.json"));
404    fs::write(&json_path, serde_json::to_vec_pretty(report)?)
405        .with_context(|| format!("writing {}", json_path.display()))?;
406    let jsonl_path = plan.out.join(format!("{id}.jsonl"));
407    let mut jsonl = fs::File::create(&jsonl_path)
408        .with_context(|| format!("creating {}", jsonl_path.display()))?;
409    for row in rows {
410        serde_json::to_writer(&mut jsonl, row)?;
411        jsonl.write_all(b"\n")?;
412    }
413    Ok(())
414}