Skip to main content

mnem_bench/datasets/
mod.rs

1//! Dataset cache + sha256 verification + download.
2//!
3//! Datasets land under `~/.mnem/bench-data/<bench>/<filename>`.
4//! On `fetch`, the URL is downloaded with `ureq`, hashed with
5//! sha256, and compared against a hardcoded expected digest. A
6//! mismatch leaves the file on disk with a `.bad` suffix so the
7//! operator can inspect it before deleting.
8
9pub mod convomem;
10pub mod locomo;
11pub mod longmemeval;
12pub mod membench;
13
14use std::fs;
15use std::io::{self, Read, Write};
16use std::path::{Path, PathBuf};
17
18use anyhow::{Context, Result, anyhow, bail};
19use sha2::{Digest, Sha256};
20
21use crate::bench::Bench;
22
23/// Static descriptor for one dataset file.
24#[derive(Clone, Debug)]
25pub struct DatasetSpec {
26    /// Bench this file feeds.
27    pub bench: Bench,
28    /// Filename under `~/.mnem/bench-data/<bench>/`.
29    pub filename: &'static str,
30    /// Direct download URL. Picked to be a bare-bytes endpoint
31    /// (HuggingFace `resolve/main/...`) so we do not need a JSON
32    /// parser to find the artefact.
33    pub url: &'static str,
34    /// Expected sha256 of the downloaded bytes (lower-case hex).
35    /// Empty string disables the check (used during dev only).
36    pub sha256: &'static str,
37    /// Approximate bytes (for the progress bar baseline).
38    pub bytes: u64,
39}
40
41/// Look up the canonical spec for a bench.
42#[must_use]
43pub fn spec_for(bench: Bench) -> Option<DatasetSpec> {
44    match bench {
45        Bench::LongMemEval => Some(longmemeval::SPEC),
46        Bench::Locomo => Some(locomo::SPEC),
47        Bench::Convomem => Some(convomem::SPEC),
48        Bench::MembenchSimpleRoles => Some(membench::SIMPLE_ROLES_SPEC),
49        Bench::MembenchHighlevelMovie => Some(membench::HIGHLEVEL_MOVIE_SPEC),
50        _ => None,
51    }
52}
53
54/// Resolve the cache directory for `bench`. Creates it if missing.
55pub fn cache_dir_for(bench: Bench) -> Result<PathBuf> {
56    let base = bench_data_root()?;
57    let dir = base.join(bench.metadata().id);
58    fs::create_dir_all(&dir).with_context(|| format!("creating {}", dir.display()))?;
59    Ok(dir)
60}
61
62/// `~/.mnem/bench-data/`. Honours `MNEM_BENCH_DATA` for tests.
63pub fn bench_data_root() -> Result<PathBuf> {
64    if let Ok(p) = std::env::var("MNEM_BENCH_DATA") {
65        return Ok(PathBuf::from(p));
66    }
67    let dirs = directories::BaseDirs::new()
68        .ok_or_else(|| anyhow!("HOME / USERPROFILE unset; cannot resolve ~/.mnem"))?;
69    Ok(dirs.home_dir().join(".mnem").join("bench-data"))
70}
71
72/// Resolve the path to the cached dataset file for `bench`. Does
73/// NOT download; use [`fetch`] for that.
74pub fn cached_path(bench: Bench) -> Result<PathBuf> {
75    let spec =
76        spec_for(bench).ok_or_else(|| anyhow!("no dataset spec for {}", bench.metadata().id))?;
77    Ok(cache_dir_for(bench)?.join(spec.filename))
78}
79
80/// Whether the cached dataset for `bench` exists AND verifies
81/// against its expected sha256. False if the file is absent, the
82/// hash does not match, or the spec is empty.
83pub fn is_cached(bench: Bench) -> bool {
84    let Ok(p) = cached_path(bench) else {
85        return false;
86    };
87    let Some(spec) = spec_for(bench) else {
88        return false;
89    };
90    if !p.is_file() {
91        return false;
92    }
93    if spec.sha256.is_empty() {
94        return true;
95    }
96    sha256_file(&p)
97        .map(|h| h.eq_ignore_ascii_case(spec.sha256))
98        .unwrap_or(false)
99}
100
101/// Fetch the dataset for `bench` into the cache. Idempotent: if
102/// the cached file already verifies, returns the path immediately.
103///
104/// `progress_cb` is called with `(downloaded_bytes, total_bytes)`
105/// every ~64KB. Pass a no-op closure when running headless.
106pub fn fetch<F: FnMut(u64, u64)>(
107    bench: Bench,
108    skip_cached: bool,
109    mut progress_cb: F,
110) -> Result<PathBuf> {
111    // ConvoMem is multi-shard. Walk the bundled manifest, merge,
112    // emit a single canonical blob. The per-shard cache lives under
113    // `cache_dir/shards/`.
114    if matches!(bench, Bench::Convomem) {
115        let dir = cache_dir_for(bench)?;
116        let dst = dir.join(convomem::SPEC.filename);
117        if skip_cached && dst.is_file() {
118            return Ok(dst);
119        }
120        return convomem::fetch_into(&dir);
121    }
122    let spec =
123        spec_for(bench).ok_or_else(|| anyhow!("no dataset spec for {}", bench.metadata().id))?;
124    let dst = cached_path(bench)?;
125    if skip_cached && is_cached(bench) {
126        return Ok(dst);
127    }
128    if dst.is_file() && !spec.sha256.is_empty() {
129        let actual = sha256_file(&dst)?;
130        if actual.eq_ignore_ascii_case(spec.sha256) {
131            return Ok(dst);
132        }
133        // Stale or corrupt - keep the bytes for forensics.
134        let bad = dst.with_extension("bad");
135        let _ = fs::rename(&dst, &bad);
136    }
137
138    // Stream the download to a temp file, sha as we go, then
139    // rename atomically.
140    let tmp = dst.with_extension("part");
141    let resp = ureq::get(spec.url)
142        .call()
143        .with_context(|| format!("GET {}", spec.url))?;
144    let total: u64 = resp
145        .header("content-length")
146        .and_then(|s| s.parse().ok())
147        .unwrap_or(spec.bytes);
148
149    let mut reader = resp.into_reader();
150    let mut file = fs::File::create(&tmp).with_context(|| format!("creating {}", tmp.display()))?;
151    let mut hasher = Sha256::new();
152    let mut buf = vec![0u8; 64 * 1024];
153    let mut done = 0u64;
154    loop {
155        let n = match reader.read(&mut buf) {
156            Ok(0) => break,
157            Ok(n) => n,
158            Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
159            Err(e) => return Err(e).context("download read"),
160        };
161        file.write_all(&buf[..n]).context("write to disk")?;
162        hasher.update(&buf[..n]);
163        done = done.saturating_add(n as u64);
164        progress_cb(done, total);
165    }
166    file.flush().ok();
167    drop(file);
168
169    let actual = hex::encode(hasher.finalize());
170    if !spec.sha256.is_empty() && !actual.eq_ignore_ascii_case(spec.sha256) {
171        let bad = dst.with_extension("bad");
172        fs::rename(&tmp, &bad).ok();
173        bail!(
174            "sha256 mismatch for {}: expected {}, got {}. file kept at {}",
175            spec.filename,
176            spec.sha256,
177            actual,
178            bad.display()
179        );
180    }
181    fs::rename(&tmp, &dst)
182        .with_context(|| format!("renaming {} -> {}", tmp.display(), dst.display()))?;
183    Ok(dst)
184}
185
186/// Hash a file with sha256, returning lower-case hex.
187pub fn sha256_file(p: &Path) -> Result<String> {
188    let mut f = fs::File::open(p).with_context(|| format!("opening {}", p.display()))?;
189    let mut hasher = Sha256::new();
190    let mut buf = vec![0u8; 64 * 1024];
191    loop {
192        let n = f
193            .read(&mut buf)
194            .with_context(|| format!("reading {}", p.display()))?;
195        if n == 0 {
196            break;
197        }
198        hasher.update(&buf[..n]);
199    }
200    Ok(hex::encode(hasher.finalize()))
201}