Skip to main content

mag/
benchmarking.rs

1use std::path::{Path, PathBuf};
2
3use anyhow::{Context, Result, anyhow};
4use chrono::Utc;
5use serde::{Deserialize, Serialize};
6
7use crate::app_paths;
8
9const LONGMEMEVAL_DATASET_URLS: &[&str] = &[
10    "https://huggingface.co/datasets/LIXINYI33/longmemeval-s/resolve/main/longmemeval_s_cleaned.json",
11    "https://huggingface.co/datasets/kellyhongg/cleaned-longmemeval-s/resolve/main/longmemeval_s_cleaned.json",
12    "https://github.com/xiaowu0162/longmemeval-cleaned/raw/main/longmemeval_s_cleaned.json",
13];
14const LOCOMO_DATASET_URLS: &[&str] = &[
15    "https://raw.githubusercontent.com/snap-research/locomo/main/data/locomo10.json",
16    "https://github.com/snap-research/locomo/raw/main/data/locomo10.json",
17];
18
19#[derive(Debug, Clone, Serialize)]
20pub struct BenchmarkMetadata {
21    pub benchmark: String,
22    pub command: String,
23    pub date: String,
24    pub commit: Option<String>,
25    pub machine: String,
26    pub dataset_source: String,
27    pub dataset_path: String,
28}
29
30#[derive(Debug)]
31pub struct DatasetArtifact {
32    pub source_url: String,
33    pub path: PathBuf,
34    pub temporary: bool,
35}
36
37#[derive(Debug, Clone, Copy)]
38pub enum DatasetKind {
39    LongMemEval,
40    LoCoMo10,
41}
42
43impl DatasetKind {
44    fn cache_subdir(self) -> &'static str {
45        match self {
46            Self::LongMemEval => "longmemeval",
47            Self::LoCoMo10 => "locomo",
48        }
49    }
50
51    fn filename(self) -> &'static str {
52        match self {
53            Self::LongMemEval => "longmemeval_s_cleaned.json",
54            Self::LoCoMo10 => "locomo10.json",
55        }
56    }
57
58    fn source_urls(self) -> &'static [&'static str] {
59        match self {
60            Self::LongMemEval => LONGMEMEVAL_DATASET_URLS,
61            Self::LoCoMo10 => LOCOMO_DATASET_URLS,
62        }
63    }
64}
65
66pub async fn resolve_dataset(
67    kind: DatasetKind,
68    dataset_path: Option<PathBuf>,
69    force_refresh: bool,
70    temporary: bool,
71) -> Result<DatasetArtifact> {
72    if dataset_path.is_some() && (force_refresh || temporary) {
73        return Err(anyhow!(
74            "--dataset-path cannot be combined with --force-refresh or --temp-dataset"
75        ));
76    }
77    if let Some(path) = dataset_path {
78        validate_json_file(&path)?;
79        return Ok(DatasetArtifact {
80            source_url: "user-supplied".to_string(),
81            path,
82            temporary: false,
83        });
84    }
85
86    let cache_path = if temporary {
87        temporary_dataset_path(kind)
88    } else {
89        benchmark_cache_path(kind)?
90    };
91    let cache_is_valid = cache_path.exists() && validate_json_file(&cache_path).is_ok();
92    if force_refresh || !cache_is_valid {
93        let source_url = download_from_sources(kind.source_urls(), &cache_path).await?;
94        if !temporary {
95            write_source_metadata(&cache_path, &source_url)?;
96        }
97        return Ok(DatasetArtifact {
98            source_url,
99            path: cache_path,
100            temporary,
101        });
102    }
103    Ok(DatasetArtifact {
104        source_url: read_source_metadata(&cache_path)
105            .unwrap_or_else(|| kind.source_urls()[0].to_string()),
106        path: cache_path,
107        temporary,
108    })
109}
110
111pub fn benchmark_cache_path(kind: DatasetKind) -> Result<PathBuf> {
112    let cache_root = app_paths::resolve_app_paths()?.benchmark_root;
113    Ok(cache_root.join(kind.cache_subdir()).join(kind.filename()))
114}
115
116pub fn benchmark_metadata(benchmark: &str, dataset: &DatasetArtifact) -> BenchmarkMetadata {
117    benchmark_metadata_from_parts(
118        benchmark,
119        &dataset.source_url,
120        &dataset.path.display().to_string(),
121    )
122}
123
124pub fn benchmark_metadata_from_parts(
125    benchmark: &str,
126    dataset_source: &str,
127    dataset_path: &str,
128) -> BenchmarkMetadata {
129    BenchmarkMetadata {
130        benchmark: benchmark.to_string(),
131        command: sanitize_command(std::env::args()),
132        date: Utc::now().to_rfc3339(),
133        commit: git_commit(),
134        machine: machine_descriptor(),
135        dataset_source: dataset_source.to_string(),
136        dataset_path: sanitize_dataset_path(dataset_path),
137    }
138}
139
140fn validate_json_file(path: &Path) -> Result<()> {
141    let file = std::fs::File::open(path)
142        .with_context(|| format!("failed to open dataset at {}", path.display()))?;
143    let mut reader = std::io::BufReader::new(file);
144    let mut de = serde_json::Deserializer::from_reader(&mut reader);
145    serde::de::IgnoredAny::deserialize(&mut de)
146        .with_context(|| format!("failed to parse JSON dataset at {}", path.display()))?;
147    de.end()
148        .with_context(|| format!("trailing content in JSON dataset at {}", path.display()))?;
149    Ok(())
150}
151
152async fn download_from_sources(urls: &[&str], path: &Path) -> Result<String> {
153    let mut failures = Vec::new();
154    for url in urls {
155        match download_file(url, path).await {
156            Ok(()) => return Ok((*url).to_string()),
157            Err(err) => failures.push(format!("{url}: {err}")),
158        }
159    }
160    Err(anyhow!(
161        "failed to download benchmark dataset from any public source:\n{}",
162        failures.join("\n")
163    ))
164}
165
166async fn download_file(url: &str, path: &Path) -> Result<()> {
167    if let Some(parent) = path.parent() {
168        tokio::fs::create_dir_all(parent)
169            .await
170            .with_context(|| format!("failed to create dataset cache dir {}", parent.display()))?;
171    }
172
173    let client = reqwest::Client::builder()
174        .timeout(std::time::Duration::from_secs(600))
175        .connect_timeout(std::time::Duration::from_secs(30))
176        .build()
177        .context("failed to build benchmark download client")?;
178
179    let response = client
180        .get(url)
181        .send()
182        .await
183        .with_context(|| format!("failed to download benchmark dataset from {url}"))?
184        .error_for_status()
185        .with_context(|| format!("benchmark dataset download failed for {url}"))?;
186    let bytes = response
187        .bytes()
188        .await
189        .with_context(|| format!("failed to read benchmark dataset body from {url}"))?;
190
191    let mut part_name = path
192        .file_name()
193        .ok_or_else(|| anyhow!("invalid dataset file name for {}", path.display()))?
194        .to_os_string();
195    part_name.push(".part");
196    let part_path = path.with_file_name(part_name);
197    let write_result: Result<()> = async {
198        tokio::fs::write(&part_path, &bytes)
199            .await
200            .with_context(|| format!("failed to write {}", part_path.display()))?;
201        let validate_path = part_path.clone();
202        tokio::task::spawn_blocking(move || validate_json_file(&validate_path))
203            .await
204            .context("dataset validation task failed")??;
205        remove_file_if_exists(path).await?;
206        tokio::fs::rename(&part_path, path)
207            .await
208            .with_context(|| format!("failed to finalize {}", path.display()))?;
209        Ok(())
210    }
211    .await;
212    if let Err(err) = write_result {
213        let _ = tokio::fs::remove_file(&part_path).await;
214        return Err(err);
215    }
216    Ok(())
217}
218
219async fn remove_file_if_exists(path: &Path) -> Result<()> {
220    match tokio::fs::remove_file(path).await {
221        Ok(()) => Ok(()),
222        Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(()),
223        Err(err) => Err(err)
224            .with_context(|| format!("failed to remove existing dataset {}", path.display())),
225    }
226}
227
228fn read_source_metadata(path: &Path) -> Option<String> {
229    let meta_path = source_metadata_path(path);
230    let text = std::fs::read_to_string(meta_path).ok()?;
231    let trimmed = text.trim();
232    if trimmed.is_empty() {
233        None
234    } else {
235        Some(trimmed.to_string())
236    }
237}
238
239fn write_source_metadata(path: &Path, source_url: &str) -> Result<()> {
240    std::fs::write(source_metadata_path(path), source_url)
241        .with_context(|| format!("failed to write source metadata for {}", path.display()))
242}
243
244fn source_metadata_path(path: &Path) -> PathBuf {
245    let mut suffix = path.file_name().unwrap_or_default().to_os_string();
246    suffix.push(".source-url");
247    path.with_file_name(suffix)
248}
249
250fn temporary_dataset_path(kind: DatasetKind) -> PathBuf {
251    let stamp = Utc::now().format("%Y%m%d%H%M%S");
252    let filename = format!("{stamp}-{}-{}", std::process::id(), kind.filename());
253    std::env::temp_dir()
254        .join("mag-benchmarks")
255        .join(kind.cache_subdir())
256        .join(filename)
257}
258
259impl DatasetArtifact {
260    pub fn cleanup(&mut self) -> Result<()> {
261        if !self.temporary || !self.path.exists() {
262            return Ok(());
263        }
264        std::fs::remove_file(&self.path).with_context(|| {
265            format!("failed to remove temporary dataset {}", self.path.display())
266        })?;
267        self.temporary = false;
268        Ok(())
269    }
270}
271
272impl Drop for DatasetArtifact {
273    fn drop(&mut self) {
274        if self.temporary {
275            let _ = std::fs::remove_file(&self.path);
276        }
277    }
278}
279
280fn sanitize_command(args: impl IntoIterator<Item = String>) -> String {
281    args.into_iter()
282        .map(|arg| quote_shell_arg(&sanitize_arg(&arg)))
283        .collect::<Vec<_>>()
284        .join(" ")
285}
286
287fn sanitize_arg(arg: &str) -> String {
288    if looks_like_path(arg) {
289        "<redacted_path>".to_string()
290    } else {
291        arg.to_string()
292    }
293}
294
295fn quote_shell_arg(arg: &str) -> String {
296    if arg
297        .chars()
298        .any(|ch| ch.is_whitespace() || matches!(ch, '"' | '\\' | '$' | '`'))
299    {
300        format!("\"{}\"", arg.replace('\\', "\\\\").replace('"', "\\\""))
301    } else {
302        arg.to_string()
303    }
304}
305
306fn sanitize_dataset_path(dataset_path: &str) -> String {
307    Path::new(dataset_path)
308        .file_name()
309        .and_then(|value| value.to_str())
310        .map(ToOwned::to_owned)
311        .unwrap_or_else(|| "<redacted_path>".to_string())
312}
313
314fn looks_like_path(arg: &str) -> bool {
315    Path::new(arg).is_absolute() || arg.starts_with('~') || arg.contains('/') || arg.contains('\\')
316}
317
318fn git_commit() -> Option<String> {
319    command_stdout("git", &["rev-parse", "HEAD"])
320}
321
322fn machine_descriptor() -> String {
323    let mut parts = vec![
324        format!("{} {}", std::env::consts::OS, std::env::consts::ARCH),
325        format!("{} CPU", num_cpus::get()),
326    ];
327    if let Some(model) = command_stdout("sysctl", &["-n", "hw.model"]) {
328        parts.insert(0, model);
329    }
330    parts.join(", ")
331}
332
333fn command_stdout(cmd: &str, args: &[&str]) -> Option<String> {
334    let output = std::process::Command::new(cmd).args(args).output().ok()?;
335    if !output.status.success() {
336        return None;
337    }
338    let text = String::from_utf8_lossy(&output.stdout).trim().to_string();
339    if text.is_empty() { None } else { Some(text) }
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345    use uuid::Uuid;
346
347    #[tokio::test]
348    async fn resolve_dataset_uses_explicit_path_without_downloading() {
349        let path =
350            std::env::temp_dir().join(format!("mag-benchmark-fixture-{}.json", Uuid::new_v4()));
351        std::fs::write(&path, r#"[{"question":"hi"}]"#).unwrap();
352
353        let dataset = resolve_dataset(DatasetKind::LongMemEval, Some(path.clone()), false, false)
354            .await
355            .unwrap();
356        assert_eq!(dataset.path, path);
357        assert_eq!(dataset.source_url, "user-supplied");
358
359        let _ = std::fs::remove_file(path);
360    }
361
362    #[test]
363    fn source_metadata_round_trip() {
364        let path = std::env::temp_dir().join(format!("mag-benchmark-meta-{}.json", Uuid::new_v4()));
365        write_source_metadata(&path, "https://example.com/dataset.json").unwrap();
366        assert_eq!(
367            read_source_metadata(&path).as_deref(),
368            Some("https://example.com/dataset.json")
369        );
370        let _ = std::fs::remove_file(source_metadata_path(&path));
371    }
372}