use std::fs;
use std::path::{Path, PathBuf};
use anyhow::{Context, Result, anyhow};
use serde::{Deserialize, Serialize};
use super::DatasetSpec;
use crate::bench::Bench;
pub const HF_BASE: &str = "https://huggingface.co/datasets/Salesforce/ConvoMem/resolve/main/core_benchmark/evidence_questions";
pub const HEADLINE_CATEGORIES: &[&str] = &[
"assistant_facts_evidence",
"implicit_connection_evidence",
"preference_evidence",
"user_evidence",
"abstention_evidence",
];
pub const TREE_API: &str = "https://huggingface.co/api/datasets/Salesforce/ConvoMem/tree/main/core_benchmark/evidence_questions";
pub const DEFAULT_PER_CATEGORY_CAP: usize = 50;
pub const SPEC: DatasetSpec = DatasetSpec {
bench: Bench::Convomem,
filename: "convomem_evidence.json",
url: "https://huggingface.co/datasets/Salesforce/ConvoMem",
sha256: "",
bytes: 5 * 1024 * 1024,
};
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct Message {
#[serde(default)]
pub speaker: String,
#[serde(default)]
pub text: String,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct Conversation {
#[serde(default)]
pub messages: Vec<Message>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct MessageEvidence {
#[serde(default)]
pub text: String,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct EvidenceItem {
pub question: String,
#[serde(default)]
pub conversations: Vec<Conversation>,
#[serde(default)]
pub message_evidences: Vec<MessageEvidence>,
#[serde(default, rename = "_category_key")]
pub category_key: String,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct EvidenceFile {
pub evidence_items: Vec<EvidenceItem>,
}
pub fn load(path: &Path) -> Result<Vec<EvidenceItem>> {
let bytes = fs::read(path).with_context(|| format!("reading {}", path.display()))?;
let parsed: EvidenceFile =
serde_json::from_slice(&bytes).with_context(|| format!("parsing {}", path.display()))?;
Ok(parsed.evidence_items)
}
#[derive(Clone, Debug, Deserialize)]
struct TreeEntry {
#[serde(default)]
path: String,
#[serde(default, rename = "type")]
entry_type: String,
}
fn discover_files(cache_dir: &Path, category: &str) -> Result<Vec<String>> {
let cache_path = cache_dir.join(format!("{category}_filelist.json"));
if cache_path.is_file() {
let bytes =
fs::read(&cache_path).with_context(|| format!("reading {}", cache_path.display()))?;
let v: Vec<String> = serde_json::from_slice(&bytes)
.with_context(|| format!("parsing {}", cache_path.display()))?;
if !v.is_empty() {
return Ok(v);
}
}
let url = format!("{TREE_API}/{category}/1_evidence");
let resp = ureq::get(&url)
.call()
.with_context(|| format!("GET {url}"))?;
let mut body = String::new();
resp.into_reader()
.read_to_string(&mut body)
.context("read tree body")?;
let entries: Vec<TreeEntry> = serde_json::from_str(&body)
.with_context(|| format!("parsing tree response for {category}"))?;
let mut out = Vec::new();
for e in entries {
if e.entry_type == "file" && e.path.ends_with(".json") {
if let Some(name) = e.path.rsplit('/').next() {
out.push(name.to_string());
}
}
}
let bytes = serde_json::to_vec(&out).context("serialize filelist")?;
fs::write(&cache_path, &bytes).with_context(|| format!("writing {}", cache_path.display()))?;
Ok(out)
}
pub fn fetch_into(cache_dir: &Path) -> Result<PathBuf> {
fs::create_dir_all(cache_dir.join("shards"))
.with_context(|| format!("mkdir {}", cache_dir.display()))?;
let cap = std::env::var("MNEM_BENCH_CONVOMEM_PER_CAT")
.ok()
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(DEFAULT_PER_CATEGORY_CAP);
let mut merged = EvidenceFile {
evidence_items: Vec::new(),
};
for cat in HEADLINE_CATEGORIES {
let files = discover_files(cache_dir, cat).with_context(|| format!("discover {cat}"))?;
if files.is_empty() {
eprintln!("[convomem] no shards discovered for {cat}; skipping");
continue;
}
let mut taken = 0usize;
for fname in &files {
if taken >= cap {
break;
}
let shard_path = cache_dir.join("shards").join(cat).join(fname);
if !shard_path.is_file() {
fs::create_dir_all(shard_path.parent().unwrap()).ok();
let url = format!("{HF_BASE}/{cat}/1_evidence/{fname}");
let resp = match ureq::get(&url).call() {
Ok(r) => r,
Err(e) => {
eprintln!("[convomem] GET {url}: {e}; skipping");
continue;
}
};
let mut body: Vec<u8> = Vec::new();
if let Err(e) = resp.into_reader().read_to_end(&mut body) {
eprintln!("[convomem] read body for {fname}: {e}; skipping");
continue;
}
if let Err(e) = fs::write(&shard_path, &body) {
eprintln!("[convomem] write {}: {e}; skipping", shard_path.display());
continue;
}
}
let bytes = match fs::read(&shard_path) {
Ok(b) => b,
Err(e) => {
eprintln!("[convomem] reading {}: {e}", shard_path.display());
continue;
}
};
let parsed: EvidenceFile = match serde_json::from_slice(&bytes) {
Ok(p) => p,
Err(e) => {
eprintln!("[convomem] parsing {}: {e}", shard_path.display());
continue;
}
};
for mut it in parsed.evidence_items {
it.category_key = (*cat).to_string();
merged.evidence_items.push(it);
taken += 1;
if taken >= cap {
break;
}
}
}
eprintln!("[convomem] {cat}: {taken} items");
}
if merged.evidence_items.is_empty() {
return Err(anyhow!(
"convomem fetch yielded zero evidence_items; check network reachability for {HF_BASE}"
));
}
let dst = cache_dir.join("convomem_evidence.json");
let bytes = serde_json::to_vec(&merged).context("serialize merged blob")?;
fs::write(&dst, &bytes).with_context(|| format!("writing {}", dst.display()))?;
Ok(dst)
}
use std::io::Read;