use std::collections::BTreeMap;
use std::fs;
use std::path::Path;
use anyhow::{Context, Result};
use serde::Deserialize;
use serde_json::Value;
use super::DatasetSpec;
use crate::bench::Bench;
pub const SPEC: DatasetSpec = DatasetSpec {
bench: Bench::Locomo,
filename: "locomo10.json",
url: "https://raw.githubusercontent.com/snap-research/locomo/main/data/locomo10.json",
sha256: "",
bytes: 3 * 1024 * 1024,
};
#[derive(Clone, Debug, Deserialize)]
pub struct Conversation {
#[serde(default)]
pub sample_id: Option<String>,
#[serde(default)]
pub qa: Vec<Qa>,
#[serde(default)]
pub conversation: BTreeMap<String, Value>,
}
#[derive(Clone, Debug, Deserialize)]
pub struct Qa {
#[serde(default)]
pub question: String,
#[serde(default)]
pub answer: Value,
#[serde(default)]
pub evidence: Vec<String>,
#[serde(default)]
pub category: u32,
}
#[derive(Clone, Debug, Deserialize)]
pub struct Dialog {
#[serde(default)]
pub speaker: String,
#[serde(default)]
pub text: String,
#[serde(default)]
pub dia_id: String,
}
#[must_use]
pub fn category_name(c: u32) -> &'static str {
match c {
1 => "single-hop",
2 => "multi-hop",
3 => "open-domain",
4 => "temporal",
5 => "common-sense",
6 => "adversarial",
_ => "unknown",
}
}
pub fn load(path: &Path) -> Result<Vec<Conversation>> {
let bytes = fs::read(path).with_context(|| format!("reading {}", path.display()))?;
serde_json::from_slice(&bytes).with_context(|| format!("parsing {}", path.display()))
}
pub fn iter_sessions(
conv: &BTreeMap<String, Value>,
) -> impl Iterator<Item = (usize, String, Vec<Dialog>)> + '_ {
SessionIter { conv, idx: 1 }
}
struct SessionIter<'a> {
conv: &'a BTreeMap<String, Value>,
idx: usize,
}
impl Iterator for SessionIter<'_> {
type Item = (usize, String, Vec<Dialog>);
fn next(&mut self) -> Option<Self::Item> {
let key = format!("session_{}", self.idx);
let raw = self.conv.get(&key)?;
let date = self
.conv
.get(&format!("{key}_date_time"))
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let dialogs: Vec<Dialog> = match raw {
Value::Array(arr) => arr
.iter()
.filter_map(|v| serde_json::from_value(v.clone()).ok())
.collect(),
_ => Vec::new(),
};
let i = self.idx;
self.idx += 1;
Some((i, date, dialogs))
}
}