mnem_bench/datasets/
mod.rs1pub mod convomem;
10pub mod locomo;
11pub mod longmemeval;
12pub mod membench;
13
14use std::fs;
15use std::io::{self, Read, Write};
16use std::path::{Path, PathBuf};
17
18use anyhow::{Context, Result, anyhow, bail};
19use sha2::{Digest, Sha256};
20
21use crate::bench::Bench;
22
23#[derive(Clone, Debug)]
25pub struct DatasetSpec {
26 pub bench: Bench,
28 pub filename: &'static str,
30 pub url: &'static str,
34 pub sha256: &'static str,
37 pub bytes: u64,
39}
40
41#[must_use]
43pub fn spec_for(bench: Bench) -> Option<DatasetSpec> {
44 match bench {
45 Bench::LongMemEval => Some(longmemeval::SPEC),
46 Bench::Locomo => Some(locomo::SPEC),
47 Bench::Convomem => Some(convomem::SPEC),
48 Bench::MembenchSimpleRoles => Some(membench::SIMPLE_ROLES_SPEC),
49 Bench::MembenchHighlevelMovie => Some(membench::HIGHLEVEL_MOVIE_SPEC),
50 _ => None,
51 }
52}
53
54pub fn cache_dir_for(bench: Bench) -> Result<PathBuf> {
56 let base = bench_data_root()?;
57 let dir = base.join(bench.metadata().id);
58 fs::create_dir_all(&dir).with_context(|| format!("creating {}", dir.display()))?;
59 Ok(dir)
60}
61
62pub fn bench_data_root() -> Result<PathBuf> {
64 if let Ok(p) = std::env::var("MNEM_BENCH_DATA") {
65 return Ok(PathBuf::from(p));
66 }
67 let dirs = directories::BaseDirs::new()
68 .ok_or_else(|| anyhow!("HOME / USERPROFILE unset; cannot resolve ~/.mnem"))?;
69 Ok(dirs.home_dir().join(".mnem").join("bench-data"))
70}
71
72pub fn cached_path(bench: Bench) -> Result<PathBuf> {
75 let spec =
76 spec_for(bench).ok_or_else(|| anyhow!("no dataset spec for {}", bench.metadata().id))?;
77 Ok(cache_dir_for(bench)?.join(spec.filename))
78}
79
80pub fn is_cached(bench: Bench) -> bool {
84 let Ok(p) = cached_path(bench) else {
85 return false;
86 };
87 let Some(spec) = spec_for(bench) else {
88 return false;
89 };
90 if !p.is_file() {
91 return false;
92 }
93 if spec.sha256.is_empty() {
94 return true;
95 }
96 sha256_file(&p)
97 .map(|h| h.eq_ignore_ascii_case(spec.sha256))
98 .unwrap_or(false)
99}
100
101pub fn fetch<F: FnMut(u64, u64)>(
107 bench: Bench,
108 skip_cached: bool,
109 mut progress_cb: F,
110) -> Result<PathBuf> {
111 if matches!(bench, Bench::Convomem) {
115 let dir = cache_dir_for(bench)?;
116 let dst = dir.join(convomem::SPEC.filename);
117 if skip_cached && dst.is_file() {
118 return Ok(dst);
119 }
120 return convomem::fetch_into(&dir);
121 }
122 let spec =
123 spec_for(bench).ok_or_else(|| anyhow!("no dataset spec for {}", bench.metadata().id))?;
124 let dst = cached_path(bench)?;
125 if skip_cached && is_cached(bench) {
126 return Ok(dst);
127 }
128 if dst.is_file() && !spec.sha256.is_empty() {
129 let actual = sha256_file(&dst)?;
130 if actual.eq_ignore_ascii_case(spec.sha256) {
131 return Ok(dst);
132 }
133 let bad = dst.with_extension("bad");
135 let _ = fs::rename(&dst, &bad);
136 }
137
138 let tmp = dst.with_extension("part");
141 let resp = ureq::get(spec.url)
142 .call()
143 .with_context(|| format!("GET {}", spec.url))?;
144 let total: u64 = resp
145 .header("content-length")
146 .and_then(|s| s.parse().ok())
147 .unwrap_or(spec.bytes);
148
149 let mut reader = resp.into_reader();
150 let mut file = fs::File::create(&tmp).with_context(|| format!("creating {}", tmp.display()))?;
151 let mut hasher = Sha256::new();
152 let mut buf = vec![0u8; 64 * 1024];
153 let mut done = 0u64;
154 loop {
155 let n = match reader.read(&mut buf) {
156 Ok(0) => break,
157 Ok(n) => n,
158 Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
159 Err(e) => return Err(e).context("download read"),
160 };
161 file.write_all(&buf[..n]).context("write to disk")?;
162 hasher.update(&buf[..n]);
163 done = done.saturating_add(n as u64);
164 progress_cb(done, total);
165 }
166 file.flush().ok();
167 drop(file);
168
169 let actual = hex::encode(hasher.finalize());
170 if !spec.sha256.is_empty() && !actual.eq_ignore_ascii_case(spec.sha256) {
171 let bad = dst.with_extension("bad");
172 fs::rename(&tmp, &bad).ok();
173 bail!(
174 "sha256 mismatch for {}: expected {}, got {}. file kept at {}",
175 spec.filename,
176 spec.sha256,
177 actual,
178 bad.display()
179 );
180 }
181 fs::rename(&tmp, &dst)
182 .with_context(|| format!("renaming {} -> {}", tmp.display(), dst.display()))?;
183 Ok(dst)
184}
185
186pub fn sha256_file(p: &Path) -> Result<String> {
188 let mut f = fs::File::open(p).with_context(|| format!("opening {}", p.display()))?;
189 let mut hasher = Sha256::new();
190 let mut buf = vec![0u8; 64 * 1024];
191 loop {
192 let n = f
193 .read(&mut buf)
194 .with_context(|| format!("reading {}", p.display()))?;
195 if n == 0 {
196 break;
197 }
198 hasher.update(&buf[..n]);
199 }
200 Ok(hex::encode(hasher.finalize()))
201}