Skip to main content

rig_retrieval_evals/
dataset.rs

1//! Labeled retrieval datasets (qrels) and accompanying corpus / answer files.
2//!
3//! The crate adopts a JSONL line format compatible with the BEIR ecosystem so
4//! that public datasets (NQ, HotpotQA, FiQA, MS-MARCO subsets, …) can be
5//! consumed directly. The canonical shape for `qrels.jsonl` is:
6//!
7//! ```jsonl
8//! {"query_id":"q1","query":"…","relevant_docs":{"doc-7":2,"doc-9":1}}
9//! {"query_id":"q2","query":"…","relevant_docs":{"doc-3":1},"reference_answer":"…"}
10//! ```
11//!
12//! Grades in `relevant_docs` are integers 1–N where higher = more relevant.
13//! Documents not listed are treated as **non-relevant** (grade 0). This matches
14//! the standard TREC / BEIR qrels semantics.
15
16use std::collections::{BTreeMap, HashMap};
17use std::path::Path;
18
19use serde::{Deserialize, Serialize};
20use tracing::debug;
21
22use crate::error::{Error, Result};
23
24/// A single record in a BEIR `queries.jsonl` file.
25#[derive(Deserialize)]
26struct BeirQuery {
27    #[serde(rename = "_id")]
28    id: String,
29    #[serde(default)]
30    text: String,
31}
32
33/// A single labeled query in a retrieval dataset.
34///
35/// Field order is stable for downstream serialization; do not reorder without
36/// bumping the dataset schema version in [`Qrels`].
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct GoldQuery {
39    /// Stable opaque identifier for the query.
40    pub query_id: String,
41    /// Natural-language query text to send to the retriever.
42    pub query: String,
43    /// Map of `doc_id -> graded_relevance`. Documents not listed are treated
44    /// as non-relevant (grade 0). Grades are typically 1–3.
45    pub relevant_docs: HashMap<String, u8>,
46    /// Optional reference / "gold" answer used by answer-level evaluators.
47    /// Retrieval-only metrics ignore this.
48    #[serde(default, skip_serializing_if = "Option::is_none")]
49    pub reference_answer: Option<String>,
50}
51
52impl GoldQuery {
53    /// Returns `true` if `doc_id` is labeled relevant (grade ≥ 1).
54    #[must_use]
55    pub fn is_relevant(&self, doc_id: &str) -> bool {
56        self.relevant_docs
57            .get(doc_id)
58            .copied()
59            .is_some_and(|g| g >= 1)
60    }
61
62    /// Returns the graded relevance for `doc_id`, or 0 if unlabeled.
63    #[must_use]
64    pub fn grade(&self, doc_id: &str) -> u8 {
65        self.relevant_docs.get(doc_id).copied().unwrap_or(0)
66    }
67
68    /// Number of distinct documents labeled relevant (grade ≥ 1).
69    #[must_use]
70    pub fn relevant_count(&self) -> usize {
71        self.relevant_docs.values().filter(|g| **g >= 1).count()
72    }
73}
74
75/// A collection of [`GoldQuery`] forming a complete retrieval dataset.
76#[derive(Debug, Clone, Default, Serialize, Deserialize)]
77pub struct Qrels {
78    /// All labeled queries.
79    pub queries: Vec<GoldQuery>,
80}
81
82impl Qrels {
83    /// Load a JSONL qrels file from disk. Each line must deserialize into
84    /// [`GoldQuery`]; empty lines are skipped.
85    pub fn load_jsonl<P: AsRef<Path>>(path: P) -> Result<Self> {
86        let path = path.as_ref();
87        debug!(?path, "loading qrels");
88        let text = std::fs::read_to_string(path)?;
89        Self::from_jsonl_str(&text)
90    }
91
92    /// Parse a JSONL qrels payload from a string. Each non-empty line is
93    /// decoded into a [`GoldQuery`]. The 1-indexed line number is included in
94    /// any parse error.
95    pub fn from_jsonl_str(text: &str) -> Result<Self> {
96        let mut queries = Vec::new();
97        for (idx, raw_line) in text.lines().enumerate() {
98            let line = raw_line.trim();
99            if line.is_empty() {
100                continue;
101            }
102            let q: GoldQuery =
103                serde_json::from_str(line).map_err(|source| Error::DatasetParse {
104                    line: idx + 1,
105                    source,
106                })?;
107            queries.push(q);
108        }
109        Ok(Self { queries })
110    }
111
112    /// Load a downloaded BEIR / BRIGHT dataset directory into [`Qrels`].
113    ///
114    /// Standard IR benchmarks ship a `queries.jsonl` (`{"_id","text"}` records)
115    /// and TREC-style relevance judgments at `qrels/<split>.tsv`. This reads
116    /// both and produces one [`GoldQuery`] per query that has at least one
117    /// positive judgment, using the BEIR **corpus ids** directly as
118    /// [`GoldQuery::relevant_docs`] keys — so the retriever under test must
119    /// report those same ids as its `doc_id`s.
120    ///
121    /// The qrels TSV is accepted in both BEIR 3-column
122    /// (`query-id <tab> corpus-id <tab> score`) and TREC 4-column
123    /// (`query-id <tab> iteration <tab> doc-id <tab> relevance`) layouts. A
124    /// header row, blank lines, and any row whose relevance does not parse to a
125    /// positive integer (grade 0) are skipped, so a leading header is handled
126    /// whether or not it is present.
127    ///
128    /// `split` selects the file under `qrels/`, e.g. `"test"` or `"dev"`.
129    pub fn from_beir<P: AsRef<Path>>(dataset_dir: P, split: &str) -> Result<Self> {
130        let dir = dataset_dir.as_ref();
131        debug!(?dir, %split, "loading BEIR dataset");
132
133        // 1. query id -> text.
134        let queries_path = dir.join("queries.jsonl");
135        let queries_text = std::fs::read_to_string(&queries_path)?;
136        let mut query_text: HashMap<String, String> = HashMap::new();
137        for (idx, raw_line) in queries_text.lines().enumerate() {
138            let line = raw_line.trim();
139            if line.is_empty() {
140                continue;
141            }
142            let record: BeirQuery =
143                serde_json::from_str(line).map_err(|source| Error::DatasetParse {
144                    line: idx + 1,
145                    source,
146                })?;
147            query_text.insert(record.id, record.text);
148        }
149
150        // 2. Group positive judgments by query, preserving deterministic order.
151        let qrels_path = dir.join("qrels").join(format!("{split}.tsv"));
152        let qrels_text = std::fs::read_to_string(&qrels_path)?;
153        let mut grouped: BTreeMap<String, HashMap<String, u8>> = BTreeMap::new();
154        for raw_line in qrels_text.lines() {
155            let line = raw_line.trim();
156            if line.is_empty() {
157                continue;
158            }
159            let cols: Vec<&str> = line.split('\t').collect();
160            let (qid, doc_id, rel) = match cols.as_slice() {
161                [qid, doc_id, rel] => (*qid, *doc_id, *rel),
162                [qid, _iter, doc_id, rel] => (*qid, *doc_id, *rel),
163                _ => continue,
164            };
165            // A non-numeric / non-positive relevance (e.g. the header's
166            // "score") yields grade 0 and is dropped.
167            let grade: u8 = rel.trim().parse().unwrap_or(0);
168            if grade == 0 {
169                continue;
170            }
171            grouped
172                .entry(qid.trim().to_string())
173                .or_default()
174                .insert(doc_id.trim().to_string(), grade);
175        }
176
177        // 3. Emit a GoldQuery per query that has both text and a positive judgment.
178        let mut queries = Vec::with_capacity(grouped.len());
179        for (qid, relevant) in grouped {
180            let Some(text) = query_text.get(&qid) else {
181                continue;
182            };
183            queries.push(GoldQuery {
184                query_id: qid.clone(),
185                query: text.clone(),
186                relevant_docs: relevant,
187                reference_answer: None,
188            });
189        }
190        Ok(Self { queries })
191    }
192
193    /// Number of queries in the dataset.
194    #[must_use]
195    pub fn len(&self) -> usize {
196        self.queries.len()
197    }
198
199    /// True if the dataset is empty.
200    #[must_use]
201    pub fn is_empty(&self) -> bool {
202        self.queries.is_empty()
203    }
204}
205
206/// A single retrieval observation produced by a vector store for one gold
207/// query. `ranked` is sorted by descending similarity score.
208#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct RetrievedSet {
210    /// The [`GoldQuery::query_id`] this retrieval corresponds to.
211    pub query_id: String,
212    /// Hits in ranked order (highest score first).
213    pub ranked: Vec<RetrievedDoc>,
214}
215
216/// A single ranked retrieval hit.
217#[derive(Debug, Clone, Serialize, Deserialize)]
218pub struct RetrievedDoc {
219    /// Backend-assigned document id used to match against
220    /// [`GoldQuery::relevant_docs`].
221    pub doc_id: String,
222    /// Similarity score reported by the backend.
223    pub score: f64,
224}
225
226#[cfg(test)]
227#[allow(clippy::unwrap_used, clippy::panic, clippy::indexing_slicing)]
228mod tests {
229    use super::*;
230
231    #[test]
232    fn parses_well_formed_jsonl() {
233        let text = r#"{"query_id":"q1","query":"a","relevant_docs":{"d1":2,"d2":1}}
234        {"query_id":"q2","query":"b","relevant_docs":{"d3":1},"reference_answer":"yes"}
235
236        "#;
237        let q = Qrels::from_jsonl_str(text).unwrap();
238        assert_eq!(q.len(), 2);
239        assert!(q.queries[0].is_relevant("d1"));
240        assert_eq!(q.queries[0].grade("d2"), 1);
241        assert_eq!(q.queries[0].grade("missing"), 0);
242        assert_eq!(q.queries[1].reference_answer.as_deref(), Some("yes"));
243    }
244
245    #[test]
246    fn reports_line_on_parse_error() {
247        let text = "{\"query_id\":\"q1\",\"query\":\"a\",\"relevant_docs\":{}}\nnot json\n";
248        let err = Qrels::from_jsonl_str(text).unwrap_err();
249        match err {
250            Error::DatasetParse { line, .. } => assert_eq!(line, 2),
251            other => panic!("unexpected error: {other:?}"),
252        }
253    }
254
255    #[test]
256    fn relevant_count_excludes_zero_grades() {
257        let q = GoldQuery {
258            query_id: "q".into(),
259            query: "".into(),
260            relevant_docs: HashMap::from([
261                ("a".to_string(), 2u8),
262                ("b".to_string(), 0u8),
263                ("c".to_string(), 1u8),
264            ]),
265            reference_answer: None,
266        };
267        assert_eq!(q.relevant_count(), 2);
268    }
269}