use std::collections::HashMap;
use anyhow::{Context, Result};
use serde::Deserialize;
use crate::traits::BenchmarkDataset;
use crate::types::{BenchmarkQuestion, ConversationSession, Turn};
use super::download::download_dataset;
const DOWNLOAD_URL: &str = "https://raw.githubusercontent.com/snap-research/locomo/main/data/locomo10.json";
const FILENAME: &str = "locomo10.json";
pub struct LoCoMoDataset {
questions: Vec<BenchmarkQuestion>,
}
#[derive(Debug, Deserialize)]
struct RawSample {
sample_id: String,
conversation: HashMap<String, serde_json::Value>,
qa: Vec<RawQA>,
}
#[derive(Debug, Deserialize)]
struct RawQA {
question: String,
#[serde(default)]
answer: serde_json::Value, #[serde(default)]
evidence: Vec<String>,
#[serde(default)]
category: u32,
}
#[derive(Debug, Deserialize)]
struct RawUtterance {
#[serde(default)]
speaker: String,
#[serde(default)]
text: String,
#[serde(default)]
dia_id: String,
}
impl LoCoMoDataset {
pub async fn load(force_download: bool) -> Result<Self> {
let path = download_dataset(DOWNLOAD_URL, FILENAME, force_download).await?;
let content = tokio::fs::read_to_string(&path).await?;
Self::from_json(&content)
}
pub fn from_json(json: &str) -> Result<Self> {
let samples: Vec<RawSample> = serde_json::from_str(json)
.context("Failed to parse LoCoMo JSON")?;
let mut questions = Vec::new();
for sample in &samples {
let mut sessions = Vec::new();
let mut session_keys: Vec<String> = sample.conversation.keys()
.filter(|k| k.starts_with("session_"))
.cloned()
.collect();
session_keys.sort();
let date_keys: HashMap<String, String> = sample.conversation.iter()
.filter(|(k, _)| k.contains("date_time"))
.filter_map(|(k, v)| v.as_str().map(|s| (k.clone(), s.to_string())))
.collect();
for session_key in &session_keys {
if let Some(turns_val) = sample.conversation.get(session_key) {
let turns: Vec<RawUtterance> = serde_json::from_value(turns_val.clone())
.unwrap_or_default();
let date_key = format!("{}_date_time", session_key);
let date = date_keys.get(&date_key).cloned();
sessions.push(ConversationSession {
id: format!("{}_{}", sample.sample_id, session_key),
date,
turns: turns.iter().map(|u| Turn {
role: if u.speaker.contains("speaker_a") || u.speaker.contains("Speaker A") {
"user".to_string()
} else {
"assistant".to_string()
},
content: u.text.clone(),
}).collect(),
});
}
}
for (qi, qa) in sample.qa.iter().enumerate() {
let qtype = match qa.category {
1 => "single-hop",
2 => "temporal",
3 => "multi-hop",
4 => "open-domain",
5 => "unanswerable",
_ => "unknown",
};
let answer_str = match &qa.answer {
serde_json::Value::String(s) => s.clone(),
serde_json::Value::Number(n) => n.to_string(),
other => other.to_string(),
};
questions.push(BenchmarkQuestion {
id: format!("{}_q{}", sample.sample_id, qi),
question_type: qtype.to_string(),
question: qa.question.clone(),
ground_truth: vec![answer_str],
question_date: None,
sessions: sessions.clone(),
is_abstention: qa.category == 5,
metadata: std::collections::HashMap::new(),
});
}
}
Ok(Self { questions })
}
}
impl BenchmarkDataset for LoCoMoDataset {
fn name(&self) -> &str { "locomo" }
fn variant(&self) -> &str { "default" }
fn description(&self) -> &str { "LoCoMo (Snap Research) — 1,986 QA pairs across 10 long conversations, 5 categories" }
fn questions(&self) -> &[BenchmarkQuestion] { &self.questions }
fn question_types(&self) -> Vec<String> {
let mut types: Vec<String> = self.questions.iter()
.map(|q| q.question_type.clone())
.collect::<std::collections::HashSet<_>>()
.into_iter().collect();
types.sort();
types
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn category_mapping() {
for (cat, expected) in [(1, "single-hop"), (2, "temporal"), (3, "multi-hop"), (4, "open-domain"), (5, "unanswerable")] {
let json = format!(r#"[{{
"sample_id": "test",
"conversation": {{}},
"qa": [{{"question": "Q?", "answer": "A", "category": {cat}}}]
}}]"#);
let ds = LoCoMoDataset::from_json(&json).unwrap();
assert_eq!(ds.questions()[0].question_type, expected);
if cat == 5 {
assert!(ds.questions()[0].is_abstention);
}
}
}
}