use std::collections::BTreeMap;
use std::io::{BufRead, BufReader, Write};
use std::path::{Path, PathBuf};
use serde::Deserialize;
use crate::error::{DciError, Result};
#[derive(Debug, Clone)]
pub struct PreparedDataset {
pub corpus_dir: PathBuf,
pub qrels_path: PathBuf,
pub documents: usize,
pub queries: usize,
}
#[derive(Deserialize)]
struct CorpusRecord {
#[serde(rename = "_id")]
id: String,
#[serde(default)]
title: String,
#[serde(default)]
text: String,
}
#[derive(Deserialize)]
struct QueryRecord {
#[serde(rename = "_id")]
id: String,
#[serde(default)]
text: String,
}
pub fn prepare(dataset_dir: &Path, out_dir: &Path, split: &str) -> Result<PreparedDataset> {
let corpus_dir = out_dir.join("corpus");
std::fs::create_dir_all(&corpus_dir).map_err(|e| DciError::Io {
path: corpus_dir.clone(),
source: e,
})?;
let corpus_jsonl = dataset_dir.join("corpus.jsonl");
let mut id_to_file: BTreeMap<String, String> = BTreeMap::new();
let mut documents = 0usize;
for line in read_lines(&corpus_jsonl)? {
let line = line.map_err(|e| DciError::Io {
path: corpus_jsonl.clone(),
source: e,
})?;
if line.trim().is_empty() {
continue;
}
let record: CorpusRecord = serde_json::from_str(&line)
.map_err(|e| DciError::InvalidPattern(format!("corpus.jsonl: {e}")))?;
let file_name = format!("{}.txt", sanitize_id(&record.id));
let body = if record.title.is_empty() {
record.text
} else {
format!("{}\n\n{}", record.title, record.text)
};
let path = corpus_dir.join(&file_name);
std::fs::write(&path, body).map_err(|e| DciError::Io {
path: path.clone(),
source: e,
})?;
id_to_file.insert(record.id, file_name);
documents += 1;
}
let queries_jsonl = dataset_dir.join("queries.jsonl");
let mut query_text: BTreeMap<String, String> = BTreeMap::new();
for line in read_lines(&queries_jsonl)? {
let line = line.map_err(|e| DciError::Io {
path: queries_jsonl.clone(),
source: e,
})?;
if line.trim().is_empty() {
continue;
}
let record: QueryRecord = serde_json::from_str(&line)
.map_err(|e| DciError::InvalidPattern(format!("queries.jsonl: {e}")))?;
query_text.insert(record.id, record.text);
}
let qrels_tsv = dataset_dir.join("qrels").join(format!("{split}.tsv"));
let mut grouped: BTreeMap<String, BTreeMap<String, u8>> = BTreeMap::new();
for (lineno, line) in read_lines(&qrels_tsv)?.enumerate() {
let line = line.map_err(|e| DciError::Io {
path: qrels_tsv.clone(),
source: e,
})?;
if lineno == 0 || line.trim().is_empty() {
continue;
}
let cols: Vec<&str> = line.split('\t').collect();
let (qid, doc_id, rel) = match cols.as_slice() {
[qid, doc_id, rel] => (*qid, *doc_id, *rel),
[qid, _iter, doc_id, rel] => (*qid, *doc_id, *rel),
_ => continue,
};
let grade: u8 = rel.trim().parse().unwrap_or(0);
if grade == 0 {
continue;
}
if let Some(file_name) = id_to_file.get(doc_id.trim()) {
grouped
.entry(qid.trim().to_string())
.or_default()
.insert(file_name.clone(), grade);
}
}
let qrels_path = out_dir.join("qrels.jsonl");
let mut out = std::fs::File::create(&qrels_path).map_err(|e| DciError::Io {
path: qrels_path.clone(),
source: e,
})?;
let mut queries = 0usize;
for (qid, relevant) in &grouped {
let Some(text) = query_text.get(qid) else {
continue;
};
let record = serde_json::json!({
"query_id": qid,
"query": text,
"relevant_docs": relevant,
});
writeln!(out, "{record}").map_err(|e| DciError::Io {
path: qrels_path.clone(),
source: e,
})?;
queries += 1;
}
Ok(PreparedDataset {
corpus_dir,
qrels_path,
documents,
queries,
})
}
fn sanitize_id(id: &str) -> String {
let mut out = String::with_capacity(id.len());
for ch in id.chars() {
if ch.is_ascii_alphanumeric() || ch == '-' || ch == '_' {
out.push(ch);
} else {
out.push('_');
}
}
out
}
fn read_lines(path: &Path) -> Result<std::io::Lines<BufReader<std::fs::File>>> {
let file = std::fs::File::open(path).map_err(|e| DciError::Io {
path: path.to_path_buf(),
source: e,
})?;
Ok(BufReader::new(file).lines())
}
#[cfg(test)]
mod tests {
#![allow(
clippy::unwrap_used,
clippy::expect_used,
clippy::indexing_slicing,
clippy::panic
)]
use super::*;
fn write(path: &Path, contents: &str) {
std::fs::create_dir_all(path.parent().unwrap()).unwrap();
std::fs::write(path, contents).unwrap();
}
#[test]
fn prepares_a_minimal_beir_dataset() {
let dataset = tempfile::tempdir().unwrap();
let out = tempfile::tempdir().unwrap();
let d = dataset.path();
write(
&d.join("corpus.jsonl"),
"{\"_id\":\"doc/1\",\"title\":\"Orwell\",\"text\":\"wrote 1984\"}\n\
{\"_id\":\"doc2\",\"title\":\"\",\"text\":\"capital of france is paris\"}\n",
);
write(
&d.join("queries.jsonl"),
"{\"_id\":\"q1\",\"text\":\"who wrote 1984\"}\n\
{\"_id\":\"q2\",\"text\":\"capital of france\"}\n",
);
write(
&d.join("qrels").join("test.tsv"),
"query-id\tcorpus-id\tscore\nq1\tdoc/1\t2\nq2\tdoc2\t1\n",
);
let prepared = prepare(d, out.path(), "test").expect("prepare");
assert_eq!(prepared.documents, 2);
assert_eq!(prepared.queries, 2);
assert!(prepared.corpus_dir.join("doc_1.txt").is_file());
assert!(prepared.corpus_dir.join("doc2.txt").is_file());
let qrels =
rig_retrieval_evals::dataset::Qrels::load_jsonl(&prepared.qrels_path).expect("load");
assert_eq!(qrels.queries.len(), 2);
let q1 = qrels
.queries
.iter()
.find(|q| q.query_id == "q1")
.expect("q1");
assert_eq!(q1.grade("doc_1.txt"), 2);
}
}