Skip to main content

coding_agent_search/search/
two_tier_search.rs

1//! Two-tier progressive search for session search (bd-3dcw, bd-2fu7e).
2//!
3//! This module implements a progressive search strategy that:
4//! 1. Returns instant results using a fast embedding model (in-process)
5//! 2. Refines rankings in the background using a quality model (daemon)
6//!
7//! **Delegates to frankensearch**: The vector storage and search are backed by
8//! `frankensearch_index::TwoTierIndex` (file-backed FSVI). This module adds
9//! cass-specific layers: synchronous `Iterator`-based search, `DocumentId`
10//! enum, `message_id` for SQLite, and `DaemonClient` integration.
11//!
12//! # Architecture
13//!
14//! ```text
15//! User Query
16//!     │
17//!     ├──→ [Fast Embedder] ──→ Results in ~1ms (display immediately)
18//!     │       (in-process)
19//!     │
20//!     └──→ [Quality Daemon] ──→ Refined scores in ~130ms
21//!              (warm UDS)           │
22//!                                   ▼
23//!                           Smooth re-rank
24//! ```
25//!
26//! # Usage
27//!
28//! ```ignore
29//! use cass::search::two_tier_search::{TwoTierIndex, TwoTierConfig, SearchPhase};
30//!
31//! let index = TwoTierIndex::build("fast", "quality", &config, entries)?;
32//! let searcher = TwoTierSearcher::new(&index, fast_embedder, Some(daemon), config);
33//!
34//! for phase in searcher.search("authentication middleware", 10) {
35//!     match phase {
36//!         SearchPhase::Initial { results, latency_ms } => {
37//!             // Display instant results
38//!         }
39//!         SearchPhase::Refined { results, latency_ms } => {
40//!             // Update with refined results
41//!         }
42//!         SearchPhase::RefinementFailed { error } => {
43//!             // Keep showing initial results
44//!         }
45//!     }
46//! }
47//! ```
48
49use std::cmp::Ordering;
50use std::collections::HashMap;
51use std::sync::Arc;
52use std::time::Instant;
53
54use anyhow::{Result, bail};
55use half::f16;
56use tracing::{debug, warn};
57
58use super::daemon_client::{DaemonClient, DaemonError};
59use super::embedder::Embedder;
60
61// Frankensearch types for vector storage and search delegation.
62use frankensearch::TwoTierConfig as FsTwoTierConfig;
63use frankensearch::{TwoTierIndex as FsTwoTierIndex, VectorHit as FsVectorHit};
64
65/// Configuration for two-tier search.
66#[derive(Debug, Clone)]
67pub struct TwoTierConfig {
68    /// Dimension for fast embeddings (default: 256).
69    pub fast_dimension: usize,
70    /// Dimension for quality embeddings (default: 384).
71    pub quality_dimension: usize,
72    /// Weight for quality scores when blending (default: 0.7).
73    pub quality_weight: f32,
74    /// Maximum documents to refine via daemon (default: 100).
75    pub max_refinement_docs: usize,
76    /// Whether to skip quality refinement entirely.
77    pub fast_only: bool,
78    /// Whether to wait for quality results before returning.
79    pub quality_only: bool,
80}
81
82impl Default for TwoTierConfig {
83    fn default() -> Self {
84        Self {
85            fast_dimension: 256,
86            quality_dimension: 384,
87            quality_weight: 0.7,
88            max_refinement_docs: 100,
89            fast_only: false,
90            quality_only: false,
91        }
92    }
93}
94
95impl TwoTierConfig {
96    /// Load config from environment variables.
97    pub fn from_env() -> Self {
98        let mut cfg = Self::default();
99
100        if let Ok(val) = dotenvy::var("CASS_TWO_TIER_FAST_DIM")
101            && let Ok(dim) = val.parse()
102        {
103            cfg.fast_dimension = dim;
104        }
105
106        if let Ok(val) = dotenvy::var("CASS_TWO_TIER_QUALITY_DIM")
107            && let Ok(dim) = val.parse()
108        {
109            cfg.quality_dimension = dim;
110        }
111
112        if let Ok(val) = dotenvy::var("CASS_TWO_TIER_QUALITY_WEIGHT")
113            && let Ok(weight) = val.parse::<f32>()
114        {
115            cfg.quality_weight = weight.clamp(0.0, 1.0);
116        }
117
118        if let Ok(val) = dotenvy::var("CASS_TWO_TIER_MAX_REFINEMENT")
119            && let Ok(max) = val.parse()
120        {
121            cfg.max_refinement_docs = max;
122        }
123
124        cfg
125    }
126
127    /// Create config for fast-only mode.
128    pub fn fast_only() -> Self {
129        Self {
130            fast_only: true,
131            ..Self::default()
132        }
133    }
134
135    /// Create config for quality-only mode.
136    pub fn quality_only() -> Self {
137        Self {
138            quality_only: true,
139            ..Self::default()
140        }
141    }
142
143    /// Convert to frankensearch TwoTierConfig.
144    fn to_fs_config(&self) -> FsTwoTierConfig {
145        FsTwoTierConfig {
146            quality_weight: f64::from(self.quality_weight),
147            fast_only: self.fast_only,
148            ..FsTwoTierConfig::optimized().with_env_overrides()
149        }
150    }
151}
152
153/// Document identifier for two-tier index entries.
154#[derive(Debug, Clone, PartialEq, Eq, Hash)]
155pub enum DocumentId {
156    /// Full session document.
157    Session(String),
158    /// Session turn (session_id, turn_index).
159    Turn(String, usize),
160    /// Code block within a turn (session_id, turn_index, code_block_index).
161    CodeBlock(String, usize, usize),
162}
163
164impl DocumentId {
165    /// Get the session ID.
166    pub fn session_id(&self) -> &str {
167        match self {
168            Self::Session(id) => id,
169            Self::Turn(id, _) => id,
170            Self::CodeBlock(id, _, _) => id,
171        }
172    }
173
174    /// Encode as a string for frankensearch doc_id storage.
175    fn encode(&self) -> String {
176        match self {
177            Self::Session(id) => format!("s:{id}"),
178            Self::Turn(id, turn) => format!("t:{id}:{turn}"),
179            Self::CodeBlock(id, turn, block) => format!("c:{id}:{turn}:{block}"),
180        }
181    }
182}
183
184/// Metadata for a two-tier index.
185#[derive(Debug, Clone)]
186pub struct TwoTierMetadata {
187    /// Fast embedder ID (e.g., "potion-128m").
188    pub fast_embedder_id: String,
189    /// Quality embedder ID (e.g., "minilm-384").
190    pub quality_embedder_id: String,
191    /// Document count.
192    pub doc_count: usize,
193    /// Index build timestamp (Unix seconds).
194    pub built_at: i64,
195    /// Index status.
196    pub status: IndexStatus,
197}
198
199/// Index build status.
200#[derive(Debug, Clone)]
201pub enum IndexStatus {
202    /// Index is being built.
203    Building { progress: f32 },
204    /// Index is complete.
205    Complete {
206        fast_latency_ms: u64,
207        quality_latency_ms: u64,
208    },
209    /// Index build failed.
210    Failed { error: String },
211}
212
213/// Two-tier index entry with both fast and quality embeddings.
214#[derive(Debug, Clone)]
215pub struct TwoTierEntry {
216    /// Document identifier.
217    pub doc_id: DocumentId,
218    /// Message ID for SQLite lookup.
219    pub message_id: u64,
220    /// Fast embedding (f16 quantized).
221    pub fast_embedding: Vec<f16>,
222    /// Quality embedding (f16 quantized).
223    pub quality_embedding: Vec<f16>,
224}
225
226/// Two-tier index for progressive search.
227///
228/// Delegates vector storage and search to frankensearch's file-backed FSVI
229/// `TwoTierIndex`, with cass-specific side tables for `DocumentId` enum
230/// and `message_id` SQLite foreign keys.
231#[derive(Debug)]
232pub struct TwoTierIndex {
233    /// Index metadata.
234    pub metadata: TwoTierMetadata,
235    /// Frankensearch file-backed two-tier index (None when empty).
236    fs_index: Option<FsTwoTierIndex>,
237    /// Document IDs in index order (cass-specific enum).
238    doc_ids: Vec<DocumentId>,
239    /// Message IDs for SQLite lookup (parallel to doc_ids).
240    message_ids: Vec<u64>,
241    /// Temp directory holding FSVI files (kept alive for index lifetime).
242    _tmpdir: Option<tempfile::TempDir>,
243}
244
245impl TwoTierIndex {
246    /// Build a two-tier index from entries.
247    ///
248    /// Creates a temporary FSVI index via frankensearch's `TwoTierIndexBuilder`,
249    /// then opens it for search. The temp directory is kept alive as long as the
250    /// index exists.
251    pub fn build(
252        fast_embedder_id: impl Into<String>,
253        quality_embedder_id: impl Into<String>,
254        config: &TwoTierConfig,
255        entries: impl IntoIterator<Item = TwoTierEntry>,
256    ) -> Result<Self> {
257        let fast_embedder_id = fast_embedder_id.into();
258        let quality_embedder_id = quality_embedder_id.into();
259        let entries: Vec<TwoTierEntry> = entries.into_iter().collect();
260        let doc_count = entries.len();
261
262        let tmpdir = tempfile::TempDir::new()?;
263
264        if doc_count == 0 {
265            return Ok(Self {
266                metadata: TwoTierMetadata {
267                    fast_embedder_id,
268                    quality_embedder_id,
269                    doc_count: 0,
270                    built_at: chrono::Utc::now().timestamp(),
271                    status: IndexStatus::Complete {
272                        fast_latency_ms: 0,
273                        quality_latency_ms: 0,
274                    },
275                },
276                fs_index: None,
277                doc_ids: Vec::new(),
278                message_ids: Vec::new(),
279                _tmpdir: None,
280            });
281        }
282
283        // Validate dimensions
284        for (i, entry) in entries.iter().enumerate() {
285            if entry.fast_embedding.len() != config.fast_dimension {
286                bail!(
287                    "fast embedding dimension mismatch at index {}: expected {}, got {}",
288                    i,
289                    config.fast_dimension,
290                    entry.fast_embedding.len()
291                );
292            }
293            if entry.quality_embedding.len() != config.quality_dimension {
294                bail!(
295                    "quality embedding dimension mismatch at index {}: expected {}, got {}",
296                    i,
297                    config.quality_dimension,
298                    entry.quality_embedding.len()
299                );
300            }
301        }
302
303        // Build frankensearch index
304        let fs_config = config.to_fs_config();
305        let mut builder = FsTwoTierIndex::create(tmpdir.path(), fs_config.clone())
306            .map_err(|e| anyhow::anyhow!("failed to create fs index builder: {e}"))?;
307        builder.set_fast_embedder_id(&fast_embedder_id);
308        builder.set_quality_embedder_id(&quality_embedder_id);
309
310        let mut metadata_by_encoded_id = HashMap::with_capacity(doc_count);
311
312        for entry in entries {
313            let doc_id_str = entry.doc_id.encode();
314            if metadata_by_encoded_id
315                .insert(doc_id_str.clone(), (entry.doc_id.clone(), entry.message_id))
316                .is_some()
317            {
318                bail!(
319                    "duplicate document id encountered while building two-tier index: {doc_id_str}"
320                );
321            }
322            let fast_f32: Vec<f32> = entry.fast_embedding.iter().map(|v| f32::from(*v)).collect();
323            let quality_f32: Vec<f32> = entry
324                .quality_embedding
325                .iter()
326                .map(|v| f32::from(*v))
327                .collect();
328
329            builder
330                .add_record(&doc_id_str, &fast_f32, Some(&quality_f32))
331                .map_err(|e| anyhow::anyhow!("failed to add record {doc_id_str}: {e}"))?;
332        }
333
334        let fs_index = builder
335            .finish()
336            .map_err(|e| anyhow::anyhow!("failed to finish fs index: {e}"))?;
337
338        // frankensearch persists records sorted by doc_id hash/doc_id, so hit indices
339        // are in fast-index order rather than cass insertion order. Rebuild our side
340        // tables to match that canonical order before any search results are exposed.
341        let mut doc_ids = Vec::with_capacity(doc_count);
342        let mut message_ids = Vec::with_capacity(doc_count);
343        for idx in 0..doc_count {
344            let encoded = fs_index
345                .doc_id_at(idx)
346                .map_err(|e| anyhow::anyhow!("failed to read fs doc_id at index {idx}: {e}"))?;
347            let (doc_id, message_id) = metadata_by_encoded_id.remove(encoded).ok_or_else(|| {
348                anyhow::anyhow!(
349                    "frankensearch index returned unknown doc_id at index {idx}: {encoded}"
350                )
351            })?;
352            doc_ids.push(doc_id);
353            message_ids.push(message_id);
354        }
355
356        Ok(Self {
357            metadata: TwoTierMetadata {
358                fast_embedder_id,
359                quality_embedder_id,
360                doc_count,
361                built_at: chrono::Utc::now().timestamp(),
362                status: IndexStatus::Complete {
363                    fast_latency_ms: 0,
364                    quality_latency_ms: 0,
365                },
366            },
367            fs_index: Some(fs_index),
368            doc_ids,
369            message_ids,
370            _tmpdir: Some(tmpdir),
371        })
372    }
373
374    /// Get the number of documents in the index.
375    pub fn len(&self) -> usize {
376        self.metadata.doc_count
377    }
378
379    /// Check if the index is empty.
380    pub fn is_empty(&self) -> bool {
381        self.metadata.doc_count == 0
382    }
383
384    /// Get document ID at index.
385    pub fn doc_id(&self, idx: usize) -> Option<&DocumentId> {
386        self.doc_ids.get(idx)
387    }
388
389    /// Get message ID at index.
390    pub fn message_id(&self, idx: usize) -> Option<u64> {
391        self.message_ids.get(idx).copied()
392    }
393
394    /// Search using fast embeddings only.
395    ///
396    /// Delegates to frankensearch's `TwoTierIndex::search_fast()`.
397    pub fn search_fast(&self, query_vec: &[f32], k: usize) -> Vec<ScoredResult> {
398        if self.is_empty() || k == 0 {
399            return Vec::new();
400        }
401
402        let Some(fs_index) = &self.fs_index else {
403            return Vec::new();
404        };
405
406        match fs_index.search_fast(query_vec, k) {
407            Ok(hits) => self.hits_to_scored_results(hits),
408            Err(e) => {
409                warn!(error = %e, "frankensearch fast search failed");
410                Vec::new()
411            }
412        }
413    }
414
415    /// Search using quality embeddings only.
416    ///
417    /// Delegates to frankensearch's quality search via `search_fast` on the
418    /// quality index. Since frankensearch's `TwoTierIndex` stores both tiers,
419    /// we use `quality_scores_for_hits` with all documents as candidates.
420    pub fn search_quality(&self, query_vec: &[f32], k: usize) -> Vec<ScoredResult> {
421        if self.is_empty() || k == 0 {
422            return Vec::new();
423        }
424
425        let Some(fs_index) = &self.fs_index else {
426            return Vec::new();
427        };
428
429        // Build candidate hits for all docs to get quality scores
430        let all_hits: Vec<FsVectorHit> = (0..self.metadata.doc_count)
431            .map(|i| FsVectorHit {
432                index: i as u32,
433                score: 0.0,
434                doc_id: self.doc_ids[i].encode(),
435            })
436            .collect();
437
438        match fs_index.quality_scores_for_hits(query_vec, &all_hits) {
439            Ok(scores) => {
440                // Build scored results and sort by score descending.
441                // Documents without quality-tier vectors (None) are skipped.
442                let mut results: Vec<ScoredResult> = scores
443                    .iter()
444                    .enumerate()
445                    .filter_map(|(idx, score)| {
446                        let s = (*score)?;
447                        let message_id = *self.message_ids.get(idx)?;
448                        Some(ScoredResult {
449                            idx,
450                            message_id,
451                            score: s,
452                        })
453                    })
454                    .collect();
455                results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
456                results.truncate(k);
457                results
458            }
459            Err(e) => {
460                warn!(error = %e, "frankensearch quality search failed");
461                Vec::new()
462            }
463        }
464    }
465
466    /// Get quality scores for a set of document indices.
467    pub fn quality_scores_for_indices(&self, query_vec: &[f32], indices: &[usize]) -> Vec<f32> {
468        let Some(fs_index) = &self.fs_index else {
469            return vec![0.0; indices.len()];
470        };
471
472        let hits: Vec<FsVectorHit> = indices
473            .iter()
474            .filter_map(|&idx| {
475                if idx < self.metadata.doc_count {
476                    Some(FsVectorHit {
477                        index: idx as u32,
478                        score: 0.0,
479                        doc_id: self.doc_ids[idx].encode(),
480                    })
481                } else {
482                    None
483                }
484            })
485            .collect();
486
487        match fs_index.quality_scores_for_hits(query_vec, &hits) {
488            Ok(scores) => scores.into_iter().map(|s| s.unwrap_or(0.0)).collect(),
489            Err(e) => {
490                warn!(error = %e, "frankensearch quality scoring failed; using zero scores");
491                vec![0.0; indices.len()]
492            }
493        }
494    }
495
496    /// Convert frankensearch VectorHits to cass ScoredResults.
497    fn hits_to_scored_results(&self, hits: Vec<FsVectorHit>) -> Vec<ScoredResult> {
498        hits.into_iter()
499            .filter_map(|hit| {
500                let idx = hit.index as usize;
501                if idx < self.metadata.doc_count {
502                    Some(ScoredResult {
503                        idx,
504                        message_id: self.message_ids[idx],
505                        score: hit.score,
506                    })
507                } else {
508                    None
509                }
510            })
511            .collect()
512    }
513}
514
515/// Search result with score and metadata.
516#[derive(Debug, Clone)]
517pub struct ScoredResult {
518    /// Index in the two-tier index.
519    pub idx: usize,
520    /// Message ID for SQLite lookup.
521    pub message_id: u64,
522    /// Similarity score.
523    pub score: f32,
524}
525
526/// Search phase result for progressive display.
527#[derive(Debug, Clone)]
528pub enum SearchPhase {
529    /// Initial results from fast embeddings.
530    Initial {
531        results: Vec<ScoredResult>,
532        latency_ms: u64,
533    },
534    /// Refined results from quality embeddings (if daemon available).
535    Refined {
536        results: Vec<ScoredResult>,
537        latency_ms: u64,
538    },
539    /// Refinement failed, keep using initial results.
540    RefinementFailed { error: String },
541}
542
543/// Two-tier searcher that coordinates fast and quality search.
544pub struct TwoTierSearcher<'a, D: DaemonClient> {
545    index: &'a TwoTierIndex,
546    daemon: Option<Arc<D>>,
547    fast_embedder: Arc<dyn Embedder>,
548    config: TwoTierConfig,
549}
550
551impl<'a, D: DaemonClient> TwoTierSearcher<'a, D> {
552    /// Create a new two-tier searcher.
553    pub fn new(
554        index: &'a TwoTierIndex,
555        fast_embedder: Arc<dyn Embedder>,
556        daemon: Option<Arc<D>>,
557        config: TwoTierConfig,
558    ) -> Self {
559        Self {
560            index,
561            daemon,
562            fast_embedder,
563            config,
564        }
565    }
566
567    /// Perform two-tier progressive search.
568    ///
569    /// Returns an iterator that yields search phases:
570    /// 1. Initial results from fast embeddings
571    /// 2. Refined results from quality embeddings (if daemon available)
572    pub fn search(&self, query: &str, k: usize) -> impl Iterator<Item = SearchPhase> + '_ {
573        TwoTierSearchIter::new(self, query.to_string(), k)
574    }
575
576    /// Perform fast-only search (no daemon refinement).
577    pub fn search_fast_only(&self, query: &str, k: usize) -> Result<Vec<ScoredResult>> {
578        let start = Instant::now();
579        let query_vec = self.fast_embedder.embed_sync(query)?;
580        let results = self.index.search_fast(&query_vec, k);
581        debug!(
582            query_len = query.len(),
583            k = k,
584            result_count = results.len(),
585            latency_ms = start.elapsed().as_millis(),
586            "Fast-only search completed"
587        );
588        Ok(results)
589    }
590
591    /// Perform quality-only search (wait for daemon).
592    pub fn search_quality_only(
593        &self,
594        query: &str,
595        k: usize,
596    ) -> Result<Vec<ScoredResult>, TwoTierError> {
597        let start = Instant::now();
598
599        let daemon = self
600            .daemon
601            .as_ref()
602            .ok_or_else(|| TwoTierError::DaemonUnavailable("no daemon configured".into()))?;
603
604        if !daemon.is_available() {
605            return Err(TwoTierError::DaemonUnavailable(
606                "daemon not available".into(),
607            ));
608        }
609
610        let request_id = format!("quality-{:016x}", rand::random::<u64>());
611        let query_vec = daemon
612            .embed(query, &request_id)
613            .map_err(TwoTierError::DaemonError)?;
614
615        let results = self.index.search_quality(&query_vec, k);
616        debug!(
617            query_len = query.len(),
618            k = k,
619            result_count = results.len(),
620            latency_ms = start.elapsed().as_millis(),
621            "Quality-only search completed"
622        );
623        Ok(results)
624    }
625}
626
627/// Iterator for two-tier search phases.
628struct TwoTierSearchIter<'a, D: DaemonClient> {
629    searcher: &'a TwoTierSearcher<'a, D>,
630    query: String,
631    k: usize,
632    phase: u8,
633    fast_results: Option<Vec<ScoredResult>>,
634}
635
636impl<'a, D: DaemonClient> TwoTierSearchIter<'a, D> {
637    fn new(searcher: &'a TwoTierSearcher<'a, D>, query: String, k: usize) -> Self {
638        Self {
639            searcher,
640            query,
641            k,
642            phase: 0,
643            fast_results: None,
644        }
645    }
646}
647
648impl<'a, D: DaemonClient> Iterator for TwoTierSearchIter<'a, D> {
649    type Item = SearchPhase;
650
651    fn next(&mut self) -> Option<Self::Item> {
652        match self.phase {
653            0 => {
654                if self.searcher.config.quality_only {
655                    self.phase = 2;
656                    let start = Instant::now();
657                    return match self.searcher.search_quality_only(&self.query, self.k) {
658                        Ok(results) => Some(SearchPhase::Refined {
659                            results,
660                            latency_ms: start.elapsed().as_millis() as u64,
661                        }),
662                        Err(e) => Some(SearchPhase::RefinementFailed {
663                            error: e.to_string(),
664                        }),
665                    };
666                }
667
668                // Phase 1: Fast search
669                self.phase = 1;
670                let start = Instant::now();
671
672                match self.searcher.fast_embedder.embed_sync(&self.query) {
673                    Ok(query_vec) => {
674                        let results = self.searcher.index.search_fast(&query_vec, self.k);
675                        let latency_ms = start.elapsed().as_millis() as u64;
676                        self.fast_results = Some(results.clone());
677
678                        if self.searcher.config.fast_only {
679                            self.phase = 2;
680                        }
681
682                        Some(SearchPhase::Initial {
683                            results,
684                            latency_ms,
685                        })
686                    }
687                    Err(e) => {
688                        warn!(error = %e, "Fast embedding failed");
689                        self.phase = 2;
690                        Some(SearchPhase::RefinementFailed {
691                            error: format!("fast embedding failed: {e}"),
692                        })
693                    }
694                }
695            }
696            1 => {
697                // Phase 2: Quality refinement
698                self.phase = 2;
699
700                let daemon = match &self.searcher.daemon {
701                    Some(d) if d.is_available() => d,
702                    _ => {
703                        return Some(SearchPhase::RefinementFailed {
704                            error: "daemon unavailable".to_string(),
705                        });
706                    }
707                };
708
709                let start = Instant::now();
710                let request_id = format!("refine-{:016x}", rand::random::<u64>());
711
712                match daemon.embed(&self.query, &request_id) {
713                    Ok(query_vec) => {
714                        let results = if let Some(fast_results) = self.fast_results.as_ref() {
715                            let refine_cap = self.searcher.config.max_refinement_docs;
716                            let candidates: Vec<usize> = fast_results
717                                .iter()
718                                .take(refine_cap)
719                                .map(|sr| sr.idx)
720                                .collect();
721                            if candidates.is_empty() {
722                                fast_results.clone()
723                            } else {
724                                let quality_scores = self
725                                    .searcher
726                                    .index
727                                    .quality_scores_for_indices(&query_vec, &candidates);
728
729                                let weight = self.searcher.config.quality_weight;
730                                let fast_scores: Vec<f32> =
731                                    fast_results.iter().map(|sr| sr.score).collect();
732                                let fast_norm = normalize_scores(&fast_scores);
733                                let quality_norm = normalize_scores(&quality_scores);
734
735                                let mut blended: Vec<ScoredResult> =
736                                    Vec::with_capacity(fast_results.len());
737                                for (idx, fast) in fast_results.iter().enumerate() {
738                                    let fast_s = fast_norm.get(idx).copied().unwrap_or(0.0);
739                                    let score = if idx < quality_norm.len() {
740                                        let quality_s =
741                                            quality_norm.get(idx).copied().unwrap_or(0.0);
742                                        (1.0 - weight) * fast_s + weight * quality_s
743                                    } else {
744                                        // Unrefined documents get a penalized score that assumes 0.0 for quality
745                                        // to preserve their original ranking but place them appropriately below
746                                        // high-quality refined items.
747                                        fast_s * (1.0 - weight)
748                                    };
749                                    blended.push(ScoredResult {
750                                        idx: fast.idx,
751                                        message_id: fast.message_id,
752                                        score,
753                                    });
754                                }
755
756                                blended.sort_by(|a, b| {
757                                    b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal)
758                                });
759                                blended.truncate(self.k);
760                                blended
761                            }
762                        } else {
763                            self.searcher.index.search_quality(&query_vec, self.k)
764                        };
765
766                        let latency_ms = start.elapsed().as_millis() as u64;
767                        Some(SearchPhase::Refined {
768                            results,
769                            latency_ms,
770                        })
771                    }
772                    Err(e) => Some(SearchPhase::RefinementFailed {
773                        error: e.to_string(),
774                    }),
775                }
776            }
777            _ => None,
778        }
779    }
780}
781
782/// Errors specific to two-tier search.
783#[derive(Debug, thiserror::Error)]
784pub enum TwoTierError {
785    #[error("daemon unavailable: {0}")]
786    DaemonUnavailable(String),
787
788    #[error("daemon error: {0}")]
789    DaemonError(#[from] DaemonError),
790
791    #[error("embedding failed: {0}")]
792    EmbeddingFailed(String),
793
794    #[error("index error: {0}")]
795    IndexError(String),
796}
797
798/// Normalize scores to [0, 1] range.
799pub fn normalize_scores(scores: &[f32]) -> Vec<f32> {
800    if scores.is_empty() {
801        return Vec::new();
802    }
803
804    let mut min = f32::INFINITY;
805    let mut max = f32::NEG_INFINITY;
806    for &s in scores {
807        if s.is_finite() {
808            min = f32::min(min, s);
809            max = f32::max(max, s);
810        }
811    }
812
813    if min.is_infinite() || max.is_infinite() {
814        return vec![0.0; scores.len()];
815    }
816
817    let range = max - min;
818
819    if range.abs() < f32::EPSILON {
820        return scores
821            .iter()
822            .map(|&s| if s.is_finite() { 1.0 } else { 0.0 })
823            .collect();
824    }
825
826    scores
827        .iter()
828        .map(|&s| {
829            if s.is_finite() {
830                (s - min) / range
831            } else {
832                0.0
833            }
834        })
835        .collect()
836}
837
838/// Blend two score vectors with the given weight for the second vector.
839pub fn blend_scores(fast: &[f32], quality: &[f32], quality_weight: f32) -> Vec<f32> {
840    let fast_norm = normalize_scores(fast);
841    let quality_norm = normalize_scores(quality);
842
843    fast_norm
844        .iter()
845        .zip(quality_norm.iter())
846        .map(|(&f, &q)| (1.0 - quality_weight) * f + quality_weight * q)
847        .collect()
848}
849
850#[cfg(test)]
851mod tests {
852    use super::*;
853    use crate::search::daemon_client::{DaemonClient, DaemonError};
854    use crate::search::embedder::{Embedder, EmbedderError};
855    use crate::search::hash_embedder::HashEmbedder;
856    use frankensearch::ModelCategory;
857    use std::sync::Arc;
858
859    struct TestDaemon {
860        dim: usize,
861        available: bool,
862    }
863
864    struct FailingEmbedder {
865        dim: usize,
866    }
867
868    struct ConstantEmbedder {
869        dim: usize,
870        value: f32,
871    }
872
873    impl Embedder for FailingEmbedder {
874        fn embed_sync(&self, _text: &str) -> Result<Vec<f32>, EmbedderError> {
875            Err(EmbedderError::EmbeddingFailed {
876                model: "failing-embedder".to_string(),
877                source: Box::new(std::io::Error::other("synthetic fast embed failure")),
878            })
879        }
880
881        fn dimension(&self) -> usize {
882            self.dim
883        }
884
885        fn id(&self) -> &str {
886            "failing-embedder"
887        }
888
889        fn is_semantic(&self) -> bool {
890            false
891        }
892
893        fn category(&self) -> ModelCategory {
894            ModelCategory::HashEmbedder
895        }
896    }
897
898    impl Embedder for ConstantEmbedder {
899        fn embed_sync(&self, _text: &str) -> Result<Vec<f32>, EmbedderError> {
900            Ok(vec![self.value; self.dim])
901        }
902
903        fn dimension(&self) -> usize {
904            self.dim
905        }
906
907        fn id(&self) -> &str {
908            "constant-embedder"
909        }
910
911        fn is_semantic(&self) -> bool {
912            false
913        }
914
915        fn category(&self) -> ModelCategory {
916            ModelCategory::HashEmbedder
917        }
918    }
919
920    impl DaemonClient for TestDaemon {
921        fn id(&self) -> &str {
922            "test-daemon"
923        }
924
925        fn is_available(&self) -> bool {
926            self.available
927        }
928
929        fn embed(&self, _text: &str, _request_id: &str) -> Result<Vec<f32>, DaemonError> {
930            Ok(vec![1.0; self.dim])
931        }
932
933        fn embed_batch(
934            &self,
935            texts: &[&str],
936            _request_id: &str,
937        ) -> Result<Vec<Vec<f32>>, DaemonError> {
938            Ok(vec![vec![1.0; self.dim]; texts.len()])
939        }
940
941        fn rerank(
942            &self,
943            _query: &str,
944            _documents: &[&str],
945            _request_id: &str,
946        ) -> Result<Vec<f32>, DaemonError> {
947            Err(DaemonError::Unavailable(
948                "rerank unsupported in test daemon".to_string(),
949            ))
950        }
951    }
952
953    fn make_test_entries(count: usize, fast_dim: usize, quality_dim: usize) -> Vec<TwoTierEntry> {
954        (0..count)
955            .map(|i| TwoTierEntry {
956                doc_id: DocumentId::Session(format!("session-{}", i)),
957                message_id: i as u64,
958                fast_embedding: (0..fast_dim)
959                    .map(|j| f16::from_f32((i + j) as f32 * 0.01))
960                    .collect(),
961                quality_embedding: (0..quality_dim)
962                    .map(|j| f16::from_f32((i + j) as f32 * 0.01))
963                    .collect(),
964            })
965            .collect()
966    }
967
968    #[test]
969    fn test_two_tier_index_creation() {
970        let config = TwoTierConfig::default();
971        let entries = make_test_entries(10, config.fast_dimension, config.quality_dimension);
972
973        let index = TwoTierIndex::build("fast-256", "quality-384", &config, entries).unwrap();
974
975        assert_eq!(index.len(), 10);
976        assert!(!index.is_empty());
977        assert!(matches!(
978            index.metadata.status,
979            IndexStatus::Complete { .. }
980        ));
981    }
982
983    #[test]
984    fn test_empty_index() {
985        let config = TwoTierConfig::default();
986        let entries: Vec<TwoTierEntry> = Vec::new();
987
988        let index = TwoTierIndex::build("fast-256", "quality-384", &config, entries).unwrap();
989
990        assert_eq!(index.len(), 0);
991        assert!(index.is_empty());
992    }
993
994    #[test]
995    fn test_dimension_mismatch_fast() {
996        let config = TwoTierConfig::default();
997        let entries = vec![TwoTierEntry {
998            doc_id: DocumentId::Session("test".into()),
999            message_id: 1,
1000            fast_embedding: vec![f16::from_f32(1.0); 128], // Wrong dimension
1001            quality_embedding: vec![f16::from_f32(1.0); config.quality_dimension],
1002        }];
1003
1004        let result = TwoTierIndex::build("fast", "quality", &config, entries);
1005        assert!(result.is_err());
1006    }
1007
1008    #[test]
1009    fn test_dimension_mismatch_quality() {
1010        let config = TwoTierConfig::default();
1011        let entries = vec![TwoTierEntry {
1012            doc_id: DocumentId::Session("test".into()),
1013            message_id: 1,
1014            fast_embedding: vec![f16::from_f32(1.0); config.fast_dimension],
1015            quality_embedding: vec![f16::from_f32(1.0); 128], // Wrong dimension
1016        }];
1017
1018        let result = TwoTierIndex::build("fast", "quality", &config, entries);
1019        assert!(result.is_err());
1020    }
1021
1022    #[test]
1023    fn test_fast_search() {
1024        let config = TwoTierConfig::default();
1025        let entries = make_test_entries(100, config.fast_dimension, config.quality_dimension);
1026        let index = TwoTierIndex::build("fast-256", "quality-384", &config, entries).unwrap();
1027
1028        let query: Vec<f32> = (0..config.fast_dimension)
1029            .map(|i| i as f32 * 0.01)
1030            .collect();
1031        let results = index.search_fast(&query, 10);
1032
1033        assert_eq!(results.len(), 10);
1034        // Results should be sorted by score descending
1035        for window in results.windows(2) {
1036            assert!(window[0].score >= window[1].score);
1037        }
1038    }
1039
1040    #[test]
1041    fn test_side_tables_follow_frankensearch_index_order() {
1042        let config = TwoTierConfig::default();
1043        let entries = vec![
1044            TwoTierEntry {
1045                doc_id: DocumentId::Session("session-z".into()),
1046                message_id: 300,
1047                fast_embedding: vec![f16::from_f32(1.0); config.fast_dimension],
1048                quality_embedding: vec![f16::from_f32(1.0); config.quality_dimension],
1049            },
1050            TwoTierEntry {
1051                doc_id: DocumentId::Session("session-a".into()),
1052                message_id: 100,
1053                fast_embedding: vec![f16::from_f32(0.5); config.fast_dimension],
1054                quality_embedding: vec![f16::from_f32(0.5); config.quality_dimension],
1055            },
1056            TwoTierEntry {
1057                doc_id: DocumentId::Session("session-m".into()),
1058                message_id: 200,
1059                fast_embedding: vec![f16::from_f32(0.25); config.fast_dimension],
1060                quality_embedding: vec![f16::from_f32(0.25); config.quality_dimension],
1061            },
1062        ];
1063        let expected_by_encoded = HashMap::from([
1064            ("s:session-z".to_string(), 300_u64),
1065            ("s:session-a".to_string(), 100_u64),
1066            ("s:session-m".to_string(), 200_u64),
1067        ]);
1068
1069        let index = TwoTierIndex::build("fast-256", "quality-384", &config, entries).unwrap();
1070        let fs_index = index.fs_index.as_ref().expect("non-empty fs index");
1071
1072        for idx in 0..index.len() {
1073            let encoded = fs_index.doc_id_at(idx).expect("fs doc_id");
1074            assert_eq!(index.doc_ids[idx].encode(), encoded);
1075            assert_eq!(index.message_ids[idx], expected_by_encoded[encoded]);
1076        }
1077    }
1078
1079    #[test]
1080    fn test_quality_search() {
1081        let config = TwoTierConfig::default();
1082        let entries = make_test_entries(100, config.fast_dimension, config.quality_dimension);
1083        let index = TwoTierIndex::build("fast-256", "quality-384", &config, entries).unwrap();
1084
1085        let query: Vec<f32> = (0..config.quality_dimension)
1086            .map(|i| i as f32 * 0.01)
1087            .collect();
1088        let results = index.search_quality(&query, 10);
1089
1090        assert_eq!(results.len(), 10);
1091        // Results should be sorted by score descending
1092        for window in results.windows(2) {
1093            assert!(window[0].score >= window[1].score);
1094        }
1095    }
1096
1097    #[test]
1098    fn test_score_normalization() {
1099        let scores = vec![0.8, 0.6, 0.4, 0.2];
1100        let normalized = normalize_scores(&scores);
1101
1102        assert!((normalized[0] - 1.0).abs() < 0.001);
1103        assert!((normalized[3] - 0.0).abs() < 0.001);
1104    }
1105
1106    #[test]
1107    fn test_score_normalization_constant() {
1108        let scores = vec![0.5, 0.5, 0.5];
1109        let normalized = normalize_scores(&scores);
1110
1111        for n in &normalized {
1112            assert!((n - 1.0).abs() < 0.001);
1113        }
1114    }
1115
1116    #[test]
1117    fn test_score_normalization_constant_with_nan_keeps_nan_zeroed() {
1118        let scores = vec![f32::NAN, 0.5, 0.5];
1119        let normalized = normalize_scores(&scores);
1120
1121        assert_eq!(normalized.len(), 3);
1122        assert_eq!(normalized[0], 0.0);
1123        assert!((normalized[1] - 1.0).abs() < 0.001);
1124        assert!((normalized[2] - 1.0).abs() < 0.001);
1125    }
1126
1127    #[test]
1128    fn test_score_normalization_with_infinite_values_keeps_non_finite_zeroed() {
1129        let scores = vec![f32::NEG_INFINITY, 2.0, f32::INFINITY, 4.0];
1130        let normalized = normalize_scores(&scores);
1131
1132        assert_eq!(normalized.len(), 4);
1133        assert_eq!(normalized[0], 0.0);
1134        assert_eq!(normalized[2], 0.0);
1135        assert!((normalized[1] - 0.0).abs() < 0.001);
1136        assert!((normalized[3] - 1.0).abs() < 0.001);
1137    }
1138
1139    #[test]
1140    fn test_score_normalization_empty() {
1141        let scores: Vec<f32> = vec![];
1142        let normalized = normalize_scores(&scores);
1143        assert!(normalized.is_empty());
1144    }
1145
1146    #[test]
1147    fn test_blend_scores() {
1148        let fast = vec![0.8, 0.6, 0.4];
1149        let quality = vec![0.4, 0.8, 0.6];
1150        let blended = blend_scores(&fast, &quality, 0.5);
1151
1152        assert_eq!(blended.len(), 3);
1153    }
1154
1155    #[test]
1156    fn test_document_id_session() {
1157        let doc_id = DocumentId::Session("test-session".into());
1158        assert_eq!(doc_id.session_id(), "test-session");
1159    }
1160
1161    #[test]
1162    fn test_document_id_turn() {
1163        let doc_id = DocumentId::Turn("test-session".into(), 5);
1164        assert_eq!(doc_id.session_id(), "test-session");
1165    }
1166
1167    #[test]
1168    fn test_document_id_code_block() {
1169        let doc_id = DocumentId::CodeBlock("test-session".into(), 3, 2);
1170        assert_eq!(doc_id.session_id(), "test-session");
1171    }
1172
1173    #[test]
1174    fn test_config_defaults() {
1175        let config = TwoTierConfig::default();
1176        assert_eq!(config.fast_dimension, 256);
1177        assert_eq!(config.quality_dimension, 384);
1178        assert!((config.quality_weight - 0.7).abs() < 0.001);
1179        assert_eq!(config.max_refinement_docs, 100);
1180        assert!(!config.fast_only);
1181        assert!(!config.quality_only);
1182    }
1183
1184    #[test]
1185    fn test_config_fast_only() {
1186        let config = TwoTierConfig::fast_only();
1187        assert!(config.fast_only);
1188        assert!(!config.quality_only);
1189    }
1190
1191    #[test]
1192    fn test_config_quality_only() {
1193        let config = TwoTierConfig::quality_only();
1194        assert!(!config.fast_only);
1195        assert!(config.quality_only);
1196    }
1197
1198    #[test]
1199    fn test_quality_scores_for_indices() {
1200        let config = TwoTierConfig::default();
1201        let entries = make_test_entries(10, config.fast_dimension, config.quality_dimension);
1202        let index = TwoTierIndex::build("fast-256", "quality-384", &config, entries).unwrap();
1203
1204        let query: Vec<f32> = (0..config.quality_dimension)
1205            .map(|i| i as f32 * 0.01)
1206            .collect();
1207        let indices = vec![0, 2, 4];
1208        let scores = index.quality_scores_for_indices(&query, &indices);
1209
1210        assert_eq!(scores.len(), 3);
1211    }
1212
1213    #[test]
1214    fn test_search_fast_dimension_mismatch_returns_empty() {
1215        let config = TwoTierConfig::default();
1216        let entries = make_test_entries(5, config.fast_dimension, config.quality_dimension);
1217        let index = TwoTierIndex::build("fast-256", "quality-384", &config, entries).unwrap();
1218
1219        let bad_query = vec![0.5; config.fast_dimension.saturating_sub(1)];
1220        let results = index.search_fast(&bad_query, 5);
1221        assert!(results.is_empty());
1222    }
1223
1224    #[test]
1225    fn test_search_quality_dimension_mismatch_returns_empty() {
1226        let config = TwoTierConfig::default();
1227        let entries = make_test_entries(5, config.fast_dimension, config.quality_dimension);
1228        let index = TwoTierIndex::build("fast-256", "quality-384", &config, entries).unwrap();
1229
1230        let bad_query = vec![0.5; config.quality_dimension.saturating_sub(1)];
1231        let results = index.search_quality(&bad_query, 5);
1232        assert!(results.is_empty());
1233    }
1234
1235    #[test]
1236    fn test_quality_scores_for_indices_dimension_mismatch_returns_zeros() {
1237        let config = TwoTierConfig::default();
1238        let entries = make_test_entries(5, config.fast_dimension, config.quality_dimension);
1239        let index = TwoTierIndex::build("fast-256", "quality-384", &config, entries).unwrap();
1240
1241        let bad_query = vec![0.5; config.quality_dimension.saturating_sub(1)];
1242        let scores = index.quality_scores_for_indices(&bad_query, &[0, 2, 4]);
1243        assert_eq!(scores, vec![0.0, 0.0, 0.0]);
1244    }
1245
1246    #[test]
1247    fn test_quality_only_mode_emits_only_refined_phase() {
1248        let config = TwoTierConfig {
1249            fast_dimension: 8,
1250            quality_dimension: 8,
1251            quality_only: true,
1252            ..Default::default()
1253        };
1254        let entries = make_test_entries(4, config.fast_dimension, config.quality_dimension);
1255        let index = TwoTierIndex::build("fast-8", "quality-8", &config, entries).unwrap();
1256
1257        let fast_embedder: Arc<dyn Embedder> = Arc::new(HashEmbedder::new(config.fast_dimension));
1258        let daemon = Arc::new(TestDaemon {
1259            dim: config.quality_dimension,
1260            available: true,
1261        });
1262        let searcher = TwoTierSearcher::new(&index, fast_embedder, Some(daemon), config);
1263        let phases: Vec<SearchPhase> = searcher.search("query", 3).collect();
1264
1265        assert_eq!(phases.len(), 1);
1266        assert!(matches!(phases[0], SearchPhase::Refined { .. }));
1267    }
1268
1269    #[test]
1270    fn test_quality_only_mode_without_daemon_reports_failure() {
1271        let config = TwoTierConfig {
1272            fast_dimension: 8,
1273            quality_dimension: 8,
1274            quality_only: true,
1275            ..Default::default()
1276        };
1277        let entries = make_test_entries(4, config.fast_dimension, config.quality_dimension);
1278        let index = TwoTierIndex::build("fast-8", "quality-8", &config, entries).unwrap();
1279
1280        let fast_embedder: Arc<dyn Embedder> = Arc::new(HashEmbedder::new(config.fast_dimension));
1281        let daemon = Arc::new(TestDaemon {
1282            dim: config.quality_dimension,
1283            available: false,
1284        });
1285        let searcher = TwoTierSearcher::new(&index, fast_embedder, Some(daemon), config);
1286        let phases: Vec<SearchPhase> = searcher.search("query", 3).collect();
1287
1288        assert_eq!(phases.len(), 1);
1289        assert!(matches!(phases[0], SearchPhase::RefinementFailed { .. }));
1290    }
1291
1292    #[test]
1293    fn test_fast_embedding_failure_yields_failure_phase() {
1294        let config = TwoTierConfig {
1295            fast_dimension: 8,
1296            quality_dimension: 8,
1297            fast_only: false,
1298            quality_only: false,
1299            ..Default::default()
1300        };
1301        let entries = make_test_entries(4, config.fast_dimension, config.quality_dimension);
1302        let index = TwoTierIndex::build("fast-8", "quality-8", &config, entries).unwrap();
1303
1304        let fast_embedder: Arc<dyn Embedder> = Arc::new(FailingEmbedder {
1305            dim: config.fast_dimension,
1306        });
1307        let daemon = Arc::new(TestDaemon {
1308            dim: config.quality_dimension,
1309            available: true,
1310        });
1311        let searcher = TwoTierSearcher::new(&index, fast_embedder, Some(daemon), config);
1312        let phases: Vec<SearchPhase> = searcher.search("query", 3).collect();
1313
1314        assert_eq!(phases.len(), 1);
1315        assert!(matches!(phases[0], SearchPhase::RefinementFailed { .. }));
1316    }
1317
1318    #[test]
1319    fn test_refinement_scores_are_normalized() {
1320        let config = TwoTierConfig {
1321            fast_dimension: 8,
1322            quality_dimension: 8,
1323            quality_weight: 0.6,
1324            max_refinement_docs: 3,
1325            ..Default::default()
1326        };
1327        let entries: Vec<TwoTierEntry> = (0..5)
1328            .map(|i| TwoTierEntry {
1329                doc_id: DocumentId::Session(format!("s{i}")),
1330                message_id: i as u64 + 1,
1331                fast_embedding: vec![f16::from_f32(20.0 + i as f32); config.fast_dimension],
1332                quality_embedding: vec![f16::from_f32(10.0 + i as f32); config.quality_dimension],
1333            })
1334            .collect();
1335        let index = TwoTierIndex::build("fast-8", "quality-8", &config, entries).unwrap();
1336
1337        let fast_embedder: Arc<dyn Embedder> = Arc::new(ConstantEmbedder {
1338            dim: config.fast_dimension,
1339            value: 10.0,
1340        });
1341        let daemon = Arc::new(TestDaemon {
1342            dim: config.quality_dimension,
1343            available: true,
1344        });
1345        let searcher = TwoTierSearcher::new(&index, fast_embedder, Some(daemon), config);
1346        let phases: Vec<SearchPhase> = searcher.search("query", 5).collect();
1347
1348        assert_eq!(phases.len(), 2);
1349        let SearchPhase::Refined { results, .. } = &phases[1] else {
1350            panic!("expected refined phase");
1351        };
1352        assert!(
1353            results.iter().all(|r| (0.0..=1.0).contains(&r.score)),
1354            "expected normalized refined scores, got {:?}",
1355            results.iter().map(|r| r.score).collect::<Vec<_>>()
1356        );
1357    }
1358}