Skip to main content

mnem_bench/datasets/
convomem.rs

1//! ConvoMem (Salesforce) dataset spec + loader.
2//!
3//! Source: HuggingFace `Salesforce/ConvoMem`,
4//! `core_benchmark/evidence_questions/<category>/1_evidence/<file>.json`.
5//!
6//! Layout differs from LongMemEval / LoCoMo: there is no single
7//! download. Instead the bench-harness ships a small bundled manifest
8//! listing one shard URL per (category, evidence_file). On `fetch`
9//! we walk the manifest, download each shard into the cache, and
10//! merge the `evidence_items` arrays into a single
11//! `convomem_evidence.json` blob the scorer consumes.
12//!
13//! Cache layout:
14//!
15//! ```text
16//! ~/.mnem/bench-data/convomem/
17//!   convomem_evidence.json      <- merged blob, what the scorer loads
18//!   shards/<category>/<file>    <- per-shard cache (idempotent fetches)
19//! ```
20//!
21//! The merged blob is treated as the canonical artefact for
22//! [`crate::datasets::is_cached`] / [`crate::datasets::cached_path`].
23
24use std::fs;
25use std::path::{Path, PathBuf};
26
27use anyhow::{Context, Result, anyhow};
28use serde::{Deserialize, Serialize};
29
30use super::DatasetSpec;
31use crate::bench::Bench;
32
33/// HuggingFace base URL for the ConvoMem evidence_questions tree.
34pub const HF_BASE: &str = "https://huggingface.co/datasets/Salesforce/ConvoMem/resolve/main/core_benchmark/evidence_questions";
35
36/// Five headline categories used by the MemPalace headline numbers.
37/// Ordering matches `convomem.py`'s default; identical to the
38/// upstream `CATEGORIES` keys minus `changing_evidence` (which is
39/// unstable across MemPalace runs and excluded from the headline).
40pub const HEADLINE_CATEGORIES: &[&str] = &[
41    "assistant_facts_evidence",
42    "implicit_connection_evidence",
43    "preference_evidence",
44    "user_evidence",
45    "abstention_evidence",
46];
47
48/// HuggingFace tree API base. The fetcher hits
49/// `<TREE_API>/<category>/1_evidence` to discover the per-role
50/// shard filenames, then downloads each via [`HF_BASE`].
51pub const TREE_API: &str = "https://huggingface.co/api/datasets/Salesforce/ConvoMem/tree/main/core_benchmark/evidence_questions";
52
53/// Default per-category shard cap. The headline 50/cat slice fits
54/// well under this; raise via `MNEM_BENCH_CONVOMEM_PER_CAT` for the
55/// full sweep.
56pub const DEFAULT_PER_CATEGORY_CAP: usize = 50;
57
58/// Static spec. The ConvoMem fetcher does NOT use this URL directly
59/// (it walks the HF tree API instead) but keeping the field
60/// non-empty keeps the [`crate::datasets::DatasetSpec`] surface
61/// uniform with LongMemEval / LoCoMo. `sha256` is empty because the
62/// merged blob is composed at fetch-time.
63pub const SPEC: DatasetSpec = DatasetSpec {
64    bench: Bench::Convomem,
65    filename: "convomem_evidence.json",
66    url: "https://huggingface.co/datasets/Salesforce/ConvoMem",
67    sha256: "",
68    bytes: 5 * 1024 * 1024,
69};
70
71/// One conversation message inside an evidence item.
72#[derive(Clone, Debug, Deserialize, Serialize)]
73pub struct Message {
74    /// Speaker tag. Free-form, usually `"user"` or `"assistant"`.
75    #[serde(default)]
76    pub speaker: String,
77    /// Message body.
78    #[serde(default)]
79    pub text: String,
80}
81
82/// One conversation (a chat thread).
83#[derive(Clone, Debug, Deserialize, Serialize)]
84pub struct Conversation {
85    /// Messages within the conversation.
86    #[serde(default)]
87    pub messages: Vec<Message>,
88}
89
90/// One piece of gold evidence: a substring expected to appear in
91/// retrieved candidates.
92#[derive(Clone, Debug, Deserialize, Serialize)]
93pub struct MessageEvidence {
94    /// Substring that must appear in any retrieved candidate (either
95    /// direction) for the evidence to count as "found".
96    #[serde(default)]
97    pub text: String,
98}
99
100/// One evidence item the scorer processes.
101#[derive(Clone, Debug, Deserialize, Serialize)]
102pub struct EvidenceItem {
103    /// Question text.
104    pub question: String,
105    /// Conversation history to ingest.
106    #[serde(default)]
107    pub conversations: Vec<Conversation>,
108    /// Gold evidence substrings.
109    #[serde(default)]
110    pub message_evidences: Vec<MessageEvidence>,
111    /// Category bucket (filled in by the loader; never present in
112    /// the upstream JSON).
113    #[serde(default, rename = "_category_key")]
114    pub category_key: String,
115}
116
117/// File format on disk. Mirrors the merged-blob shape produced by
118/// [`fetch_into`].
119#[derive(Clone, Debug, Deserialize, Serialize)]
120pub struct EvidenceFile {
121    /// Items belonging to this file. Each carries its own
122    /// `_category_key` (set at fetch time).
123    pub evidence_items: Vec<EvidenceItem>,
124}
125
126/// Load + parse the merged blob at `path`.
127pub fn load(path: &Path) -> Result<Vec<EvidenceItem>> {
128    let bytes = fs::read(path).with_context(|| format!("reading {}", path.display()))?;
129    let parsed: EvidenceFile =
130        serde_json::from_slice(&bytes).with_context(|| format!("parsing {}", path.display()))?;
131    Ok(parsed.evidence_items)
132}
133
134/// One discovered shard.
135#[derive(Clone, Debug, Deserialize)]
136struct TreeEntry {
137    #[serde(default)]
138    path: String,
139    #[serde(default, rename = "type")]
140    entry_type: String,
141}
142
143/// Walk the HF tree API to discover `1_evidence/*.json` filenames
144/// for `category`. Cached at `cache_dir/<category>_filelist.json` so
145/// reruns skip the API hop.
146fn discover_files(cache_dir: &Path, category: &str) -> Result<Vec<String>> {
147    let cache_path = cache_dir.join(format!("{category}_filelist.json"));
148    if cache_path.is_file() {
149        let bytes =
150            fs::read(&cache_path).with_context(|| format!("reading {}", cache_path.display()))?;
151        let v: Vec<String> = serde_json::from_slice(&bytes)
152            .with_context(|| format!("parsing {}", cache_path.display()))?;
153        if !v.is_empty() {
154            return Ok(v);
155        }
156    }
157    let url = format!("{TREE_API}/{category}/1_evidence");
158    let resp = ureq::get(&url)
159        .call()
160        .with_context(|| format!("GET {url}"))?;
161    let mut body = String::new();
162    resp.into_reader()
163        .read_to_string(&mut body)
164        .context("read tree body")?;
165    let entries: Vec<TreeEntry> = serde_json::from_str(&body)
166        .with_context(|| format!("parsing tree response for {category}"))?;
167    let mut out = Vec::new();
168    for e in entries {
169        if e.entry_type == "file" && e.path.ends_with(".json") {
170            // Path looks like `<...>/1_evidence/<file>`; we only need
171            // the filename.
172            if let Some(name) = e.path.rsplit('/').next() {
173                out.push(name.to_string());
174            }
175        }
176    }
177    let bytes = serde_json::to_vec(&out).context("serialize filelist")?;
178    fs::write(&cache_path, &bytes).with_context(|| format!("writing {}", cache_path.display()))?;
179    Ok(out)
180}
181
182/// Fetch ConvoMem shards for every headline category, merge into a
183/// single canonical blob under `cache_dir`, return the path.
184///
185/// Discovery is dynamic (HF tree API) so we never have to ship a
186/// stale manifest. Per-category cap defaults to
187/// [`DEFAULT_PER_CATEGORY_CAP`]; override via the
188/// `MNEM_BENCH_CONVOMEM_PER_CAT` env var.
189pub fn fetch_into(cache_dir: &Path) -> Result<PathBuf> {
190    fs::create_dir_all(cache_dir.join("shards"))
191        .with_context(|| format!("mkdir {}", cache_dir.display()))?;
192    let cap = std::env::var("MNEM_BENCH_CONVOMEM_PER_CAT")
193        .ok()
194        .and_then(|s| s.parse::<usize>().ok())
195        .unwrap_or(DEFAULT_PER_CATEGORY_CAP);
196
197    let mut merged = EvidenceFile {
198        evidence_items: Vec::new(),
199    };
200    for cat in HEADLINE_CATEGORIES {
201        let files = discover_files(cache_dir, cat).with_context(|| format!("discover {cat}"))?;
202        if files.is_empty() {
203            eprintln!("[convomem] no shards discovered for {cat}; skipping");
204            continue;
205        }
206        let mut taken = 0usize;
207        for fname in &files {
208            if taken >= cap {
209                break;
210            }
211            let shard_path = cache_dir.join("shards").join(cat).join(fname);
212            if !shard_path.is_file() {
213                fs::create_dir_all(shard_path.parent().unwrap()).ok();
214                let url = format!("{HF_BASE}/{cat}/1_evidence/{fname}");
215                let resp = match ureq::get(&url).call() {
216                    Ok(r) => r,
217                    Err(e) => {
218                        eprintln!("[convomem] GET {url}: {e}; skipping");
219                        continue;
220                    }
221                };
222                let mut body: Vec<u8> = Vec::new();
223                if let Err(e) = resp.into_reader().read_to_end(&mut body) {
224                    eprintln!("[convomem] read body for {fname}: {e}; skipping");
225                    continue;
226                }
227                if let Err(e) = fs::write(&shard_path, &body) {
228                    eprintln!("[convomem] write {}: {e}; skipping", shard_path.display());
229                    continue;
230                }
231            }
232            let bytes = match fs::read(&shard_path) {
233                Ok(b) => b,
234                Err(e) => {
235                    eprintln!("[convomem] reading {}: {e}", shard_path.display());
236                    continue;
237                }
238            };
239            let parsed: EvidenceFile = match serde_json::from_slice(&bytes) {
240                Ok(p) => p,
241                Err(e) => {
242                    eprintln!("[convomem] parsing {}: {e}", shard_path.display());
243                    continue;
244                }
245            };
246            for mut it in parsed.evidence_items {
247                it.category_key = (*cat).to_string();
248                merged.evidence_items.push(it);
249                taken += 1;
250                if taken >= cap {
251                    break;
252                }
253            }
254        }
255        eprintln!("[convomem] {cat}: {taken} items");
256    }
257    if merged.evidence_items.is_empty() {
258        return Err(anyhow!(
259            "convomem fetch yielded zero evidence_items; check network reachability for {HF_BASE}"
260        ));
261    }
262    let dst = cache_dir.join("convomem_evidence.json");
263    let bytes = serde_json::to_vec(&merged).context("serialize merged blob")?;
264    fs::write(&dst, &bytes).with_context(|| format!("writing {}", dst.display()))?;
265    Ok(dst)
266}
267
268// Pull `Read` into scope for `into_reader().read_to_end/string`.
269use std::io::Read;