1use std::collections::{BTreeMap, HashMap};
17use std::path::Path;
18
19use serde::{Deserialize, Serialize};
20use tracing::debug;
21
22use crate::error::{Error, Result};
23
24#[derive(Deserialize)]
26struct BeirQuery {
27 #[serde(rename = "_id")]
28 id: String,
29 #[serde(default)]
30 text: String,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct GoldQuery {
39 pub query_id: String,
41 pub query: String,
43 pub relevant_docs: HashMap<String, u8>,
46 #[serde(default, skip_serializing_if = "Option::is_none")]
49 pub reference_answer: Option<String>,
50}
51
52impl GoldQuery {
53 #[must_use]
55 pub fn is_relevant(&self, doc_id: &str) -> bool {
56 self.relevant_docs
57 .get(doc_id)
58 .copied()
59 .is_some_and(|g| g >= 1)
60 }
61
62 #[must_use]
64 pub fn grade(&self, doc_id: &str) -> u8 {
65 self.relevant_docs.get(doc_id).copied().unwrap_or(0)
66 }
67
68 #[must_use]
70 pub fn relevant_count(&self) -> usize {
71 self.relevant_docs.values().filter(|g| **g >= 1).count()
72 }
73}
74
75#[derive(Debug, Clone, Default, Serialize, Deserialize)]
77pub struct Qrels {
78 pub queries: Vec<GoldQuery>,
80}
81
82impl Qrels {
83 pub fn load_jsonl<P: AsRef<Path>>(path: P) -> Result<Self> {
86 let path = path.as_ref();
87 debug!(?path, "loading qrels");
88 let text = std::fs::read_to_string(path)?;
89 Self::from_jsonl_str(&text)
90 }
91
92 pub fn from_jsonl_str(text: &str) -> Result<Self> {
96 let mut queries = Vec::new();
97 for (idx, raw_line) in text.lines().enumerate() {
98 let line = raw_line.trim();
99 if line.is_empty() {
100 continue;
101 }
102 let q: GoldQuery =
103 serde_json::from_str(line).map_err(|source| Error::DatasetParse {
104 line: idx + 1,
105 source,
106 })?;
107 queries.push(q);
108 }
109 Ok(Self { queries })
110 }
111
112 pub fn from_beir<P: AsRef<Path>>(dataset_dir: P, split: &str) -> Result<Self> {
130 let dir = dataset_dir.as_ref();
131 debug!(?dir, %split, "loading BEIR dataset");
132
133 let queries_path = dir.join("queries.jsonl");
135 let queries_text = std::fs::read_to_string(&queries_path)?;
136 let mut query_text: HashMap<String, String> = HashMap::new();
137 for (idx, raw_line) in queries_text.lines().enumerate() {
138 let line = raw_line.trim();
139 if line.is_empty() {
140 continue;
141 }
142 let record: BeirQuery =
143 serde_json::from_str(line).map_err(|source| Error::DatasetParse {
144 line: idx + 1,
145 source,
146 })?;
147 query_text.insert(record.id, record.text);
148 }
149
150 let qrels_path = dir.join("qrels").join(format!("{split}.tsv"));
152 let qrels_text = std::fs::read_to_string(&qrels_path)?;
153 let mut grouped: BTreeMap<String, HashMap<String, u8>> = BTreeMap::new();
154 for raw_line in qrels_text.lines() {
155 let line = raw_line.trim();
156 if line.is_empty() {
157 continue;
158 }
159 let cols: Vec<&str> = line.split('\t').collect();
160 let (qid, doc_id, rel) = match cols.as_slice() {
161 [qid, doc_id, rel] => (*qid, *doc_id, *rel),
162 [qid, _iter, doc_id, rel] => (*qid, *doc_id, *rel),
163 _ => continue,
164 };
165 let grade: u8 = rel.trim().parse().unwrap_or(0);
168 if grade == 0 {
169 continue;
170 }
171 grouped
172 .entry(qid.trim().to_string())
173 .or_default()
174 .insert(doc_id.trim().to_string(), grade);
175 }
176
177 let mut queries = Vec::with_capacity(grouped.len());
179 for (qid, relevant) in grouped {
180 let Some(text) = query_text.get(&qid) else {
181 continue;
182 };
183 queries.push(GoldQuery {
184 query_id: qid.clone(),
185 query: text.clone(),
186 relevant_docs: relevant,
187 reference_answer: None,
188 });
189 }
190 Ok(Self { queries })
191 }
192
193 #[must_use]
195 pub fn len(&self) -> usize {
196 self.queries.len()
197 }
198
199 #[must_use]
201 pub fn is_empty(&self) -> bool {
202 self.queries.is_empty()
203 }
204}
205
206#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct RetrievedSet {
210 pub query_id: String,
212 pub ranked: Vec<RetrievedDoc>,
214}
215
216#[derive(Debug, Clone, Serialize, Deserialize)]
218pub struct RetrievedDoc {
219 pub doc_id: String,
222 pub score: f64,
224}
225
226#[cfg(test)]
227#[allow(clippy::unwrap_used, clippy::panic, clippy::indexing_slicing)]
228mod tests {
229 use super::*;
230
231 #[test]
232 fn parses_well_formed_jsonl() {
233 let text = r#"{"query_id":"q1","query":"a","relevant_docs":{"d1":2,"d2":1}}
234 {"query_id":"q2","query":"b","relevant_docs":{"d3":1},"reference_answer":"yes"}
235
236 "#;
237 let q = Qrels::from_jsonl_str(text).unwrap();
238 assert_eq!(q.len(), 2);
239 assert!(q.queries[0].is_relevant("d1"));
240 assert_eq!(q.queries[0].grade("d2"), 1);
241 assert_eq!(q.queries[0].grade("missing"), 0);
242 assert_eq!(q.queries[1].reference_answer.as_deref(), Some("yes"));
243 }
244
245 #[test]
246 fn reports_line_on_parse_error() {
247 let text = "{\"query_id\":\"q1\",\"query\":\"a\",\"relevant_docs\":{}}\nnot json\n";
248 let err = Qrels::from_jsonl_str(text).unwrap_err();
249 match err {
250 Error::DatasetParse { line, .. } => assert_eq!(line, 2),
251 other => panic!("unexpected error: {other:?}"),
252 }
253 }
254
255 #[test]
256 fn relevant_count_excludes_zero_grades() {
257 let q = GoldQuery {
258 query_id: "q".into(),
259 query: "".into(),
260 relevant_docs: HashMap::from([
261 ("a".to_string(), 2u8),
262 ("b".to_string(), 0u8),
263 ("c".to_string(), 1u8),
264 ]),
265 reference_answer: None,
266 };
267 assert_eq!(q.relevant_count(), 2);
268 }
269}