1use std::path::{Path, PathBuf};
2
3use anyhow::{Context, Result, anyhow};
4use chrono::Utc;
5use serde::{Deserialize, Serialize};
6
7use crate::app_paths;
8
9const LONGMEMEVAL_DATASET_URLS: &[&str] = &[
10 "https://huggingface.co/datasets/LIXINYI33/longmemeval-s/resolve/main/longmemeval_s_cleaned.json",
11 "https://huggingface.co/datasets/kellyhongg/cleaned-longmemeval-s/resolve/main/longmemeval_s_cleaned.json",
12 "https://github.com/xiaowu0162/longmemeval-cleaned/raw/main/longmemeval_s_cleaned.json",
13];
14const LOCOMO_DATASET_URLS: &[&str] = &[
15 "https://raw.githubusercontent.com/snap-research/locomo/main/data/locomo10.json",
16 "https://github.com/snap-research/locomo/raw/main/data/locomo10.json",
17];
18
19#[derive(Debug, Clone, Serialize)]
20pub struct BenchmarkMetadata {
21 pub benchmark: String,
22 pub command: String,
23 pub date: String,
24 pub commit: Option<String>,
25 pub machine: String,
26 pub dataset_source: String,
27 pub dataset_path: String,
28}
29
30#[derive(Debug)]
31pub struct DatasetArtifact {
32 pub source_url: String,
33 pub path: PathBuf,
34 pub temporary: bool,
35}
36
37#[derive(Debug, Clone, Copy)]
38pub enum DatasetKind {
39 LongMemEval,
40 LoCoMo10,
41}
42
43impl DatasetKind {
44 fn cache_subdir(self) -> &'static str {
45 match self {
46 Self::LongMemEval => "longmemeval",
47 Self::LoCoMo10 => "locomo",
48 }
49 }
50
51 fn filename(self) -> &'static str {
52 match self {
53 Self::LongMemEval => "longmemeval_s_cleaned.json",
54 Self::LoCoMo10 => "locomo10.json",
55 }
56 }
57
58 fn source_urls(self) -> &'static [&'static str] {
59 match self {
60 Self::LongMemEval => LONGMEMEVAL_DATASET_URLS,
61 Self::LoCoMo10 => LOCOMO_DATASET_URLS,
62 }
63 }
64}
65
66pub async fn resolve_dataset(
67 kind: DatasetKind,
68 dataset_path: Option<PathBuf>,
69 force_refresh: bool,
70 temporary: bool,
71) -> Result<DatasetArtifact> {
72 if dataset_path.is_some() && (force_refresh || temporary) {
73 return Err(anyhow!(
74 "--dataset-path cannot be combined with --force-refresh or --temp-dataset"
75 ));
76 }
77 if let Some(path) = dataset_path {
78 validate_json_file(&path)?;
79 return Ok(DatasetArtifact {
80 source_url: "user-supplied".to_string(),
81 path,
82 temporary: false,
83 });
84 }
85
86 let cache_path = if temporary {
87 temporary_dataset_path(kind)
88 } else {
89 benchmark_cache_path(kind)?
90 };
91 let cache_is_valid = cache_path.exists() && validate_json_file(&cache_path).is_ok();
92 if force_refresh || !cache_is_valid {
93 let source_url = download_from_sources(kind.source_urls(), &cache_path).await?;
94 if !temporary {
95 write_source_metadata(&cache_path, &source_url)?;
96 }
97 return Ok(DatasetArtifact {
98 source_url,
99 path: cache_path,
100 temporary,
101 });
102 }
103 Ok(DatasetArtifact {
104 source_url: read_source_metadata(&cache_path)
105 .unwrap_or_else(|| kind.source_urls()[0].to_string()),
106 path: cache_path,
107 temporary,
108 })
109}
110
111pub fn benchmark_cache_path(kind: DatasetKind) -> Result<PathBuf> {
112 let cache_root = app_paths::resolve_app_paths()?.benchmark_root;
113 Ok(cache_root.join(kind.cache_subdir()).join(kind.filename()))
114}
115
116pub fn benchmark_metadata(benchmark: &str, dataset: &DatasetArtifact) -> BenchmarkMetadata {
117 benchmark_metadata_from_parts(
118 benchmark,
119 &dataset.source_url,
120 &dataset.path.display().to_string(),
121 )
122}
123
124pub fn benchmark_metadata_from_parts(
125 benchmark: &str,
126 dataset_source: &str,
127 dataset_path: &str,
128) -> BenchmarkMetadata {
129 BenchmarkMetadata {
130 benchmark: benchmark.to_string(),
131 command: sanitize_command(std::env::args()),
132 date: Utc::now().to_rfc3339(),
133 commit: git_commit(),
134 machine: machine_descriptor(),
135 dataset_source: dataset_source.to_string(),
136 dataset_path: sanitize_dataset_path(dataset_path),
137 }
138}
139
140fn validate_json_file(path: &Path) -> Result<()> {
141 let file = std::fs::File::open(path)
142 .with_context(|| format!("failed to open dataset at {}", path.display()))?;
143 let mut reader = std::io::BufReader::new(file);
144 let mut de = serde_json::Deserializer::from_reader(&mut reader);
145 serde::de::IgnoredAny::deserialize(&mut de)
146 .with_context(|| format!("failed to parse JSON dataset at {}", path.display()))?;
147 de.end()
148 .with_context(|| format!("trailing content in JSON dataset at {}", path.display()))?;
149 Ok(())
150}
151
152async fn download_from_sources(urls: &[&str], path: &Path) -> Result<String> {
153 let mut failures = Vec::new();
154 for url in urls {
155 match download_file(url, path).await {
156 Ok(()) => return Ok((*url).to_string()),
157 Err(err) => failures.push(format!("{url}: {err}")),
158 }
159 }
160 Err(anyhow!(
161 "failed to download benchmark dataset from any public source:\n{}",
162 failures.join("\n")
163 ))
164}
165
166async fn download_file(url: &str, path: &Path) -> Result<()> {
167 if let Some(parent) = path.parent() {
168 tokio::fs::create_dir_all(parent)
169 .await
170 .with_context(|| format!("failed to create dataset cache dir {}", parent.display()))?;
171 }
172
173 let client = reqwest::Client::builder()
174 .timeout(std::time::Duration::from_secs(600))
175 .connect_timeout(std::time::Duration::from_secs(30))
176 .build()
177 .context("failed to build benchmark download client")?;
178
179 let response = client
180 .get(url)
181 .send()
182 .await
183 .with_context(|| format!("failed to download benchmark dataset from {url}"))?
184 .error_for_status()
185 .with_context(|| format!("benchmark dataset download failed for {url}"))?;
186 let bytes = response
187 .bytes()
188 .await
189 .with_context(|| format!("failed to read benchmark dataset body from {url}"))?;
190
191 let mut part_name = path
192 .file_name()
193 .ok_or_else(|| anyhow!("invalid dataset file name for {}", path.display()))?
194 .to_os_string();
195 part_name.push(".part");
196 let part_path = path.with_file_name(part_name);
197 let write_result: Result<()> = async {
198 tokio::fs::write(&part_path, &bytes)
199 .await
200 .with_context(|| format!("failed to write {}", part_path.display()))?;
201 let validate_path = part_path.clone();
202 tokio::task::spawn_blocking(move || validate_json_file(&validate_path))
203 .await
204 .context("dataset validation task failed")??;
205 remove_file_if_exists(path).await?;
206 tokio::fs::rename(&part_path, path)
207 .await
208 .with_context(|| format!("failed to finalize {}", path.display()))?;
209 Ok(())
210 }
211 .await;
212 if let Err(err) = write_result {
213 let _ = tokio::fs::remove_file(&part_path).await;
214 return Err(err);
215 }
216 Ok(())
217}
218
219async fn remove_file_if_exists(path: &Path) -> Result<()> {
220 match tokio::fs::remove_file(path).await {
221 Ok(()) => Ok(()),
222 Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(()),
223 Err(err) => Err(err)
224 .with_context(|| format!("failed to remove existing dataset {}", path.display())),
225 }
226}
227
228fn read_source_metadata(path: &Path) -> Option<String> {
229 let meta_path = source_metadata_path(path);
230 let text = std::fs::read_to_string(meta_path).ok()?;
231 let trimmed = text.trim();
232 if trimmed.is_empty() {
233 None
234 } else {
235 Some(trimmed.to_string())
236 }
237}
238
239fn write_source_metadata(path: &Path, source_url: &str) -> Result<()> {
240 std::fs::write(source_metadata_path(path), source_url)
241 .with_context(|| format!("failed to write source metadata for {}", path.display()))
242}
243
244fn source_metadata_path(path: &Path) -> PathBuf {
245 let mut suffix = path.file_name().unwrap_or_default().to_os_string();
246 suffix.push(".source-url");
247 path.with_file_name(suffix)
248}
249
250fn temporary_dataset_path(kind: DatasetKind) -> PathBuf {
251 let stamp = Utc::now().format("%Y%m%d%H%M%S");
252 let filename = format!("{stamp}-{}-{}", std::process::id(), kind.filename());
253 std::env::temp_dir()
254 .join("mag-benchmarks")
255 .join(kind.cache_subdir())
256 .join(filename)
257}
258
259impl DatasetArtifact {
260 pub fn cleanup(&mut self) -> Result<()> {
261 if !self.temporary || !self.path.exists() {
262 return Ok(());
263 }
264 std::fs::remove_file(&self.path).with_context(|| {
265 format!("failed to remove temporary dataset {}", self.path.display())
266 })?;
267 self.temporary = false;
268 Ok(())
269 }
270}
271
272impl Drop for DatasetArtifact {
273 fn drop(&mut self) {
274 if self.temporary {
275 let _ = std::fs::remove_file(&self.path);
276 }
277 }
278}
279
280fn sanitize_command(args: impl IntoIterator<Item = String>) -> String {
281 args.into_iter()
282 .map(|arg| quote_shell_arg(&sanitize_arg(&arg)))
283 .collect::<Vec<_>>()
284 .join(" ")
285}
286
287fn sanitize_arg(arg: &str) -> String {
288 if looks_like_path(arg) {
289 "<redacted_path>".to_string()
290 } else {
291 arg.to_string()
292 }
293}
294
295fn quote_shell_arg(arg: &str) -> String {
296 if arg
297 .chars()
298 .any(|ch| ch.is_whitespace() || matches!(ch, '"' | '\\' | '$' | '`'))
299 {
300 format!("\"{}\"", arg.replace('\\', "\\\\").replace('"', "\\\""))
301 } else {
302 arg.to_string()
303 }
304}
305
306fn sanitize_dataset_path(dataset_path: &str) -> String {
307 Path::new(dataset_path)
308 .file_name()
309 .and_then(|value| value.to_str())
310 .map(ToOwned::to_owned)
311 .unwrap_or_else(|| "<redacted_path>".to_string())
312}
313
314fn looks_like_path(arg: &str) -> bool {
315 Path::new(arg).is_absolute() || arg.starts_with('~') || arg.contains('/') || arg.contains('\\')
316}
317
318fn git_commit() -> Option<String> {
319 command_stdout("git", &["rev-parse", "HEAD"])
320}
321
322fn machine_descriptor() -> String {
323 let mut parts = vec![
324 format!("{} {}", std::env::consts::OS, std::env::consts::ARCH),
325 format!("{} CPU", num_cpus::get()),
326 ];
327 if let Some(model) = command_stdout("sysctl", &["-n", "hw.model"]) {
328 parts.insert(0, model);
329 }
330 parts.join(", ")
331}
332
333fn command_stdout(cmd: &str, args: &[&str]) -> Option<String> {
334 let output = std::process::Command::new(cmd).args(args).output().ok()?;
335 if !output.status.success() {
336 return None;
337 }
338 let text = String::from_utf8_lossy(&output.stdout).trim().to_string();
339 if text.is_empty() { None } else { Some(text) }
340}
341
342#[cfg(test)]
343mod tests {
344 use super::*;
345 use uuid::Uuid;
346
347 #[tokio::test]
348 async fn resolve_dataset_uses_explicit_path_without_downloading() {
349 let path =
350 std::env::temp_dir().join(format!("mag-benchmark-fixture-{}.json", Uuid::new_v4()));
351 std::fs::write(&path, r#"[{"question":"hi"}]"#).unwrap();
352
353 let dataset = resolve_dataset(DatasetKind::LongMemEval, Some(path.clone()), false, false)
354 .await
355 .unwrap();
356 assert_eq!(dataset.path, path);
357 assert_eq!(dataset.source_url, "user-supplied");
358
359 let _ = std::fs::remove_file(path);
360 }
361
362 #[test]
363 fn source_metadata_round_trip() {
364 let path = std::env::temp_dir().join(format!("mag-benchmark-meta-{}.json", Uuid::new_v4()));
365 write_source_metadata(&path, "https://example.com/dataset.json").unwrap();
366 assert_eq!(
367 read_source_metadata(&path).as_deref(),
368 Some("https://example.com/dataset.json")
369 );
370 let _ = std::fs::remove_file(source_metadata_path(&path));
371 }
372}