1use std::collections::HashMap;
12
13use lance::{Error as LanceError, Result as LanceResult};
14use serde::{Deserialize, Serialize};
15
16use crate::record::{LifecycleQueryOptions, RecordFilters};
17use crate::store::ContextStore;
18
19fn default_grade() -> f32 {
20 1.0
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct RelevanceLabel {
30 pub external_id: String,
31 #[serde(default = "default_grade")]
32 pub grade: f32,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct EvalQuery {
38 pub query_id: String,
39 #[serde(default, skip_serializing_if = "Option::is_none")]
40 pub text: Option<String>,
41 #[serde(default, skip_serializing_if = "Option::is_none")]
42 pub vector: Option<Vec<f32>>,
43 #[serde(default)]
44 pub relevant: Vec<RelevanceLabel>,
45}
46
47impl EvalQuery {
48 fn relevance_map(&self) -> HashMap<&str, f32> {
49 self.relevant
50 .iter()
51 .map(|label| (label.external_id.as_str(), label.grade))
52 .collect()
53 }
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct EvalQuerySet {
59 pub id: String,
60 pub queries: Vec<EvalQuery>,
61}
62
63impl EvalQuerySet {
64 #[must_use]
65 pub fn new(id: impl Into<String>, queries: Vec<EvalQuery>) -> Self {
66 Self {
67 id: id.into(),
68 queries,
69 }
70 }
71
72 pub fn from_jsonl(id: impl Into<String>, contents: &str) -> LanceResult<Self> {
76 let mut queries = Vec::new();
77 for (index, line) in contents.lines().enumerate() {
78 let line = line.trim();
79 if line.is_empty() {
80 continue;
81 }
82 let query: EvalQuery = serde_json::from_str(line).map_err(|err| {
83 LanceError::invalid_input(format!(
84 "invalid eval query on line {}: {err}",
85 index + 1
86 ))
87 })?;
88 queries.push(query);
89 }
90 Ok(Self::new(id, queries))
91 }
92
93 pub fn to_jsonl(&self) -> LanceResult<String> {
95 let mut out = String::new();
96 for query in &self.queries {
97 let line = serde_json::to_string(query)
98 .map_err(|err| LanceError::invalid_input(err.to_string()))?;
99 out.push_str(&line);
100 out.push('\n');
101 }
102 Ok(out)
103 }
104}
105
106#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
108#[serde(rename_all = "lowercase")]
109pub enum RetrievalMode {
110 #[default]
112 Vector,
113 Hybrid,
116}
117
118impl RetrievalMode {
119 #[must_use]
120 pub fn as_str(self) -> &'static str {
121 match self {
122 Self::Vector => "vector",
123 Self::Hybrid => "hybrid",
124 }
125 }
126}
127
128#[derive(Clone)]
130pub struct EvalConfig {
131 pub k: usize,
133 pub mode: RetrievalMode,
134 pub filters: Option<RecordFilters>,
135 pub lifecycle: LifecycleQueryOptions,
136}
137
138impl Default for EvalConfig {
139 fn default() -> Self {
140 Self {
141 k: 10,
142 mode: RetrievalMode::Vector,
143 filters: None,
144 lifecycle: LifecycleQueryOptions::default(),
145 }
146 }
147}
148
149#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)]
151pub struct MetricScores {
152 pub recall: f64,
153 pub precision: f64,
154 pub mrr: f64,
155 pub ndcg: f64,
156 pub hit_rate: f64,
157}
158
159impl MetricScores {
160 #[must_use]
162 pub fn delta(&self, baseline: &MetricScores) -> MetricScores {
163 MetricScores {
164 recall: self.recall - baseline.recall,
165 precision: self.precision - baseline.precision,
166 mrr: self.mrr - baseline.mrr,
167 ndcg: self.ndcg - baseline.ndcg,
168 hit_rate: self.hit_rate - baseline.hit_rate,
169 }
170 }
171}
172
173#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct QueryEval {
176 pub query_id: String,
177 pub retrieved: Vec<String>,
180 pub scores: MetricScores,
181}
182
183#[derive(Debug, Clone, Serialize, Deserialize)]
186pub struct EvalReport {
187 pub query_set_id: String,
188 pub version: u64,
190 pub k: usize,
191 pub mode: String,
192 pub distance_metric: String,
194 pub num_queries: usize,
195 pub aggregate: MetricScores,
196 pub per_query: Vec<QueryEval>,
197}
198
199#[derive(Debug, Clone, Serialize, Deserialize)]
201pub struct AbReport {
202 pub query_set_id: String,
203 pub baseline: EvalReport,
204 pub candidate: EvalReport,
205 pub deltas: MetricScores,
207}
208
209fn compute_scores(retrieved: &[String], relevant: &HashMap<&str, f32>, k: usize) -> MetricScores {
211 let k = k.max(1);
212 let num_relevant = relevant.values().filter(|grade| **grade > 0.0).count();
213
214 let mut hits = 0usize;
215 let mut first_relevant_rank: Option<usize> = None;
216 let mut dcg = 0.0f64;
217 for (index, external_id) in retrieved.iter().take(k).enumerate() {
218 let grade = relevant.get(external_id.as_str()).copied().unwrap_or(0.0);
219 if grade > 0.0 {
220 hits += 1;
221 if first_relevant_rank.is_none() {
222 first_relevant_rank = Some(index + 1);
223 }
224 dcg += f64::from(grade) / ((index + 2) as f64).log2();
226 }
227 }
228
229 let mut ideal_grades: Vec<f64> = relevant
231 .values()
232 .filter(|grade| **grade > 0.0)
233 .map(|grade| f64::from(*grade))
234 .collect();
235 ideal_grades.sort_by(|a, b| b.total_cmp(a));
236 let idcg: f64 = ideal_grades
237 .iter()
238 .take(k)
239 .enumerate()
240 .map(|(index, grade)| grade / ((index + 2) as f64).log2())
241 .sum();
242
243 MetricScores {
244 recall: if num_relevant > 0 {
245 hits as f64 / num_relevant as f64
246 } else {
247 0.0
248 },
249 precision: hits as f64 / k as f64,
250 mrr: first_relevant_rank.map_or(0.0, |rank| 1.0 / rank as f64),
251 ndcg: if idcg > 0.0 { dcg / idcg } else { 0.0 },
252 hit_rate: if hits > 0 { 1.0 } else { 0.0 },
253 }
254}
255
256fn mean_scores(per_query: &[QueryEval]) -> MetricScores {
257 let n = per_query.len();
258 if n == 0 {
259 return MetricScores::default();
260 }
261 let mut agg = MetricScores::default();
262 for query in per_query {
263 agg.recall += query.scores.recall;
264 agg.precision += query.scores.precision;
265 agg.mrr += query.scores.mrr;
266 agg.ndcg += query.scores.ndcg;
267 agg.hit_rate += query.scores.hit_rate;
268 }
269 let n = n as f64;
270 MetricScores {
271 recall: agg.recall / n,
272 precision: agg.precision / n,
273 mrr: agg.mrr / n,
274 ndcg: agg.ndcg / n,
275 hit_rate: agg.hit_rate / n,
276 }
277}
278
279impl ContextStore {
280 pub async fn evaluate(
287 &self,
288 query_set: &EvalQuerySet,
289 config: &EvalConfig,
290 ) -> LanceResult<EvalReport> {
291 let mut per_query = Vec::with_capacity(query_set.queries.len());
292 for query in &query_set.queries {
293 let retrieved = self.run_eval_query(query, config).await?;
294 let relevant = query.relevance_map();
295 let scores = compute_scores(&retrieved, &relevant, config.k);
296 per_query.push(QueryEval {
297 query_id: query.query_id.clone(),
298 retrieved,
299 scores,
300 });
301 }
302
303 Ok(EvalReport {
304 query_set_id: query_set.id.clone(),
305 version: self.version(),
306 k: config.k,
307 mode: config.mode.as_str().to_string(),
308 distance_metric: self.distance_metric().as_str().to_string(),
309 num_queries: per_query.len(),
310 aggregate: mean_scores(&per_query),
311 per_query,
312 })
313 }
314
315 pub async fn evaluate_versions(
319 &mut self,
320 query_set: &EvalQuerySet,
321 config: &EvalConfig,
322 baseline_version: u64,
323 candidate_version: u64,
324 ) -> LanceResult<AbReport> {
325 let original_version = self.version();
326
327 self.checkout(baseline_version).await?;
328 let baseline = self.evaluate(query_set, config).await?;
329 self.checkout(candidate_version).await?;
330 let candidate = self.evaluate(query_set, config).await?;
331 self.checkout(original_version).await?;
332
333 let deltas = candidate.aggregate.delta(&baseline.aggregate);
334 Ok(AbReport {
335 query_set_id: query_set.id.clone(),
336 baseline,
337 candidate,
338 deltas,
339 })
340 }
341
342 async fn run_eval_query(
344 &self,
345 query: &EvalQuery,
346 config: &EvalConfig,
347 ) -> LanceResult<Vec<String>> {
348 let limit = Some(config.k);
349 let records = match config.mode {
350 RetrievalMode::Vector => {
351 let vector = query.vector.as_deref().ok_or_else(|| {
352 LanceError::invalid_input(format!(
353 "query '{}' has no vector for vector-mode eval",
354 query.query_id
355 ))
356 })?;
357 self.search_filtered_with_options(
358 vector,
359 limit,
360 config.filters.as_ref(),
361 config.lifecycle.clone(),
362 )
363 .await?
364 .into_iter()
365 .map(|hit| hit.record)
366 .collect::<Vec<_>>()
367 }
368 RetrievalMode::Hybrid => self
369 .retrieve_filtered_with_options(
370 query.text.as_deref(),
371 query.vector.as_deref(),
372 limit,
373 config.filters.as_ref(),
374 config.lifecycle.clone(),
375 )
376 .await?
377 .into_iter()
378 .map(|hit| hit.record)
379 .collect::<Vec<_>>(),
380 };
381
382 Ok(records
383 .into_iter()
384 .map(|record| record.external_id.unwrap_or_default())
385 .collect())
386 }
387}
388
389#[cfg(test)]
390mod tests {
391 use super::*;
392 use crate::record::{ContextRecord, LIFECYCLE_ACTIVE};
393 use crate::store::ContextStore;
394 use chrono::Utc;
395 use serde_json::json;
396 use tempfile::TempDir;
397 use uuid::Uuid;
398
399 fn scores(retrieved: &[&str], relevant: &[(&str, f32)], k: usize) -> MetricScores {
402 let retrieved: Vec<String> = retrieved.iter().map(|s| s.to_string()).collect();
403 let relevant: HashMap<&str, f32> = relevant.iter().copied().collect();
404 compute_scores(&retrieved, &relevant, k)
405 }
406
407 fn approx(actual: f64, expected: f64) {
408 assert!(
409 (actual - expected).abs() < 1e-4,
410 "expected {expected}, got {actual}"
411 );
412 }
413
414 #[test]
415 fn metrics_perfect_ranking() {
416 let s = scores(&["a", "b"], &[("a", 1.0), ("b", 1.0)], 2);
417 approx(s.recall, 1.0);
418 approx(s.precision, 1.0);
419 approx(s.mrr, 1.0);
420 approx(s.ndcg, 1.0);
421 approx(s.hit_rate, 1.0);
422 }
423
424 #[test]
425 fn metrics_single_relevant_at_rank_two() {
426 let s = scores(&["x", "a"], &[("a", 1.0)], 2);
428 approx(s.recall, 1.0); approx(s.precision, 0.5); approx(s.mrr, 0.5); approx(s.hit_rate, 1.0);
432 approx(s.ndcg, 1.0 / 3.0_f64.log2());
434 }
435
436 #[test]
437 fn metrics_no_relevant_in_topk() {
438 let s = scores(&["x", "y"], &[("a", 1.0)], 2);
439 approx(s.recall, 0.0);
440 approx(s.precision, 0.0);
441 approx(s.mrr, 0.0);
442 approx(s.ndcg, 0.0);
443 approx(s.hit_rate, 0.0);
444 }
445
446 #[test]
447 fn metrics_graded_ndcg() {
448 let s = scores(&["a", "b"], &[("a", 1.0), ("b", 3.0)], 2);
450 let dcg = 1.0 / 2.0_f64.log2() + 3.0 / 3.0_f64.log2();
451 let idcg = 3.0 / 2.0_f64.log2() + 1.0 / 3.0_f64.log2();
452 approx(s.ndcg, dcg / idcg);
453 approx(s.recall, 1.0);
454 }
455
456 #[test]
457 fn metrics_precision_is_over_k() {
458 let s = scores(&["a"], &[("a", 1.0)], 2);
460 approx(s.precision, 0.5);
461 approx(s.recall, 1.0);
462 approx(s.hit_rate, 1.0);
463 }
464
465 #[test]
466 fn query_set_jsonl_round_trip() {
467 let jsonl = concat!(
468 "{\"query_id\":\"q1\",\"vector\":[1.0,0.0],\"relevant\":[{\"external_id\":\"a\"}]}\n",
469 "\n",
470 "{\"query_id\":\"q2\",\"text\":\"hi\",\"relevant\":[{\"external_id\":\"b\",\"grade\":2.0}]}\n",
471 );
472 let set = EvalQuerySet::from_jsonl("set-1", jsonl).unwrap();
473 assert_eq!(set.queries.len(), 2);
474 assert_eq!(set.queries[0].query_id, "q1");
475 assert_eq!(set.queries[1].relevant[0].grade, 2.0);
476 assert_eq!(set.queries[0].relevant[0].grade, 1.0);
478
479 let reparsed = EvalQuerySet::from_jsonl("set-1", &set.to_jsonl().unwrap()).unwrap();
480 assert_eq!(reparsed.queries.len(), 2);
481 assert_eq!(reparsed.queries[1].relevant[0].external_id, "b");
482 }
483
484 fn embedding(store: &ContextStore, lead: &[f32]) -> Vec<f32> {
487 let dim = store.embedding_dim() as usize;
488 let mut v = vec![0.0f32; dim];
489 for (i, x) in lead.iter().enumerate() {
490 v[i] = *x;
491 }
492 v
493 }
494
495 fn record(external_id: &str, text: &str, embedding: Vec<f32>) -> ContextRecord {
496 ContextRecord {
497 id: Uuid::new_v4().to_string(),
498 external_id: Some(external_id.to_string()),
499 run_id: "run".to_string(),
500 bot_id: None,
501 session_id: None,
502 tenant: None,
503 source: None,
504 created_at: Utc::now(),
505 role: "user".to_string(),
506 state_metadata: None,
507 metadata: None,
508 relationships: Vec::new(),
509 expires_at: None,
510 retention_policy: None,
511 lifecycle_status: LIFECYCLE_ACTIVE.to_string(),
512 retired_at: None,
513 retired_reason: None,
514 supersedes_id: None,
515 superseded_by_id: None,
516 content_type: "text/plain".to_string(),
517 text_payload: Some(text.to_string()),
518 binary_payload: None,
519 embedding: Some(embedding),
520 }
521 }
522
523 #[test]
524 fn evaluate_vector_mode_scores_query_set() {
525 let dir = TempDir::new().unwrap();
526 let uri = dir.path().to_string_lossy().to_string();
527 let runtime = tokio::runtime::Runtime::new().unwrap();
528 runtime.block_on(async {
529 let mut store = ContextStore::open(&uri).await.unwrap();
530 let a = embedding(&store, &[1.0]);
531 let b = embedding(&store, &[0.5]);
532 let c = embedding(&store, &[0.0, 1.0]);
533 store
534 .add(&[
535 record("doc-a", "alpha", a.clone()),
536 record("doc-b", "beta", b),
537 record("doc-c", "gamma", c),
538 ])
539 .await
540 .unwrap();
541
542 let query_set = EvalQuerySet::new(
544 "qs",
545 vec![EvalQuery {
546 query_id: "q1".to_string(),
547 text: None,
548 vector: Some(a),
549 relevant: vec![RelevanceLabel {
550 external_id: "doc-a".to_string(),
551 grade: 1.0,
552 }],
553 }],
554 );
555 let config = EvalConfig {
556 k: 2,
557 mode: RetrievalMode::Vector,
558 ..Default::default()
559 };
560 let report = store.evaluate(&query_set, &config).await.unwrap();
561
562 assert_eq!(report.num_queries, 1);
563 assert_eq!(report.mode, "vector");
564 assert_eq!(report.k, 2);
565 assert_eq!(report.per_query[0].retrieved.first().unwrap(), "doc-a");
566 approx(report.aggregate.recall, 1.0);
567 approx(report.aggregate.precision, 0.5);
568 approx(report.aggregate.mrr, 1.0);
569 approx(report.aggregate.hit_rate, 1.0);
570 });
571 }
572
573 #[test]
574 fn evaluate_respects_lifecycle_visibility() {
575 let dir = TempDir::new().unwrap();
576 let uri = dir.path().to_string_lossy().to_string();
577 let runtime = tokio::runtime::Runtime::new().unwrap();
578 runtime.block_on(async {
579 let mut store = ContextStore::open(&uri).await.unwrap();
580 let q = embedding(&store, &[1.0]);
581 let mut retired = record("doc-a", "alpha", q.clone());
583 retired.retired_at = Some(Utc::now());
584 store.add(&[retired]).await.unwrap();
585
586 let query_set = EvalQuerySet::new(
587 "qs",
588 vec![EvalQuery {
589 query_id: "q1".to_string(),
590 text: None,
591 vector: Some(q),
592 relevant: vec![RelevanceLabel {
593 external_id: "doc-a".to_string(),
594 grade: 1.0,
595 }],
596 }],
597 );
598
599 let default_cfg = EvalConfig {
600 k: 5,
601 mode: RetrievalMode::Vector,
602 ..Default::default()
603 };
604 let hidden = store.evaluate(&query_set, &default_cfg).await.unwrap();
605 approx(hidden.aggregate.recall, 0.0); let include_retired = EvalConfig {
608 k: 5,
609 mode: RetrievalMode::Vector,
610 lifecycle: LifecycleQueryOptions::new(true, true),
611 ..Default::default()
612 };
613 let visible = store.evaluate(&query_set, &include_retired).await.unwrap();
614 approx(visible.aggregate.recall, 1.0); });
616 }
617
618 #[test]
619 fn evaluate_respects_filters() {
620 let dir = TempDir::new().unwrap();
621 let uri = dir.path().to_string_lossy().to_string();
622 let runtime = tokio::runtime::Runtime::new().unwrap();
623 runtime.block_on(async {
624 let mut store = ContextStore::open(&uri).await.unwrap();
625 let shared = embedding(&store, &[1.0]);
626 let mut a = record("doc-a", "alpha", shared.clone());
627 a.tenant = Some("x".to_string());
628 let mut b = record("doc-b", "beta", shared.clone());
629 b.tenant = Some("y".to_string());
630 store.add(&[a, b]).await.unwrap();
631
632 let query_set = EvalQuerySet::new(
634 "qs",
635 vec![EvalQuery {
636 query_id: "q1".to_string(),
637 text: None,
638 vector: Some(shared),
639 relevant: vec![RelevanceLabel {
640 external_id: "doc-b".to_string(),
641 grade: 1.0,
642 }],
643 }],
644 );
645 let config = EvalConfig {
646 k: 5,
647 mode: RetrievalMode::Vector,
648 filters: Some(RecordFilters::from_json_value(json!({"tenant": "x"})).unwrap()),
649 ..Default::default()
650 };
651 let report = store.evaluate(&query_set, &config).await.unwrap();
652 approx(report.aggregate.recall, 0.0); });
654 }
655
656 #[test]
657 fn evaluate_hybrid_mode_finds_relevant() {
658 let dir = TempDir::new().unwrap();
659 let uri = dir.path().to_string_lossy().to_string();
660 let runtime = tokio::runtime::Runtime::new().unwrap();
661 runtime.block_on(async {
662 let mut store = ContextStore::open(&uri).await.unwrap();
663 let a = embedding(&store, &[1.0]);
664 let b = embedding(&store, &[0.0, 1.0]);
665 store
666 .add(&[
667 record("doc-a", "alpha unique", a.clone()),
668 record("doc-b", "beta other", b),
669 ])
670 .await
671 .unwrap();
672
673 let query_set = EvalQuerySet::new(
674 "qs",
675 vec![EvalQuery {
676 query_id: "q1".to_string(),
677 text: Some("alpha".to_string()),
678 vector: Some(a),
679 relevant: vec![RelevanceLabel {
680 external_id: "doc-a".to_string(),
681 grade: 1.0,
682 }],
683 }],
684 );
685 let config = EvalConfig {
686 k: 2,
687 mode: RetrievalMode::Hybrid,
688 ..Default::default()
689 };
690 let report = store.evaluate(&query_set, &config).await.unwrap();
691 approx(report.aggregate.hit_rate, 1.0);
692 });
693 }
694
695 #[test]
696 fn config_ab_delta_detects_k_sensitivity() {
697 let dir = TempDir::new().unwrap();
698 let uri = dir.path().to_string_lossy().to_string();
699 let runtime = tokio::runtime::Runtime::new().unwrap();
700 runtime.block_on(async {
701 let mut store = ContextStore::open(&uri).await.unwrap();
702 let a = embedding(&store, &[1.0]);
703 let b = embedding(&store, &[0.5]);
704 store
705 .add(&[
706 record("doc-a", "alpha", a.clone()),
707 record("doc-b", "beta", b),
708 ])
709 .await
710 .unwrap();
711
712 let query_set = EvalQuerySet::new(
714 "qs",
715 vec![EvalQuery {
716 query_id: "q1".to_string(),
717 text: None,
718 vector: Some(a),
719 relevant: vec![RelevanceLabel {
720 external_id: "doc-b".to_string(),
721 grade: 1.0,
722 }],
723 }],
724 );
725 let k1 = EvalConfig {
726 k: 1,
727 mode: RetrievalMode::Vector,
728 ..Default::default()
729 };
730 let k2 = EvalConfig {
731 k: 2,
732 mode: RetrievalMode::Vector,
733 ..Default::default()
734 };
735 let at_1 = store.evaluate(&query_set, &k1).await.unwrap();
736 let at_2 = store.evaluate(&query_set, &k2).await.unwrap();
737 approx(at_1.aggregate.recall, 0.0); approx(at_2.aggregate.recall, 1.0); let delta = at_2.aggregate.delta(&at_1.aggregate);
740 approx(delta.recall, 1.0);
741 });
742 }
743
744 #[test]
745 fn evaluate_versions_same_version_is_zero_delta_and_restores() {
746 let dir = TempDir::new().unwrap();
747 let uri = dir.path().to_string_lossy().to_string();
748 let runtime = tokio::runtime::Runtime::new().unwrap();
749 runtime.block_on(async {
750 let mut store = ContextStore::open(&uri).await.unwrap();
751 let a = embedding(&store, &[1.0]);
752 store
753 .add(&[record("doc-a", "alpha", a.clone())])
754 .await
755 .unwrap();
756 let version = store.version();
757
758 let query_set = EvalQuerySet::new(
759 "qs",
760 vec![EvalQuery {
761 query_id: "q1".to_string(),
762 text: None,
763 vector: Some(a),
764 relevant: vec![RelevanceLabel {
765 external_id: "doc-a".to_string(),
766 grade: 1.0,
767 }],
768 }],
769 );
770 let config = EvalConfig {
771 k: 1,
772 mode: RetrievalMode::Vector,
773 ..Default::default()
774 };
775 let ab = store
776 .evaluate_versions(&query_set, &config, version, version)
777 .await
778 .unwrap();
779
780 approx(ab.deltas.recall, 0.0);
781 approx(ab.deltas.ndcg, 0.0);
782 assert_eq!(ab.baseline.version, version);
783 assert_eq!(ab.candidate.version, version);
784 assert_eq!(
785 store.version(),
786 version,
787 "store restored to original version"
788 );
789 });
790 }
791}