1use std::cmp::Ordering;
50use std::collections::HashMap;
51use std::sync::Arc;
52use std::time::Instant;
53
54use anyhow::{Result, bail};
55use half::f16;
56use tracing::{debug, warn};
57
58use super::daemon_client::{DaemonClient, DaemonError};
59use super::embedder::Embedder;
60
61use frankensearch::TwoTierConfig as FsTwoTierConfig;
63use frankensearch::{TwoTierIndex as FsTwoTierIndex, VectorHit as FsVectorHit};
64
65#[derive(Debug, Clone)]
67pub struct TwoTierConfig {
68 pub fast_dimension: usize,
70 pub quality_dimension: usize,
72 pub quality_weight: f32,
74 pub max_refinement_docs: usize,
76 pub fast_only: bool,
78 pub quality_only: bool,
80}
81
82impl Default for TwoTierConfig {
83 fn default() -> Self {
84 Self {
85 fast_dimension: 256,
86 quality_dimension: 384,
87 quality_weight: 0.7,
88 max_refinement_docs: 100,
89 fast_only: false,
90 quality_only: false,
91 }
92 }
93}
94
95impl TwoTierConfig {
96 pub fn from_env() -> Self {
98 let mut cfg = Self::default();
99
100 if let Ok(val) = dotenvy::var("CASS_TWO_TIER_FAST_DIM")
101 && let Ok(dim) = val.parse()
102 {
103 cfg.fast_dimension = dim;
104 }
105
106 if let Ok(val) = dotenvy::var("CASS_TWO_TIER_QUALITY_DIM")
107 && let Ok(dim) = val.parse()
108 {
109 cfg.quality_dimension = dim;
110 }
111
112 if let Ok(val) = dotenvy::var("CASS_TWO_TIER_QUALITY_WEIGHT")
113 && let Ok(weight) = val.parse::<f32>()
114 {
115 cfg.quality_weight = weight.clamp(0.0, 1.0);
116 }
117
118 if let Ok(val) = dotenvy::var("CASS_TWO_TIER_MAX_REFINEMENT")
119 && let Ok(max) = val.parse()
120 {
121 cfg.max_refinement_docs = max;
122 }
123
124 cfg
125 }
126
127 pub fn fast_only() -> Self {
129 Self {
130 fast_only: true,
131 ..Self::default()
132 }
133 }
134
135 pub fn quality_only() -> Self {
137 Self {
138 quality_only: true,
139 ..Self::default()
140 }
141 }
142
143 fn to_fs_config(&self) -> FsTwoTierConfig {
145 FsTwoTierConfig {
146 quality_weight: f64::from(self.quality_weight),
147 fast_only: self.fast_only,
148 ..FsTwoTierConfig::optimized().with_env_overrides()
149 }
150 }
151}
152
153#[derive(Debug, Clone, PartialEq, Eq, Hash)]
155pub enum DocumentId {
156 Session(String),
158 Turn(String, usize),
160 CodeBlock(String, usize, usize),
162}
163
164impl DocumentId {
165 pub fn session_id(&self) -> &str {
167 match self {
168 Self::Session(id) => id,
169 Self::Turn(id, _) => id,
170 Self::CodeBlock(id, _, _) => id,
171 }
172 }
173
174 fn encode(&self) -> String {
176 match self {
177 Self::Session(id) => format!("s:{id}"),
178 Self::Turn(id, turn) => format!("t:{id}:{turn}"),
179 Self::CodeBlock(id, turn, block) => format!("c:{id}:{turn}:{block}"),
180 }
181 }
182}
183
184#[derive(Debug, Clone)]
186pub struct TwoTierMetadata {
187 pub fast_embedder_id: String,
189 pub quality_embedder_id: String,
191 pub doc_count: usize,
193 pub built_at: i64,
195 pub status: IndexStatus,
197}
198
199#[derive(Debug, Clone)]
201pub enum IndexStatus {
202 Building { progress: f32 },
204 Complete {
206 fast_latency_ms: u64,
207 quality_latency_ms: u64,
208 },
209 Failed { error: String },
211}
212
213#[derive(Debug, Clone)]
215pub struct TwoTierEntry {
216 pub doc_id: DocumentId,
218 pub message_id: u64,
220 pub fast_embedding: Vec<f16>,
222 pub quality_embedding: Vec<f16>,
224}
225
226#[derive(Debug)]
232pub struct TwoTierIndex {
233 pub metadata: TwoTierMetadata,
235 fs_index: Option<FsTwoTierIndex>,
237 doc_ids: Vec<DocumentId>,
239 message_ids: Vec<u64>,
241 _tmpdir: Option<tempfile::TempDir>,
243}
244
245impl TwoTierIndex {
246 pub fn build(
252 fast_embedder_id: impl Into<String>,
253 quality_embedder_id: impl Into<String>,
254 config: &TwoTierConfig,
255 entries: impl IntoIterator<Item = TwoTierEntry>,
256 ) -> Result<Self> {
257 let fast_embedder_id = fast_embedder_id.into();
258 let quality_embedder_id = quality_embedder_id.into();
259 let entries: Vec<TwoTierEntry> = entries.into_iter().collect();
260 let doc_count = entries.len();
261
262 let tmpdir = tempfile::TempDir::new()?;
263
264 if doc_count == 0 {
265 return Ok(Self {
266 metadata: TwoTierMetadata {
267 fast_embedder_id,
268 quality_embedder_id,
269 doc_count: 0,
270 built_at: chrono::Utc::now().timestamp(),
271 status: IndexStatus::Complete {
272 fast_latency_ms: 0,
273 quality_latency_ms: 0,
274 },
275 },
276 fs_index: None,
277 doc_ids: Vec::new(),
278 message_ids: Vec::new(),
279 _tmpdir: None,
280 });
281 }
282
283 for (i, entry) in entries.iter().enumerate() {
285 if entry.fast_embedding.len() != config.fast_dimension {
286 bail!(
287 "fast embedding dimension mismatch at index {}: expected {}, got {}",
288 i,
289 config.fast_dimension,
290 entry.fast_embedding.len()
291 );
292 }
293 if entry.quality_embedding.len() != config.quality_dimension {
294 bail!(
295 "quality embedding dimension mismatch at index {}: expected {}, got {}",
296 i,
297 config.quality_dimension,
298 entry.quality_embedding.len()
299 );
300 }
301 }
302
303 let fs_config = config.to_fs_config();
305 let mut builder = FsTwoTierIndex::create(tmpdir.path(), fs_config.clone())
306 .map_err(|e| anyhow::anyhow!("failed to create fs index builder: {e}"))?;
307 builder.set_fast_embedder_id(&fast_embedder_id);
308 builder.set_quality_embedder_id(&quality_embedder_id);
309
310 let mut metadata_by_encoded_id = HashMap::with_capacity(doc_count);
311
312 for entry in entries {
313 let doc_id_str = entry.doc_id.encode();
314 if metadata_by_encoded_id
315 .insert(doc_id_str.clone(), (entry.doc_id.clone(), entry.message_id))
316 .is_some()
317 {
318 bail!(
319 "duplicate document id encountered while building two-tier index: {doc_id_str}"
320 );
321 }
322 let fast_f32: Vec<f32> = entry.fast_embedding.iter().map(|v| f32::from(*v)).collect();
323 let quality_f32: Vec<f32> = entry
324 .quality_embedding
325 .iter()
326 .map(|v| f32::from(*v))
327 .collect();
328
329 builder
330 .add_record(&doc_id_str, &fast_f32, Some(&quality_f32))
331 .map_err(|e| anyhow::anyhow!("failed to add record {doc_id_str}: {e}"))?;
332 }
333
334 let fs_index = builder
335 .finish()
336 .map_err(|e| anyhow::anyhow!("failed to finish fs index: {e}"))?;
337
338 let mut doc_ids = Vec::with_capacity(doc_count);
342 let mut message_ids = Vec::with_capacity(doc_count);
343 for idx in 0..doc_count {
344 let encoded = fs_index
345 .doc_id_at(idx)
346 .map_err(|e| anyhow::anyhow!("failed to read fs doc_id at index {idx}: {e}"))?;
347 let (doc_id, message_id) = metadata_by_encoded_id.remove(encoded).ok_or_else(|| {
348 anyhow::anyhow!(
349 "frankensearch index returned unknown doc_id at index {idx}: {encoded}"
350 )
351 })?;
352 doc_ids.push(doc_id);
353 message_ids.push(message_id);
354 }
355
356 Ok(Self {
357 metadata: TwoTierMetadata {
358 fast_embedder_id,
359 quality_embedder_id,
360 doc_count,
361 built_at: chrono::Utc::now().timestamp(),
362 status: IndexStatus::Complete {
363 fast_latency_ms: 0,
364 quality_latency_ms: 0,
365 },
366 },
367 fs_index: Some(fs_index),
368 doc_ids,
369 message_ids,
370 _tmpdir: Some(tmpdir),
371 })
372 }
373
374 pub fn len(&self) -> usize {
376 self.metadata.doc_count
377 }
378
379 pub fn is_empty(&self) -> bool {
381 self.metadata.doc_count == 0
382 }
383
384 pub fn doc_id(&self, idx: usize) -> Option<&DocumentId> {
386 self.doc_ids.get(idx)
387 }
388
389 pub fn message_id(&self, idx: usize) -> Option<u64> {
391 self.message_ids.get(idx).copied()
392 }
393
394 pub fn search_fast(&self, query_vec: &[f32], k: usize) -> Vec<ScoredResult> {
398 if self.is_empty() || k == 0 {
399 return Vec::new();
400 }
401
402 let Some(fs_index) = &self.fs_index else {
403 return Vec::new();
404 };
405
406 match fs_index.search_fast(query_vec, k) {
407 Ok(hits) => self.hits_to_scored_results(hits),
408 Err(e) => {
409 warn!(error = %e, "frankensearch fast search failed");
410 Vec::new()
411 }
412 }
413 }
414
415 pub fn search_quality(&self, query_vec: &[f32], k: usize) -> Vec<ScoredResult> {
421 if self.is_empty() || k == 0 {
422 return Vec::new();
423 }
424
425 let Some(fs_index) = &self.fs_index else {
426 return Vec::new();
427 };
428
429 let all_hits: Vec<FsVectorHit> = (0..self.metadata.doc_count)
431 .map(|i| FsVectorHit {
432 index: i as u32,
433 score: 0.0,
434 doc_id: self.doc_ids[i].encode(),
435 })
436 .collect();
437
438 match fs_index.quality_scores_for_hits(query_vec, &all_hits) {
439 Ok(scores) => {
440 let mut results: Vec<ScoredResult> = scores
443 .iter()
444 .enumerate()
445 .filter_map(|(idx, score)| {
446 let s = (*score)?;
447 let message_id = *self.message_ids.get(idx)?;
448 Some(ScoredResult {
449 idx,
450 message_id,
451 score: s,
452 })
453 })
454 .collect();
455 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
456 results.truncate(k);
457 results
458 }
459 Err(e) => {
460 warn!(error = %e, "frankensearch quality search failed");
461 Vec::new()
462 }
463 }
464 }
465
466 pub fn quality_scores_for_indices(&self, query_vec: &[f32], indices: &[usize]) -> Vec<f32> {
468 let Some(fs_index) = &self.fs_index else {
469 return vec![0.0; indices.len()];
470 };
471
472 let hits: Vec<FsVectorHit> = indices
473 .iter()
474 .filter_map(|&idx| {
475 if idx < self.metadata.doc_count {
476 Some(FsVectorHit {
477 index: idx as u32,
478 score: 0.0,
479 doc_id: self.doc_ids[idx].encode(),
480 })
481 } else {
482 None
483 }
484 })
485 .collect();
486
487 match fs_index.quality_scores_for_hits(query_vec, &hits) {
488 Ok(scores) => scores.into_iter().map(|s| s.unwrap_or(0.0)).collect(),
489 Err(e) => {
490 warn!(error = %e, "frankensearch quality scoring failed; using zero scores");
491 vec![0.0; indices.len()]
492 }
493 }
494 }
495
496 fn hits_to_scored_results(&self, hits: Vec<FsVectorHit>) -> Vec<ScoredResult> {
498 hits.into_iter()
499 .filter_map(|hit| {
500 let idx = hit.index as usize;
501 if idx < self.metadata.doc_count {
502 Some(ScoredResult {
503 idx,
504 message_id: self.message_ids[idx],
505 score: hit.score,
506 })
507 } else {
508 None
509 }
510 })
511 .collect()
512 }
513}
514
515#[derive(Debug, Clone)]
517pub struct ScoredResult {
518 pub idx: usize,
520 pub message_id: u64,
522 pub score: f32,
524}
525
526#[derive(Debug, Clone)]
528pub enum SearchPhase {
529 Initial {
531 results: Vec<ScoredResult>,
532 latency_ms: u64,
533 },
534 Refined {
536 results: Vec<ScoredResult>,
537 latency_ms: u64,
538 },
539 RefinementFailed { error: String },
541}
542
543pub struct TwoTierSearcher<'a, D: DaemonClient> {
545 index: &'a TwoTierIndex,
546 daemon: Option<Arc<D>>,
547 fast_embedder: Arc<dyn Embedder>,
548 config: TwoTierConfig,
549}
550
551impl<'a, D: DaemonClient> TwoTierSearcher<'a, D> {
552 pub fn new(
554 index: &'a TwoTierIndex,
555 fast_embedder: Arc<dyn Embedder>,
556 daemon: Option<Arc<D>>,
557 config: TwoTierConfig,
558 ) -> Self {
559 Self {
560 index,
561 daemon,
562 fast_embedder,
563 config,
564 }
565 }
566
567 pub fn search(&self, query: &str, k: usize) -> impl Iterator<Item = SearchPhase> + '_ {
573 TwoTierSearchIter::new(self, query.to_string(), k)
574 }
575
576 pub fn search_fast_only(&self, query: &str, k: usize) -> Result<Vec<ScoredResult>> {
578 let start = Instant::now();
579 let query_vec = self.fast_embedder.embed_sync(query)?;
580 let results = self.index.search_fast(&query_vec, k);
581 debug!(
582 query_len = query.len(),
583 k = k,
584 result_count = results.len(),
585 latency_ms = start.elapsed().as_millis(),
586 "Fast-only search completed"
587 );
588 Ok(results)
589 }
590
591 pub fn search_quality_only(
593 &self,
594 query: &str,
595 k: usize,
596 ) -> Result<Vec<ScoredResult>, TwoTierError> {
597 let start = Instant::now();
598
599 let daemon = self
600 .daemon
601 .as_ref()
602 .ok_or_else(|| TwoTierError::DaemonUnavailable("no daemon configured".into()))?;
603
604 if !daemon.is_available() {
605 return Err(TwoTierError::DaemonUnavailable(
606 "daemon not available".into(),
607 ));
608 }
609
610 let request_id = format!("quality-{:016x}", rand::random::<u64>());
611 let query_vec = daemon
612 .embed(query, &request_id)
613 .map_err(TwoTierError::DaemonError)?;
614
615 let results = self.index.search_quality(&query_vec, k);
616 debug!(
617 query_len = query.len(),
618 k = k,
619 result_count = results.len(),
620 latency_ms = start.elapsed().as_millis(),
621 "Quality-only search completed"
622 );
623 Ok(results)
624 }
625}
626
627struct TwoTierSearchIter<'a, D: DaemonClient> {
629 searcher: &'a TwoTierSearcher<'a, D>,
630 query: String,
631 k: usize,
632 phase: u8,
633 fast_results: Option<Vec<ScoredResult>>,
634}
635
636impl<'a, D: DaemonClient> TwoTierSearchIter<'a, D> {
637 fn new(searcher: &'a TwoTierSearcher<'a, D>, query: String, k: usize) -> Self {
638 Self {
639 searcher,
640 query,
641 k,
642 phase: 0,
643 fast_results: None,
644 }
645 }
646}
647
648impl<'a, D: DaemonClient> Iterator for TwoTierSearchIter<'a, D> {
649 type Item = SearchPhase;
650
651 fn next(&mut self) -> Option<Self::Item> {
652 match self.phase {
653 0 => {
654 if self.searcher.config.quality_only {
655 self.phase = 2;
656 let start = Instant::now();
657 return match self.searcher.search_quality_only(&self.query, self.k) {
658 Ok(results) => Some(SearchPhase::Refined {
659 results,
660 latency_ms: start.elapsed().as_millis() as u64,
661 }),
662 Err(e) => Some(SearchPhase::RefinementFailed {
663 error: e.to_string(),
664 }),
665 };
666 }
667
668 self.phase = 1;
670 let start = Instant::now();
671
672 match self.searcher.fast_embedder.embed_sync(&self.query) {
673 Ok(query_vec) => {
674 let results = self.searcher.index.search_fast(&query_vec, self.k);
675 let latency_ms = start.elapsed().as_millis() as u64;
676 self.fast_results = Some(results.clone());
677
678 if self.searcher.config.fast_only {
679 self.phase = 2;
680 }
681
682 Some(SearchPhase::Initial {
683 results,
684 latency_ms,
685 })
686 }
687 Err(e) => {
688 warn!(error = %e, "Fast embedding failed");
689 self.phase = 2;
690 Some(SearchPhase::RefinementFailed {
691 error: format!("fast embedding failed: {e}"),
692 })
693 }
694 }
695 }
696 1 => {
697 self.phase = 2;
699
700 let daemon = match &self.searcher.daemon {
701 Some(d) if d.is_available() => d,
702 _ => {
703 return Some(SearchPhase::RefinementFailed {
704 error: "daemon unavailable".to_string(),
705 });
706 }
707 };
708
709 let start = Instant::now();
710 let request_id = format!("refine-{:016x}", rand::random::<u64>());
711
712 match daemon.embed(&self.query, &request_id) {
713 Ok(query_vec) => {
714 let results = if let Some(fast_results) = self.fast_results.as_ref() {
715 let refine_cap = self.searcher.config.max_refinement_docs;
716 let candidates: Vec<usize> = fast_results
717 .iter()
718 .take(refine_cap)
719 .map(|sr| sr.idx)
720 .collect();
721 if candidates.is_empty() {
722 fast_results.clone()
723 } else {
724 let quality_scores = self
725 .searcher
726 .index
727 .quality_scores_for_indices(&query_vec, &candidates);
728
729 let weight = self.searcher.config.quality_weight;
730 let fast_scores: Vec<f32> =
731 fast_results.iter().map(|sr| sr.score).collect();
732 let fast_norm = normalize_scores(&fast_scores);
733 let quality_norm = normalize_scores(&quality_scores);
734
735 let mut blended: Vec<ScoredResult> =
736 Vec::with_capacity(fast_results.len());
737 for (idx, fast) in fast_results.iter().enumerate() {
738 let fast_s = fast_norm.get(idx).copied().unwrap_or(0.0);
739 let score = if idx < quality_norm.len() {
740 let quality_s =
741 quality_norm.get(idx).copied().unwrap_or(0.0);
742 (1.0 - weight) * fast_s + weight * quality_s
743 } else {
744 fast_s * (1.0 - weight)
748 };
749 blended.push(ScoredResult {
750 idx: fast.idx,
751 message_id: fast.message_id,
752 score,
753 });
754 }
755
756 blended.sort_by(|a, b| {
757 b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal)
758 });
759 blended.truncate(self.k);
760 blended
761 }
762 } else {
763 self.searcher.index.search_quality(&query_vec, self.k)
764 };
765
766 let latency_ms = start.elapsed().as_millis() as u64;
767 Some(SearchPhase::Refined {
768 results,
769 latency_ms,
770 })
771 }
772 Err(e) => Some(SearchPhase::RefinementFailed {
773 error: e.to_string(),
774 }),
775 }
776 }
777 _ => None,
778 }
779 }
780}
781
782#[derive(Debug, thiserror::Error)]
784pub enum TwoTierError {
785 #[error("daemon unavailable: {0}")]
786 DaemonUnavailable(String),
787
788 #[error("daemon error: {0}")]
789 DaemonError(#[from] DaemonError),
790
791 #[error("embedding failed: {0}")]
792 EmbeddingFailed(String),
793
794 #[error("index error: {0}")]
795 IndexError(String),
796}
797
798pub fn normalize_scores(scores: &[f32]) -> Vec<f32> {
800 if scores.is_empty() {
801 return Vec::new();
802 }
803
804 let mut min = f32::INFINITY;
805 let mut max = f32::NEG_INFINITY;
806 for &s in scores {
807 if s.is_finite() {
808 min = f32::min(min, s);
809 max = f32::max(max, s);
810 }
811 }
812
813 if min.is_infinite() || max.is_infinite() {
814 return vec![0.0; scores.len()];
815 }
816
817 let range = max - min;
818
819 if range.abs() < f32::EPSILON {
820 return scores
821 .iter()
822 .map(|&s| if s.is_finite() { 1.0 } else { 0.0 })
823 .collect();
824 }
825
826 scores
827 .iter()
828 .map(|&s| {
829 if s.is_finite() {
830 (s - min) / range
831 } else {
832 0.0
833 }
834 })
835 .collect()
836}
837
838pub fn blend_scores(fast: &[f32], quality: &[f32], quality_weight: f32) -> Vec<f32> {
840 let fast_norm = normalize_scores(fast);
841 let quality_norm = normalize_scores(quality);
842
843 fast_norm
844 .iter()
845 .zip(quality_norm.iter())
846 .map(|(&f, &q)| (1.0 - quality_weight) * f + quality_weight * q)
847 .collect()
848}
849
850#[cfg(test)]
851mod tests {
852 use super::*;
853 use crate::search::daemon_client::{DaemonClient, DaemonError};
854 use crate::search::embedder::{Embedder, EmbedderError};
855 use crate::search::hash_embedder::HashEmbedder;
856 use frankensearch::ModelCategory;
857 use std::sync::Arc;
858
859 struct TestDaemon {
860 dim: usize,
861 available: bool,
862 }
863
864 struct FailingEmbedder {
865 dim: usize,
866 }
867
868 struct ConstantEmbedder {
869 dim: usize,
870 value: f32,
871 }
872
873 impl Embedder for FailingEmbedder {
874 fn embed_sync(&self, _text: &str) -> Result<Vec<f32>, EmbedderError> {
875 Err(EmbedderError::EmbeddingFailed {
876 model: "failing-embedder".to_string(),
877 source: Box::new(std::io::Error::other("synthetic fast embed failure")),
878 })
879 }
880
881 fn dimension(&self) -> usize {
882 self.dim
883 }
884
885 fn id(&self) -> &str {
886 "failing-embedder"
887 }
888
889 fn is_semantic(&self) -> bool {
890 false
891 }
892
893 fn category(&self) -> ModelCategory {
894 ModelCategory::HashEmbedder
895 }
896 }
897
898 impl Embedder for ConstantEmbedder {
899 fn embed_sync(&self, _text: &str) -> Result<Vec<f32>, EmbedderError> {
900 Ok(vec![self.value; self.dim])
901 }
902
903 fn dimension(&self) -> usize {
904 self.dim
905 }
906
907 fn id(&self) -> &str {
908 "constant-embedder"
909 }
910
911 fn is_semantic(&self) -> bool {
912 false
913 }
914
915 fn category(&self) -> ModelCategory {
916 ModelCategory::HashEmbedder
917 }
918 }
919
920 impl DaemonClient for TestDaemon {
921 fn id(&self) -> &str {
922 "test-daemon"
923 }
924
925 fn is_available(&self) -> bool {
926 self.available
927 }
928
929 fn embed(&self, _text: &str, _request_id: &str) -> Result<Vec<f32>, DaemonError> {
930 Ok(vec![1.0; self.dim])
931 }
932
933 fn embed_batch(
934 &self,
935 texts: &[&str],
936 _request_id: &str,
937 ) -> Result<Vec<Vec<f32>>, DaemonError> {
938 Ok(vec![vec![1.0; self.dim]; texts.len()])
939 }
940
941 fn rerank(
942 &self,
943 _query: &str,
944 _documents: &[&str],
945 _request_id: &str,
946 ) -> Result<Vec<f32>, DaemonError> {
947 Err(DaemonError::Unavailable(
948 "rerank unsupported in test daemon".to_string(),
949 ))
950 }
951 }
952
953 fn make_test_entries(count: usize, fast_dim: usize, quality_dim: usize) -> Vec<TwoTierEntry> {
954 (0..count)
955 .map(|i| TwoTierEntry {
956 doc_id: DocumentId::Session(format!("session-{}", i)),
957 message_id: i as u64,
958 fast_embedding: (0..fast_dim)
959 .map(|j| f16::from_f32((i + j) as f32 * 0.01))
960 .collect(),
961 quality_embedding: (0..quality_dim)
962 .map(|j| f16::from_f32((i + j) as f32 * 0.01))
963 .collect(),
964 })
965 .collect()
966 }
967
968 #[test]
969 fn test_two_tier_index_creation() {
970 let config = TwoTierConfig::default();
971 let entries = make_test_entries(10, config.fast_dimension, config.quality_dimension);
972
973 let index = TwoTierIndex::build("fast-256", "quality-384", &config, entries).unwrap();
974
975 assert_eq!(index.len(), 10);
976 assert!(!index.is_empty());
977 assert!(matches!(
978 index.metadata.status,
979 IndexStatus::Complete { .. }
980 ));
981 }
982
983 #[test]
984 fn test_empty_index() {
985 let config = TwoTierConfig::default();
986 let entries: Vec<TwoTierEntry> = Vec::new();
987
988 let index = TwoTierIndex::build("fast-256", "quality-384", &config, entries).unwrap();
989
990 assert_eq!(index.len(), 0);
991 assert!(index.is_empty());
992 }
993
994 #[test]
995 fn test_dimension_mismatch_fast() {
996 let config = TwoTierConfig::default();
997 let entries = vec![TwoTierEntry {
998 doc_id: DocumentId::Session("test".into()),
999 message_id: 1,
1000 fast_embedding: vec![f16::from_f32(1.0); 128], quality_embedding: vec![f16::from_f32(1.0); config.quality_dimension],
1002 }];
1003
1004 let result = TwoTierIndex::build("fast", "quality", &config, entries);
1005 assert!(result.is_err());
1006 }
1007
1008 #[test]
1009 fn test_dimension_mismatch_quality() {
1010 let config = TwoTierConfig::default();
1011 let entries = vec![TwoTierEntry {
1012 doc_id: DocumentId::Session("test".into()),
1013 message_id: 1,
1014 fast_embedding: vec![f16::from_f32(1.0); config.fast_dimension],
1015 quality_embedding: vec![f16::from_f32(1.0); 128], }];
1017
1018 let result = TwoTierIndex::build("fast", "quality", &config, entries);
1019 assert!(result.is_err());
1020 }
1021
1022 #[test]
1023 fn test_fast_search() {
1024 let config = TwoTierConfig::default();
1025 let entries = make_test_entries(100, config.fast_dimension, config.quality_dimension);
1026 let index = TwoTierIndex::build("fast-256", "quality-384", &config, entries).unwrap();
1027
1028 let query: Vec<f32> = (0..config.fast_dimension)
1029 .map(|i| i as f32 * 0.01)
1030 .collect();
1031 let results = index.search_fast(&query, 10);
1032
1033 assert_eq!(results.len(), 10);
1034 for window in results.windows(2) {
1036 assert!(window[0].score >= window[1].score);
1037 }
1038 }
1039
1040 #[test]
1041 fn test_side_tables_follow_frankensearch_index_order() {
1042 let config = TwoTierConfig::default();
1043 let entries = vec![
1044 TwoTierEntry {
1045 doc_id: DocumentId::Session("session-z".into()),
1046 message_id: 300,
1047 fast_embedding: vec![f16::from_f32(1.0); config.fast_dimension],
1048 quality_embedding: vec![f16::from_f32(1.0); config.quality_dimension],
1049 },
1050 TwoTierEntry {
1051 doc_id: DocumentId::Session("session-a".into()),
1052 message_id: 100,
1053 fast_embedding: vec![f16::from_f32(0.5); config.fast_dimension],
1054 quality_embedding: vec![f16::from_f32(0.5); config.quality_dimension],
1055 },
1056 TwoTierEntry {
1057 doc_id: DocumentId::Session("session-m".into()),
1058 message_id: 200,
1059 fast_embedding: vec![f16::from_f32(0.25); config.fast_dimension],
1060 quality_embedding: vec![f16::from_f32(0.25); config.quality_dimension],
1061 },
1062 ];
1063 let expected_by_encoded = HashMap::from([
1064 ("s:session-z".to_string(), 300_u64),
1065 ("s:session-a".to_string(), 100_u64),
1066 ("s:session-m".to_string(), 200_u64),
1067 ]);
1068
1069 let index = TwoTierIndex::build("fast-256", "quality-384", &config, entries).unwrap();
1070 let fs_index = index.fs_index.as_ref().expect("non-empty fs index");
1071
1072 for idx in 0..index.len() {
1073 let encoded = fs_index.doc_id_at(idx).expect("fs doc_id");
1074 assert_eq!(index.doc_ids[idx].encode(), encoded);
1075 assert_eq!(index.message_ids[idx], expected_by_encoded[encoded]);
1076 }
1077 }
1078
1079 #[test]
1080 fn test_quality_search() {
1081 let config = TwoTierConfig::default();
1082 let entries = make_test_entries(100, config.fast_dimension, config.quality_dimension);
1083 let index = TwoTierIndex::build("fast-256", "quality-384", &config, entries).unwrap();
1084
1085 let query: Vec<f32> = (0..config.quality_dimension)
1086 .map(|i| i as f32 * 0.01)
1087 .collect();
1088 let results = index.search_quality(&query, 10);
1089
1090 assert_eq!(results.len(), 10);
1091 for window in results.windows(2) {
1093 assert!(window[0].score >= window[1].score);
1094 }
1095 }
1096
1097 #[test]
1098 fn test_score_normalization() {
1099 let scores = vec![0.8, 0.6, 0.4, 0.2];
1100 let normalized = normalize_scores(&scores);
1101
1102 assert!((normalized[0] - 1.0).abs() < 0.001);
1103 assert!((normalized[3] - 0.0).abs() < 0.001);
1104 }
1105
1106 #[test]
1107 fn test_score_normalization_constant() {
1108 let scores = vec![0.5, 0.5, 0.5];
1109 let normalized = normalize_scores(&scores);
1110
1111 for n in &normalized {
1112 assert!((n - 1.0).abs() < 0.001);
1113 }
1114 }
1115
1116 #[test]
1117 fn test_score_normalization_constant_with_nan_keeps_nan_zeroed() {
1118 let scores = vec![f32::NAN, 0.5, 0.5];
1119 let normalized = normalize_scores(&scores);
1120
1121 assert_eq!(normalized.len(), 3);
1122 assert_eq!(normalized[0], 0.0);
1123 assert!((normalized[1] - 1.0).abs() < 0.001);
1124 assert!((normalized[2] - 1.0).abs() < 0.001);
1125 }
1126
1127 #[test]
1128 fn test_score_normalization_with_infinite_values_keeps_non_finite_zeroed() {
1129 let scores = vec![f32::NEG_INFINITY, 2.0, f32::INFINITY, 4.0];
1130 let normalized = normalize_scores(&scores);
1131
1132 assert_eq!(normalized.len(), 4);
1133 assert_eq!(normalized[0], 0.0);
1134 assert_eq!(normalized[2], 0.0);
1135 assert!((normalized[1] - 0.0).abs() < 0.001);
1136 assert!((normalized[3] - 1.0).abs() < 0.001);
1137 }
1138
1139 #[test]
1140 fn test_score_normalization_empty() {
1141 let scores: Vec<f32> = vec![];
1142 let normalized = normalize_scores(&scores);
1143 assert!(normalized.is_empty());
1144 }
1145
1146 #[test]
1147 fn test_blend_scores() {
1148 let fast = vec![0.8, 0.6, 0.4];
1149 let quality = vec![0.4, 0.8, 0.6];
1150 let blended = blend_scores(&fast, &quality, 0.5);
1151
1152 assert_eq!(blended.len(), 3);
1153 }
1154
1155 #[test]
1156 fn test_document_id_session() {
1157 let doc_id = DocumentId::Session("test-session".into());
1158 assert_eq!(doc_id.session_id(), "test-session");
1159 }
1160
1161 #[test]
1162 fn test_document_id_turn() {
1163 let doc_id = DocumentId::Turn("test-session".into(), 5);
1164 assert_eq!(doc_id.session_id(), "test-session");
1165 }
1166
1167 #[test]
1168 fn test_document_id_code_block() {
1169 let doc_id = DocumentId::CodeBlock("test-session".into(), 3, 2);
1170 assert_eq!(doc_id.session_id(), "test-session");
1171 }
1172
1173 #[test]
1174 fn test_config_defaults() {
1175 let config = TwoTierConfig::default();
1176 assert_eq!(config.fast_dimension, 256);
1177 assert_eq!(config.quality_dimension, 384);
1178 assert!((config.quality_weight - 0.7).abs() < 0.001);
1179 assert_eq!(config.max_refinement_docs, 100);
1180 assert!(!config.fast_only);
1181 assert!(!config.quality_only);
1182 }
1183
1184 #[test]
1185 fn test_config_fast_only() {
1186 let config = TwoTierConfig::fast_only();
1187 assert!(config.fast_only);
1188 assert!(!config.quality_only);
1189 }
1190
1191 #[test]
1192 fn test_config_quality_only() {
1193 let config = TwoTierConfig::quality_only();
1194 assert!(!config.fast_only);
1195 assert!(config.quality_only);
1196 }
1197
1198 #[test]
1199 fn test_quality_scores_for_indices() {
1200 let config = TwoTierConfig::default();
1201 let entries = make_test_entries(10, config.fast_dimension, config.quality_dimension);
1202 let index = TwoTierIndex::build("fast-256", "quality-384", &config, entries).unwrap();
1203
1204 let query: Vec<f32> = (0..config.quality_dimension)
1205 .map(|i| i as f32 * 0.01)
1206 .collect();
1207 let indices = vec![0, 2, 4];
1208 let scores = index.quality_scores_for_indices(&query, &indices);
1209
1210 assert_eq!(scores.len(), 3);
1211 }
1212
1213 #[test]
1214 fn test_search_fast_dimension_mismatch_returns_empty() {
1215 let config = TwoTierConfig::default();
1216 let entries = make_test_entries(5, config.fast_dimension, config.quality_dimension);
1217 let index = TwoTierIndex::build("fast-256", "quality-384", &config, entries).unwrap();
1218
1219 let bad_query = vec![0.5; config.fast_dimension.saturating_sub(1)];
1220 let results = index.search_fast(&bad_query, 5);
1221 assert!(results.is_empty());
1222 }
1223
1224 #[test]
1225 fn test_search_quality_dimension_mismatch_returns_empty() {
1226 let config = TwoTierConfig::default();
1227 let entries = make_test_entries(5, config.fast_dimension, config.quality_dimension);
1228 let index = TwoTierIndex::build("fast-256", "quality-384", &config, entries).unwrap();
1229
1230 let bad_query = vec![0.5; config.quality_dimension.saturating_sub(1)];
1231 let results = index.search_quality(&bad_query, 5);
1232 assert!(results.is_empty());
1233 }
1234
1235 #[test]
1236 fn test_quality_scores_for_indices_dimension_mismatch_returns_zeros() {
1237 let config = TwoTierConfig::default();
1238 let entries = make_test_entries(5, config.fast_dimension, config.quality_dimension);
1239 let index = TwoTierIndex::build("fast-256", "quality-384", &config, entries).unwrap();
1240
1241 let bad_query = vec![0.5; config.quality_dimension.saturating_sub(1)];
1242 let scores = index.quality_scores_for_indices(&bad_query, &[0, 2, 4]);
1243 assert_eq!(scores, vec![0.0, 0.0, 0.0]);
1244 }
1245
1246 #[test]
1247 fn test_quality_only_mode_emits_only_refined_phase() {
1248 let config = TwoTierConfig {
1249 fast_dimension: 8,
1250 quality_dimension: 8,
1251 quality_only: true,
1252 ..Default::default()
1253 };
1254 let entries = make_test_entries(4, config.fast_dimension, config.quality_dimension);
1255 let index = TwoTierIndex::build("fast-8", "quality-8", &config, entries).unwrap();
1256
1257 let fast_embedder: Arc<dyn Embedder> = Arc::new(HashEmbedder::new(config.fast_dimension));
1258 let daemon = Arc::new(TestDaemon {
1259 dim: config.quality_dimension,
1260 available: true,
1261 });
1262 let searcher = TwoTierSearcher::new(&index, fast_embedder, Some(daemon), config);
1263 let phases: Vec<SearchPhase> = searcher.search("query", 3).collect();
1264
1265 assert_eq!(phases.len(), 1);
1266 assert!(matches!(phases[0], SearchPhase::Refined { .. }));
1267 }
1268
1269 #[test]
1270 fn test_quality_only_mode_without_daemon_reports_failure() {
1271 let config = TwoTierConfig {
1272 fast_dimension: 8,
1273 quality_dimension: 8,
1274 quality_only: true,
1275 ..Default::default()
1276 };
1277 let entries = make_test_entries(4, config.fast_dimension, config.quality_dimension);
1278 let index = TwoTierIndex::build("fast-8", "quality-8", &config, entries).unwrap();
1279
1280 let fast_embedder: Arc<dyn Embedder> = Arc::new(HashEmbedder::new(config.fast_dimension));
1281 let daemon = Arc::new(TestDaemon {
1282 dim: config.quality_dimension,
1283 available: false,
1284 });
1285 let searcher = TwoTierSearcher::new(&index, fast_embedder, Some(daemon), config);
1286 let phases: Vec<SearchPhase> = searcher.search("query", 3).collect();
1287
1288 assert_eq!(phases.len(), 1);
1289 assert!(matches!(phases[0], SearchPhase::RefinementFailed { .. }));
1290 }
1291
1292 #[test]
1293 fn test_fast_embedding_failure_yields_failure_phase() {
1294 let config = TwoTierConfig {
1295 fast_dimension: 8,
1296 quality_dimension: 8,
1297 fast_only: false,
1298 quality_only: false,
1299 ..Default::default()
1300 };
1301 let entries = make_test_entries(4, config.fast_dimension, config.quality_dimension);
1302 let index = TwoTierIndex::build("fast-8", "quality-8", &config, entries).unwrap();
1303
1304 let fast_embedder: Arc<dyn Embedder> = Arc::new(FailingEmbedder {
1305 dim: config.fast_dimension,
1306 });
1307 let daemon = Arc::new(TestDaemon {
1308 dim: config.quality_dimension,
1309 available: true,
1310 });
1311 let searcher = TwoTierSearcher::new(&index, fast_embedder, Some(daemon), config);
1312 let phases: Vec<SearchPhase> = searcher.search("query", 3).collect();
1313
1314 assert_eq!(phases.len(), 1);
1315 assert!(matches!(phases[0], SearchPhase::RefinementFailed { .. }));
1316 }
1317
1318 #[test]
1319 fn test_refinement_scores_are_normalized() {
1320 let config = TwoTierConfig {
1321 fast_dimension: 8,
1322 quality_dimension: 8,
1323 quality_weight: 0.6,
1324 max_refinement_docs: 3,
1325 ..Default::default()
1326 };
1327 let entries: Vec<TwoTierEntry> = (0..5)
1328 .map(|i| TwoTierEntry {
1329 doc_id: DocumentId::Session(format!("s{i}")),
1330 message_id: i as u64 + 1,
1331 fast_embedding: vec![f16::from_f32(20.0 + i as f32); config.fast_dimension],
1332 quality_embedding: vec![f16::from_f32(10.0 + i as f32); config.quality_dimension],
1333 })
1334 .collect();
1335 let index = TwoTierIndex::build("fast-8", "quality-8", &config, entries).unwrap();
1336
1337 let fast_embedder: Arc<dyn Embedder> = Arc::new(ConstantEmbedder {
1338 dim: config.fast_dimension,
1339 value: 10.0,
1340 });
1341 let daemon = Arc::new(TestDaemon {
1342 dim: config.quality_dimension,
1343 available: true,
1344 });
1345 let searcher = TwoTierSearcher::new(&index, fast_embedder, Some(daemon), config);
1346 let phases: Vec<SearchPhase> = searcher.search("query", 5).collect();
1347
1348 assert_eq!(phases.len(), 2);
1349 let SearchPhase::Refined { results, .. } = &phases[1] else {
1350 panic!("expected refined phase");
1351 };
1352 assert!(
1353 results.iter().all(|r| (0.0..=1.0).contains(&r.score)),
1354 "expected normalized refined scores, got {:?}",
1355 results.iter().map(|r| r.score).collect::<Vec<_>>()
1356 );
1357 }
1358}