use std::collections::{BTreeMap, HashMap};
use std::path::Path;
use serde::{Deserialize, Serialize};
use tracing::debug;
use crate::error::{Error, Result};
#[derive(Deserialize)]
struct BeirQuery {
#[serde(rename = "_id")]
id: String,
#[serde(default)]
text: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GoldQuery {
pub query_id: String,
pub query: String,
pub relevant_docs: HashMap<String, u8>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub reference_answer: Option<String>,
}
impl GoldQuery {
#[must_use]
pub fn is_relevant(&self, doc_id: &str) -> bool {
self.relevant_docs
.get(doc_id)
.copied()
.is_some_and(|g| g >= 1)
}
#[must_use]
pub fn grade(&self, doc_id: &str) -> u8 {
self.relevant_docs.get(doc_id).copied().unwrap_or(0)
}
#[must_use]
pub fn relevant_count(&self) -> usize {
self.relevant_docs.values().filter(|g| **g >= 1).count()
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct Qrels {
pub queries: Vec<GoldQuery>,
}
impl Qrels {
pub fn load_jsonl<P: AsRef<Path>>(path: P) -> Result<Self> {
let path = path.as_ref();
debug!(?path, "loading qrels");
let text = std::fs::read_to_string(path)?;
Self::from_jsonl_str(&text)
}
pub fn from_jsonl_str(text: &str) -> Result<Self> {
let mut queries = Vec::new();
for (idx, raw_line) in text.lines().enumerate() {
let line = raw_line.trim();
if line.is_empty() {
continue;
}
let q: GoldQuery =
serde_json::from_str(line).map_err(|source| Error::DatasetParse {
line: idx + 1,
source,
})?;
queries.push(q);
}
Ok(Self { queries })
}
pub fn from_beir<P: AsRef<Path>>(dataset_dir: P, split: &str) -> Result<Self> {
let dir = dataset_dir.as_ref();
debug!(?dir, %split, "loading BEIR dataset");
let queries_path = dir.join("queries.jsonl");
let queries_text = std::fs::read_to_string(&queries_path)?;
let mut query_text: HashMap<String, String> = HashMap::new();
for (idx, raw_line) in queries_text.lines().enumerate() {
let line = raw_line.trim();
if line.is_empty() {
continue;
}
let record: BeirQuery =
serde_json::from_str(line).map_err(|source| Error::DatasetParse {
line: idx + 1,
source,
})?;
query_text.insert(record.id, record.text);
}
let qrels_path = dir.join("qrels").join(format!("{split}.tsv"));
let qrels_text = std::fs::read_to_string(&qrels_path)?;
let mut grouped: BTreeMap<String, HashMap<String, u8>> = BTreeMap::new();
for raw_line in qrels_text.lines() {
let line = raw_line.trim();
if line.is_empty() {
continue;
}
let cols: Vec<&str> = line.split('\t').collect();
let (qid, doc_id, rel) = match cols.as_slice() {
[qid, doc_id, rel] => (*qid, *doc_id, *rel),
[qid, _iter, doc_id, rel] => (*qid, *doc_id, *rel),
_ => continue,
};
let grade: u8 = rel.trim().parse().unwrap_or(0);
if grade == 0 {
continue;
}
grouped
.entry(qid.trim().to_string())
.or_default()
.insert(doc_id.trim().to_string(), grade);
}
let mut queries = Vec::with_capacity(grouped.len());
for (qid, relevant) in grouped {
let Some(text) = query_text.get(&qid) else {
continue;
};
queries.push(GoldQuery {
query_id: qid.clone(),
query: text.clone(),
relevant_docs: relevant,
reference_answer: None,
});
}
Ok(Self { queries })
}
#[must_use]
pub fn len(&self) -> usize {
self.queries.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.queries.is_empty()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetrievedSet {
pub query_id: String,
pub ranked: Vec<RetrievedDoc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetrievedDoc {
pub doc_id: String,
pub score: f64,
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::panic, clippy::indexing_slicing)]
mod tests {
use super::*;
#[test]
fn parses_well_formed_jsonl() {
let text = r#"{"query_id":"q1","query":"a","relevant_docs":{"d1":2,"d2":1}}
{"query_id":"q2","query":"b","relevant_docs":{"d3":1},"reference_answer":"yes"}
"#;
let q = Qrels::from_jsonl_str(text).unwrap();
assert_eq!(q.len(), 2);
assert!(q.queries[0].is_relevant("d1"));
assert_eq!(q.queries[0].grade("d2"), 1);
assert_eq!(q.queries[0].grade("missing"), 0);
assert_eq!(q.queries[1].reference_answer.as_deref(), Some("yes"));
}
#[test]
fn reports_line_on_parse_error() {
let text = "{\"query_id\":\"q1\",\"query\":\"a\",\"relevant_docs\":{}}\nnot json\n";
let err = Qrels::from_jsonl_str(text).unwrap_err();
match err {
Error::DatasetParse { line, .. } => assert_eq!(line, 2),
other => panic!("unexpected error: {other:?}"),
}
}
#[test]
fn relevant_count_excludes_zero_grades() {
let q = GoldQuery {
query_id: "q".into(),
query: "".into(),
relevant_docs: HashMap::from([
("a".to_string(), 2u8),
("b".to_string(), 0u8),
("c".to_string(), 1u8),
]),
reference_answer: None,
};
assert_eq!(q.relevant_count(), 2);
}
}