Skip to main content

kbolt_core/engine/
eval_ops.rs

1use kbolt_types::{
2    EvalCase, EvalJudgment, EvalModeFailure, EvalModeReport, EvalQueryReport, EvalRunReport,
3    SearchMode, SearchRequest, SearchResult,
4};
5use std::collections::{HashMap, HashSet};
6use std::path::Path;
7
8use crate::eval_store::load_eval_dataset_with_file;
9
10use super::*;
11
12impl Engine {
13    pub fn run_eval(&self, eval_file: Option<&Path>) -> Result<EvalRunReport> {
14        let dataset = load_eval_dataset_with_file(&self.config.config_dir, eval_file)?;
15        self.validate_eval_cases(&dataset.cases)?;
16        let eval_runs = self.eval_runs();
17        let total_cases = dataset.cases.len();
18        let mut reports = Vec::with_capacity(eval_runs.len());
19        let mut failed_modes = Vec::new();
20
21        for (mode, no_rerank) in eval_runs {
22            match self.run_eval_mode(&dataset.cases, mode.clone(), no_rerank) {
23                Ok(report) => reports.push(report),
24                Err(err) => failed_modes.push(EvalModeFailure {
25                    mode,
26                    no_rerank,
27                    error: err.to_string(),
28                }),
29            }
30        }
31
32        Ok(EvalRunReport {
33            total_cases,
34            modes: reports,
35            failed_modes,
36        })
37    }
38
39    fn validate_eval_cases(&self, cases: &[EvalCase]) -> Result<()> {
40        let mut seen_spaces = HashSet::new();
41        let mut seen_collections = HashSet::new();
42
43        for case in cases {
44            if let Some(space_name) = case.space.as_deref() {
45                if seen_spaces.insert(space_name.to_string()) {
46                    self.resolve_space_row(Some(space_name), None)?;
47                }
48            }
49
50            for collection in &case.collections {
51                let key = (case.space.clone(), collection.clone());
52                if seen_collections.insert(key) {
53                    self.validate_eval_collection(case.space.as_deref(), collection)?;
54                }
55            }
56        }
57
58        Ok(())
59    }
60
61    fn validate_eval_collection(&self, space: Option<&str>, collection: &str) -> Result<()> {
62        let resolved_space = self.resolve_space_row(space, Some(collection))?;
63        let collection_row = self.storage.get_collection(resolved_space.id, collection)?;
64        let chunk_count = self.storage.count_chunks_in_collection(collection_row.id)?;
65        if chunk_count == 0 {
66            return Err(KboltError::InvalidInput(format!(
67                "eval collection '{collection}' in space '{}' has no indexed chunks; run `kbolt --space {} update --collection {collection}`",
68                resolved_space.name, resolved_space.name
69            ))
70            .into());
71        }
72
73        Ok(())
74    }
75
76    fn eval_runs(&self) -> Vec<(SearchMode, bool)> {
77        let mut runs = vec![
78            (SearchMode::Keyword, true),
79            (SearchMode::Auto, true),
80            (SearchMode::Auto, false),
81        ];
82        if self.embedder.is_some() {
83            runs.push((SearchMode::Semantic, true));
84        }
85        runs.push((SearchMode::Deep, true));
86        runs.push((SearchMode::Deep, false));
87        runs
88    }
89
90    fn run_eval_mode(
91        &self,
92        cases: &[EvalCase],
93        mode: SearchMode,
94        no_rerank: bool,
95    ) -> Result<EvalModeReport> {
96        let mut query_reports = Vec::with_capacity(cases.len());
97        let mut ndcg_sum = 0.0_f32;
98        let mut recall_sum = 0.0_f32;
99        let mut mrr_sum = 0.0_f32;
100        let mut latencies = Vec::with_capacity(cases.len());
101
102        for case in cases {
103            let response = self.search(SearchRequest {
104                query: case.query.clone(),
105                mode: mode.clone(),
106                space: case.space.clone(),
107                collections: case.collections.clone(),
108                limit: 10,
109                min_score: 0.0,
110                no_rerank,
111                debug: false,
112            })?;
113
114            let returned_paths = dedupe_result_paths(case, &response.results);
115            let judgment_map = judgment_map(&case.judgments);
116            let relevant_path_count = case
117                .judgments
118                .iter()
119                .filter(|judgment| judgment.relevance > 0)
120                .count();
121            let matched_paths = returned_paths
122                .iter()
123                .filter(|path| relevance_for_path(&judgment_map, path.as_str()) > 0)
124                .cloned()
125                .collect::<Vec<_>>();
126            let first_relevant_rank = returned_paths
127                .iter()
128                .position(|path| relevance_for_path(&judgment_map, path.as_str()) > 0)
129                .map(|index| index + 1);
130            let matched_top_10 = returned_paths
131                .iter()
132                .take(10)
133                .filter(|path| relevance_for_path(&judgment_map, path.as_str()) > 0)
134                .count();
135            ndcg_sum += ndcg_at_k(&returned_paths, &judgment_map, 10);
136            recall_sum += matched_top_10 as f32 / relevant_path_count as f32;
137            mrr_sum += first_relevant_rank
138                .map(|rank| 1.0_f32 / rank as f32)
139                .unwrap_or(0.0);
140            latencies.push(response.elapsed_ms);
141            query_reports.push(EvalQueryReport {
142                query: case.query.clone(),
143                space: case.space.clone(),
144                collections: case.collections.clone(),
145                judgments: case.judgments.clone(),
146                returned_paths,
147                matched_paths,
148                first_relevant_rank,
149                elapsed_ms: response.elapsed_ms,
150            });
151        }
152
153        let case_count = cases.len() as f32;
154        Ok(EvalModeReport {
155            mode,
156            no_rerank,
157            ndcg_at_10: ndcg_sum / case_count,
158            recall_at_10: recall_sum / case_count,
159            mrr_at_10: mrr_sum / case_count,
160            latency_p50_ms: percentile_ms(&latencies, 0.50),
161            latency_p95_ms: percentile_ms(&latencies, 0.95),
162            queries: query_reports,
163        })
164    }
165}
166
167fn dedupe_result_paths(case: &EvalCase, results: &[SearchResult]) -> Vec<String> {
168    let mut seen = HashSet::new();
169    let mut deduped = Vec::new();
170
171    for result in results {
172        let path = if case.space.is_some() {
173            result.path.clone()
174        } else {
175            format!("{}/{}", result.space, result.path)
176        };
177
178        if seen.insert(path.clone()) {
179            deduped.push(path);
180        }
181    }
182
183    deduped
184}
185
186fn judgment_map<'a>(judgments: &'a [EvalJudgment]) -> HashMap<&'a str, u8> {
187    judgments
188        .iter()
189        .map(|judgment| (judgment.path.as_str(), judgment.relevance))
190        .collect()
191}
192
193fn relevance_for_path(judgments: &HashMap<&str, u8>, path: &str) -> u8 {
194    judgments.get(path).copied().unwrap_or(0)
195}
196
197fn ndcg_at_k(returned_paths: &[String], judgments: &HashMap<&str, u8>, k: usize) -> f32 {
198    let dcg = dcg_at_k(
199        &returned_paths
200            .iter()
201            .take(k)
202            .map(|path| relevance_for_path(judgments, path.as_str()))
203            .collect::<Vec<_>>(),
204    );
205    let mut ideal_relevances = judgments.values().copied().collect::<Vec<_>>();
206    ideal_relevances.sort_unstable_by(|left, right| right.cmp(left));
207    let ideal_dcg = dcg_at_k(&ideal_relevances.into_iter().take(k).collect::<Vec<_>>());
208    if ideal_dcg == 0.0 {
209        0.0
210    } else {
211        dcg / ideal_dcg
212    }
213}
214
215fn dcg_at_k(relevances: &[u8]) -> f32 {
216    relevances
217        .iter()
218        .enumerate()
219        .map(|(index, relevance)| {
220            let gain = 2_f32.powi(i32::from(*relevance)) - 1.0;
221            let discount = (index as f32 + 2.0).log2();
222            gain / discount
223        })
224        .sum()
225}
226
227fn percentile_ms(samples: &[u64], percentile: f32) -> u64 {
228    if samples.is_empty() {
229        return 0;
230    }
231
232    let mut sorted = samples.to_vec();
233    sorted.sort_unstable();
234    let index = ((sorted.len() as f32 * percentile).ceil() as usize)
235        .saturating_sub(1)
236        .min(sorted.len() - 1);
237    sorted[index]
238}
239
240#[cfg(test)]
241mod tests {
242    use std::fs;
243    use std::path::Path;
244    use std::sync::Arc;
245
246    use tempfile::tempdir;
247
248    use crate::config::{ChunkingConfig, Config, RankingConfig, ReapingConfig};
249    use crate::models::{Embedder, Expander};
250    use crate::storage::Storage;
251    use kbolt_types::{AddCollectionRequest, SearchMode, UpdateOptions};
252
253    use super::*;
254
255    #[derive(Default)]
256    struct DeterministicEmbedder;
257
258    impl Embedder for DeterministicEmbedder {
259        fn embed_batch(
260            &self,
261            _kind: crate::models::EmbeddingInputKind,
262            texts: &[String],
263        ) -> crate::Result<Vec<Vec<f32>>> {
264            Ok(texts
265                .iter()
266                .map(|text| {
267                    let token_count = text.split_whitespace().count().max(1) as f32;
268                    let byte_count = text.len().max(1) as f32;
269                    vec![token_count, byte_count]
270                })
271                .collect())
272        }
273    }
274
275    #[derive(Default)]
276    struct DeterministicExpander;
277
278    impl Expander for DeterministicExpander {
279        fn expand(&self, query: &str, _max_variants: usize) -> crate::Result<Vec<String>> {
280            Ok(vec![format!("explain {query}")])
281        }
282    }
283
284    struct FailingExpander;
285
286    impl Expander for FailingExpander {
287        fn expand(&self, _query: &str, _max_variants: usize) -> crate::Result<Vec<String>> {
288            Err(KboltError::Inference("expander unavailable".to_string()).into())
289        }
290    }
291
292    #[test]
293    fn run_eval_reports_keyword_auto_and_deep_without_embedder() {
294        let root = tempdir().expect("create temp root");
295        let collection_dir = seed_collection(root.path(), "rust", "guides/traits.md", TRAITS_DOC);
296        let engine = test_engine(None, Some(Arc::new(DeterministicExpander)));
297        seed_eval_file(
298            &engine,
299            r#"
300[[cases]]
301query = "trait object generic"
302space = "default"
303collections = ["rust"]
304judgments = [{ path = "rust/guides/traits.md", relevance = 1 }]
305"#,
306        );
307
308        engine
309            .add_collection(AddCollectionRequest {
310                path: collection_dir,
311                space: Some("default".to_string()),
312                name: Some("rust".to_string()),
313                description: None,
314                extensions: None,
315                no_index: true,
316            })
317            .expect("add collection");
318        engine
319            .update(UpdateOptions {
320                space: Some("default".to_string()),
321                collections: vec!["rust".to_string()],
322                no_embed: true,
323                dry_run: false,
324                verbose: false,
325            })
326            .expect("update collection");
327
328        let report = engine.run_eval(None).expect("run eval");
329
330        assert_eq!(report.total_cases, 1);
331        let mode_labels: Vec<_> = report
332            .modes
333            .iter()
334            .map(|m| (m.mode.clone(), m.no_rerank))
335            .collect();
336        assert_eq!(
337            mode_labels,
338            vec![
339                (SearchMode::Keyword, true),
340                (SearchMode::Auto, true),
341                (SearchMode::Auto, false),
342                (SearchMode::Deep, true),
343                (SearchMode::Deep, false),
344            ]
345        );
346        for mode in &report.modes {
347            assert_eq!(mode.queries.len(), 1);
348            assert_eq!(mode.queries[0].first_relevant_rank, Some(1));
349            assert_eq!(mode.ndcg_at_10, 1.0);
350            assert_eq!(mode.recall_at_10, 1.0);
351            assert_eq!(mode.mrr_at_10, 1.0);
352        }
353    }
354
355    #[test]
356    fn run_eval_includes_semantic_mode_when_embedder_is_available() {
357        let root = tempdir().expect("create temp root");
358        let collection_dir = seed_collection(root.path(), "rust", "guides/traits.md", TRAITS_DOC);
359        let engine = test_engine(
360            Some(Arc::new(DeterministicEmbedder)),
361            Some(Arc::new(DeterministicExpander)),
362        );
363        seed_eval_file(
364            &engine,
365            r#"
366[[cases]]
367query = "trait object generic"
368space = "default"
369collections = ["rust"]
370judgments = [{ path = "rust/guides/traits.md", relevance = 1 }]
371"#,
372        );
373
374        engine
375            .add_collection(AddCollectionRequest {
376                path: collection_dir,
377                space: Some("default".to_string()),
378                name: Some("rust".to_string()),
379                description: None,
380                extensions: None,
381                no_index: true,
382            })
383            .expect("add collection");
384        engine
385            .update(UpdateOptions {
386                space: Some("default".to_string()),
387                collections: vec!["rust".to_string()],
388                no_embed: false,
389                dry_run: false,
390                verbose: false,
391            })
392            .expect("update collection");
393
394        let report = engine.run_eval(None).expect("run eval");
395
396        assert!(report
397            .modes
398            .iter()
399            .any(|mode| mode.mode == SearchMode::Semantic));
400        let semantic = report
401            .modes
402            .iter()
403            .find(|mode| mode.mode == SearchMode::Semantic)
404            .expect("semantic report");
405        assert_eq!(semantic.queries[0].first_relevant_rank, Some(1));
406        assert_eq!(semantic.ndcg_at_10, 1.0);
407        assert_eq!(semantic.recall_at_10, 1.0);
408        assert_eq!(semantic.mrr_at_10, 1.0);
409    }
410
411    #[test]
412    fn run_eval_keeps_successful_modes_when_later_mode_fails() {
413        let root = tempdir().expect("create temp root");
414        let collection_dir = seed_collection(root.path(), "rust", "guides/traits.md", TRAITS_DOC);
415        let engine = test_engine(None, Some(Arc::new(FailingExpander)));
416        seed_eval_file(
417            &engine,
418            r#"
419[[cases]]
420query = "trait object generic"
421space = "default"
422collections = ["rust"]
423judgments = [{ path = "rust/guides/traits.md", relevance = 1 }]
424"#,
425        );
426
427        engine
428            .add_collection(AddCollectionRequest {
429                path: collection_dir,
430                space: Some("default".to_string()),
431                name: Some("rust".to_string()),
432                description: None,
433                extensions: None,
434                no_index: true,
435            })
436            .expect("add collection");
437        engine
438            .update(UpdateOptions {
439                space: Some("default".to_string()),
440                collections: vec!["rust".to_string()],
441                no_embed: true,
442                dry_run: false,
443                verbose: false,
444            })
445            .expect("update collection");
446
447        let report = engine.run_eval(None).expect("run eval");
448
449        assert!(report
450            .modes
451            .iter()
452            .any(|mode| mode.mode == SearchMode::Keyword));
453        assert!(report.failed_modes.iter().any(
454            |mode| mode.mode == SearchMode::Deep && mode.error.contains("expander unavailable")
455        ));
456    }
457
458    #[test]
459    fn ndcg_at_10_is_zero_for_irrelevant_results() {
460        let cases = [
461            EvalJudgment {
462                path: "rust/a.md".to_string(),
463                relevance: 2,
464            },
465            EvalJudgment {
466                path: "rust/b.md".to_string(),
467                relevance: 1,
468            },
469        ];
470        let judgments = judgment_map(&cases);
471
472        let score = ndcg_at_k(&["rust/c.md".to_string()], &judgments, 10);
473
474        assert_eq!(score, 0.0);
475    }
476
477    #[test]
478    fn ndcg_at_10_is_one_for_perfect_ranking() {
479        let cases = [
480            EvalJudgment {
481                path: "rust/a.md".to_string(),
482                relevance: 2,
483            },
484            EvalJudgment {
485                path: "rust/b.md".to_string(),
486                relevance: 1,
487            },
488        ];
489        let judgments = judgment_map(&cases);
490
491        let score = ndcg_at_k(
492            &["rust/a.md".to_string(), "rust/b.md".to_string()],
493            &judgments,
494            10,
495        );
496
497        assert!((score - 1.0).abs() < 1e-6, "unexpected score: {score}");
498    }
499
500    #[test]
501    fn ndcg_at_10_uses_graded_relevance_ordering() {
502        let cases = [
503            EvalJudgment {
504                path: "rust/a.md".to_string(),
505                relevance: 2,
506            },
507            EvalJudgment {
508                path: "rust/b.md".to_string(),
509                relevance: 1,
510            },
511        ];
512        let judgments = judgment_map(&cases);
513
514        let perfect = ndcg_at_k(
515            &["rust/a.md".to_string(), "rust/b.md".to_string()],
516            &judgments,
517            10,
518        );
519        let swapped = ndcg_at_k(
520            &["rust/b.md".to_string(), "rust/a.md".to_string()],
521            &judgments,
522            10,
523        );
524
525        assert!(perfect > swapped, "perfect={perfect}, swapped={swapped}");
526    }
527
528    #[test]
529    fn ndcg_at_10_handles_fewer_results_than_k() {
530        let cases = [
531            EvalJudgment {
532                path: "rust/a.md".to_string(),
533                relevance: 2,
534            },
535            EvalJudgment {
536                path: "rust/b.md".to_string(),
537                relevance: 1,
538            },
539            EvalJudgment {
540                path: "rust/c.md".to_string(),
541                relevance: 1,
542            },
543        ];
544        let judgments = judgment_map(&cases);
545
546        let score = ndcg_at_k(
547            &["rust/a.md".to_string(), "rust/b.md".to_string()],
548            &judgments,
549            10,
550        );
551
552        assert!(score > 0.0 && score < 1.0, "unexpected score: {score}");
553    }
554
555    #[test]
556    fn run_eval_supports_explicit_manifest_path() {
557        let root = tempdir().expect("create temp root");
558        let collection_dir = seed_collection(root.path(), "rust", "guides/traits.md", TRAITS_DOC);
559        let engine = test_engine(None, Some(Arc::new(DeterministicExpander)));
560        let eval_file = root.path().join("bench").join("scifact.toml");
561        if let Some(parent) = eval_file.parent() {
562            fs::create_dir_all(parent).expect("create bench dir");
563        }
564        fs::write(
565            &eval_file,
566            r#"
567[[cases]]
568query = "trait object generic"
569space = "default"
570collections = ["rust"]
571judgments = [{ path = "rust/guides/traits.md", relevance = 1 }]
572"#,
573        )
574        .expect("write eval file");
575
576        engine
577            .add_collection(AddCollectionRequest {
578                path: collection_dir,
579                space: Some("default".to_string()),
580                name: Some("rust".to_string()),
581                description: None,
582                extensions: None,
583                no_index: true,
584            })
585            .expect("add collection");
586        engine
587            .update(UpdateOptions {
588                space: Some("default".to_string()),
589                collections: vec!["rust".to_string()],
590                no_embed: true,
591                dry_run: false,
592                verbose: false,
593            })
594            .expect("update collection");
595
596        let report = engine.run_eval(Some(&eval_file)).expect("run eval");
597
598        assert_eq!(report.total_cases, 1);
599        assert!(report.modes.iter().all(|mode| mode.recall_at_10 >= 0.0));
600    }
601
602    #[test]
603    fn run_eval_fails_when_manifest_references_missing_collection() {
604        let engine = test_engine(None, Some(Arc::new(DeterministicExpander)));
605        seed_eval_file(
606            &engine,
607            r#"
608[[cases]]
609query = "trait object generic"
610space = "default"
611collections = ["rust"]
612judgments = [{ path = "rust/guides/traits.md", relevance = 1 }]
613"#,
614        );
615
616        let err = engine
617            .run_eval(None)
618            .expect_err("missing collection should fail");
619        assert!(
620            err.to_string().contains("collection not found: rust"),
621            "unexpected error: {err}"
622        );
623    }
624
625    #[test]
626    fn run_eval_fails_when_collection_has_not_been_indexed() {
627        let root = tempdir().expect("create temp root");
628        let collection_dir = seed_collection(root.path(), "rust", "guides/traits.md", TRAITS_DOC);
629        let engine = test_engine(None, Some(Arc::new(DeterministicExpander)));
630        seed_eval_file(
631            &engine,
632            r#"
633[[cases]]
634query = "trait object generic"
635space = "default"
636collections = ["rust"]
637judgments = [{ path = "rust/guides/traits.md", relevance = 1 }]
638"#,
639        );
640
641        engine
642            .add_collection(AddCollectionRequest {
643                path: collection_dir,
644                space: Some("default".to_string()),
645                name: Some("rust".to_string()),
646                description: None,
647                extensions: None,
648                no_index: true,
649            })
650            .expect("add collection");
651
652        let err = engine
653            .run_eval(None)
654            .expect_err("unindexed collection should fail");
655        assert!(
656            err.to_string().contains("has no indexed chunks"),
657            "unexpected error: {err}"
658        );
659        assert!(
660            err.to_string()
661                .contains("kbolt --space default update --collection rust"),
662            "unexpected error: {err}"
663        );
664    }
665
666    fn test_engine(
667        embedder: Option<Arc<dyn Embedder>>,
668        expander: Option<Arc<dyn Expander>>,
669    ) -> Engine {
670        let root = tempdir().expect("create temp root");
671        let root_path = root.path().to_path_buf();
672        std::mem::forget(root);
673        let config_dir = root_path.join("config");
674        let cache_dir = root_path.join("cache");
675        let storage = Storage::new(&cache_dir).expect("create storage");
676        let config = Config {
677            config_dir,
678            cache_dir,
679            default_space: None,
680            providers: std::collections::HashMap::new(),
681            roles: crate::config::RoleBindingsConfig::default(),
682            reaping: ReapingConfig { days: 7 },
683            chunking: ChunkingConfig::default(),
684            ranking: RankingConfig::default(),
685        };
686        Engine::from_parts_with_models(storage, config, embedder, None, expander)
687    }
688
689    fn seed_collection(
690        root: &Path,
691        collection: &str,
692        relative_path: &str,
693        content: &str,
694    ) -> std::path::PathBuf {
695        let collection_dir = root.join(collection);
696        let full_path = collection_dir.join(relative_path);
697        if let Some(parent) = full_path.parent() {
698            fs::create_dir_all(parent).expect("create file parent");
699        }
700        fs::write(full_path, content).expect("write test document");
701        collection_dir
702    }
703
704    fn seed_eval_file(engine: &Engine, content: &str) {
705        fs::create_dir_all(&engine.config().config_dir).expect("create config dir");
706        fs::write(engine.config().config_dir.join("eval.toml"), content).expect("write eval file");
707    }
708
709    const TRAITS_DOC: &str = r#"
710Trait objects use dynamic dispatch, while generics use monomorphization.
711Choose trait objects for heterogenous collections and generics for zero-cost abstraction.
712"#;
713}