mnem_bench/datasets/
convomem.rs1use std::fs;
25use std::path::{Path, PathBuf};
26
27use anyhow::{Context, Result, anyhow};
28use serde::{Deserialize, Serialize};
29
30use super::DatasetSpec;
31use crate::bench::Bench;
32
33pub const HF_BASE: &str = "https://huggingface.co/datasets/Salesforce/ConvoMem/resolve/main/core_benchmark/evidence_questions";
35
36pub const HEADLINE_CATEGORIES: &[&str] = &[
41 "assistant_facts_evidence",
42 "implicit_connection_evidence",
43 "preference_evidence",
44 "user_evidence",
45 "abstention_evidence",
46];
47
48pub const TREE_API: &str = "https://huggingface.co/api/datasets/Salesforce/ConvoMem/tree/main/core_benchmark/evidence_questions";
52
53pub const DEFAULT_PER_CATEGORY_CAP: usize = 50;
57
58pub const SPEC: DatasetSpec = DatasetSpec {
64 bench: Bench::Convomem,
65 filename: "convomem_evidence.json",
66 url: "https://huggingface.co/datasets/Salesforce/ConvoMem",
67 sha256: "",
68 bytes: 5 * 1024 * 1024,
69};
70
71#[derive(Clone, Debug, Deserialize, Serialize)]
73pub struct Message {
74 #[serde(default)]
76 pub speaker: String,
77 #[serde(default)]
79 pub text: String,
80}
81
82#[derive(Clone, Debug, Deserialize, Serialize)]
84pub struct Conversation {
85 #[serde(default)]
87 pub messages: Vec<Message>,
88}
89
90#[derive(Clone, Debug, Deserialize, Serialize)]
93pub struct MessageEvidence {
94 #[serde(default)]
97 pub text: String,
98}
99
100#[derive(Clone, Debug, Deserialize, Serialize)]
102pub struct EvidenceItem {
103 pub question: String,
105 #[serde(default)]
107 pub conversations: Vec<Conversation>,
108 #[serde(default)]
110 pub message_evidences: Vec<MessageEvidence>,
111 #[serde(default, rename = "_category_key")]
114 pub category_key: String,
115}
116
117#[derive(Clone, Debug, Deserialize, Serialize)]
120pub struct EvidenceFile {
121 pub evidence_items: Vec<EvidenceItem>,
124}
125
126pub fn load(path: &Path) -> Result<Vec<EvidenceItem>> {
128 let bytes = fs::read(path).with_context(|| format!("reading {}", path.display()))?;
129 let parsed: EvidenceFile =
130 serde_json::from_slice(&bytes).with_context(|| format!("parsing {}", path.display()))?;
131 Ok(parsed.evidence_items)
132}
133
134#[derive(Clone, Debug, Deserialize)]
136struct TreeEntry {
137 #[serde(default)]
138 path: String,
139 #[serde(default, rename = "type")]
140 entry_type: String,
141}
142
143fn discover_files(cache_dir: &Path, category: &str) -> Result<Vec<String>> {
147 let cache_path = cache_dir.join(format!("{category}_filelist.json"));
148 if cache_path.is_file() {
149 let bytes =
150 fs::read(&cache_path).with_context(|| format!("reading {}", cache_path.display()))?;
151 let v: Vec<String> = serde_json::from_slice(&bytes)
152 .with_context(|| format!("parsing {}", cache_path.display()))?;
153 if !v.is_empty() {
154 return Ok(v);
155 }
156 }
157 let url = format!("{TREE_API}/{category}/1_evidence");
158 let resp = ureq::get(&url)
159 .call()
160 .with_context(|| format!("GET {url}"))?;
161 let mut body = String::new();
162 resp.into_reader()
163 .read_to_string(&mut body)
164 .context("read tree body")?;
165 let entries: Vec<TreeEntry> = serde_json::from_str(&body)
166 .with_context(|| format!("parsing tree response for {category}"))?;
167 let mut out = Vec::new();
168 for e in entries {
169 if e.entry_type == "file" && e.path.ends_with(".json") {
170 if let Some(name) = e.path.rsplit('/').next() {
173 out.push(name.to_string());
174 }
175 }
176 }
177 let bytes = serde_json::to_vec(&out).context("serialize filelist")?;
178 fs::write(&cache_path, &bytes).with_context(|| format!("writing {}", cache_path.display()))?;
179 Ok(out)
180}
181
182pub fn fetch_into(cache_dir: &Path) -> Result<PathBuf> {
190 fs::create_dir_all(cache_dir.join("shards"))
191 .with_context(|| format!("mkdir {}", cache_dir.display()))?;
192 let cap = std::env::var("MNEM_BENCH_CONVOMEM_PER_CAT")
193 .ok()
194 .and_then(|s| s.parse::<usize>().ok())
195 .unwrap_or(DEFAULT_PER_CATEGORY_CAP);
196
197 let mut merged = EvidenceFile {
198 evidence_items: Vec::new(),
199 };
200 for cat in HEADLINE_CATEGORIES {
201 let files = discover_files(cache_dir, cat).with_context(|| format!("discover {cat}"))?;
202 if files.is_empty() {
203 eprintln!("[convomem] no shards discovered for {cat}; skipping");
204 continue;
205 }
206 let mut taken = 0usize;
207 for fname in &files {
208 if taken >= cap {
209 break;
210 }
211 let shard_path = cache_dir.join("shards").join(cat).join(fname);
212 if !shard_path.is_file() {
213 fs::create_dir_all(shard_path.parent().unwrap()).ok();
214 let url = format!("{HF_BASE}/{cat}/1_evidence/{fname}");
215 let resp = match ureq::get(&url).call() {
216 Ok(r) => r,
217 Err(e) => {
218 eprintln!("[convomem] GET {url}: {e}; skipping");
219 continue;
220 }
221 };
222 let mut body: Vec<u8> = Vec::new();
223 if let Err(e) = resp.into_reader().read_to_end(&mut body) {
224 eprintln!("[convomem] read body for {fname}: {e}; skipping");
225 continue;
226 }
227 if let Err(e) = fs::write(&shard_path, &body) {
228 eprintln!("[convomem] write {}: {e}; skipping", shard_path.display());
229 continue;
230 }
231 }
232 let bytes = match fs::read(&shard_path) {
233 Ok(b) => b,
234 Err(e) => {
235 eprintln!("[convomem] reading {}: {e}", shard_path.display());
236 continue;
237 }
238 };
239 let parsed: EvidenceFile = match serde_json::from_slice(&bytes) {
240 Ok(p) => p,
241 Err(e) => {
242 eprintln!("[convomem] parsing {}: {e}", shard_path.display());
243 continue;
244 }
245 };
246 for mut it in parsed.evidence_items {
247 it.category_key = (*cat).to_string();
248 merged.evidence_items.push(it);
249 taken += 1;
250 if taken >= cap {
251 break;
252 }
253 }
254 }
255 eprintln!("[convomem] {cat}: {taken} items");
256 }
257 if merged.evidence_items.is_empty() {
258 return Err(anyhow!(
259 "convomem fetch yielded zero evidence_items; check network reachability for {HF_BASE}"
260 ));
261 }
262 let dst = cache_dir.join("convomem_evidence.json");
263 let bytes = serde_json::to_vec(&merged).context("serialize merged blob")?;
264 fs::write(&dst, &bytes).with_context(|| format!("writing {}", dst.display()))?;
265 Ok(dst)
266}
267
268use std::io::Read;