Skip to main content

frankensearch_core/
traits.rs

1//! Core traits for the frankensearch search pipeline.
2//!
3//! - [`Embedder`]: Text embedding model interface (hash, model2vec, fastembed).
4//! - [`Reranker`]: Cross-encoder reranking model interface.
5//! - [`LexicalSearch`]: Full-text search backend interface (Tantivy, FTS5).
6//!
7//! Async operations are represented as boxed futures so the traits remain
8//! dyn-compatible for runtime polymorphism (`Box<dyn Embedder>`, etc.).
9
10use std::fmt;
11use std::future::Future;
12use std::pin::Pin;
13use std::sync::Arc;
14
15use asupersync::Cx;
16use serde::{Deserialize, Serialize};
17
18use crate::error::{SearchError, SearchResult};
19use crate::types::{
20    EmbeddingMetrics, IndexMetrics, IndexableDocument, ScoredResult, SearchMetrics,
21};
22
23/// Boxed future carrying a `SearchResult<T>`.
24pub type SearchFuture<'a, T> = Pin<Box<dyn Future<Output = SearchResult<T>> + Send + 'a>>;
25
26// ─── Model Category ─────────────────────────────────────────────────────────
27
28/// Classification of an embedding model by its speed/quality tradeoff.
29///
30/// Used by `EmbedderStack` to pair a fast-tier and quality-tier embedder
31/// for the two-tier progressive search pipeline.
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
33pub enum ModelCategory {
34    /// Hash-based (FNV-1a): ultra-fast, deterministic, not semantically meaningful.
35    HashEmbedder,
36    /// Static token embeddings (Model2Vec/potion): fast with good semantic quality.
37    StaticEmbedder,
38    /// Transformer inference (MiniLM/BGE): highest quality but slower.
39    TransformerEmbedder,
40    /// Cloud API embeddings (`OpenAI`, Gemini): high quality, network-dependent latency.
41    ApiEmbedder,
42}
43
44impl fmt::Display for ModelCategory {
45    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46        match self {
47            Self::HashEmbedder => write!(f, "hash_embedder"),
48            Self::StaticEmbedder => write!(f, "static_embedder"),
49            Self::TransformerEmbedder => write!(f, "transformer_embedder"),
50            Self::ApiEmbedder => write!(f, "api_embedder"),
51        }
52    }
53}
54
55impl ModelCategory {
56    /// Returns the default progressive tier for this model category.
57    #[must_use]
58    pub const fn default_tier(self) -> ModelTier {
59        match self {
60            Self::HashEmbedder | Self::StaticEmbedder => ModelTier::Fast,
61            Self::TransformerEmbedder | Self::ApiEmbedder => ModelTier::Quality,
62        }
63    }
64
65    /// Whether this category is semantically meaningful by default.
66    #[must_use]
67    pub const fn default_semantic_flag(self) -> bool {
68        !matches!(self, Self::HashEmbedder)
69    }
70}
71
72/// Tier assignment in the progressive two-tier pipeline.
73#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
74pub enum ModelTier {
75    /// Ultra-fast path for immediate results.
76    Fast,
77    /// Higher-quality path for deferred refinement.
78    Quality,
79}
80
81impl fmt::Display for ModelTier {
82    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83        match self {
84            Self::Fast => write!(f, "fast"),
85            Self::Quality => write!(f, "quality"),
86        }
87    }
88}
89
90/// Static metadata describing an embedder implementation.
91#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
92pub struct ModelInfo {
93    /// Stable model identifier used in index metadata.
94    pub id: String,
95    /// Human-friendly model name.
96    pub name: String,
97    /// Embedding dimensionality.
98    pub dimension: usize,
99    /// Embedder category by architecture/performance profile.
100    pub category: ModelCategory,
101    /// Default tier assignment in progressive search.
102    pub tier: ModelTier,
103    /// Whether embeddings encode semantic similarity.
104    pub is_semantic: bool,
105    /// Whether Matryoshka truncation is supported.
106    pub supports_mrl: bool,
107    /// Optional upstream model id (e.g., `HuggingFace`).
108    pub huggingface_id: Option<String>,
109    /// Optional model footprint on disk.
110    pub size_bytes: Option<u64>,
111    /// Optional model license string.
112    pub license: Option<String>,
113}
114
115// ─── Embedder Trait ─────────────────────────────────────────────────────────
116
117/// Core trait for text embedding models.
118///
119/// Implementations run under structured concurrency, so each async operation
120/// receives a capability context (`&Cx`) as its first parameter.
121///
122/// # Contract
123///
124/// - `embed()` and `embed_batch()` are cancel-aware and return boxed futures.
125/// - `dimension()` must be constant for the lifetime of the embedder.
126/// - `id()` must be stable across process restarts (it's stored in FSVI headers).
127pub trait Embedder: Send + Sync {
128    /// Embed a single text string into a vector of f32 floats.
129    ///
130    /// The returned vector has exactly `self.dimension()` elements.
131    ///
132    /// # Errors
133    ///
134    /// Returns `SearchError` if embedding inference fails.
135    fn embed<'a>(&'a self, cx: &'a Cx, text: &'a str) -> SearchFuture<'a, Vec<f32>>;
136
137    /// Embed a batch of text strings.
138    ///
139    /// Default implementation calls `embed` in a loop. Neural models should
140    /// override this to exploit batch inference (ONNX has high fixed overhead
141    /// but low marginal cost per additional input).
142    ///
143    /// # Errors
144    ///
145    /// Returns `SearchError` if any embedding inference fails.
146    fn embed_batch<'a>(
147        &'a self,
148        cx: &'a Cx,
149        texts: &'a [&'a str],
150    ) -> SearchFuture<'a, Vec<Vec<f32>>> {
151        Box::pin(async move {
152            let mut out = Vec::with_capacity(texts.len());
153            for text in texts {
154                out.push(self.embed(cx, text).await?);
155            }
156            Ok(out)
157        })
158    }
159
160    /// The dimensionality of embedding vectors produced by this model.
161    fn dimension(&self) -> usize;
162
163    /// A unique, stable identifier for this embedder.
164    ///
165    /// Examples: `"fnv-hash-384"`, `"potion-multilingual-128M"`, `"all-MiniLM-L6-v2"`.
166    /// Stored in FSVI index headers for embedder-revision matching.
167    fn id(&self) -> &str;
168
169    /// Human-readable model name.
170    fn model_name(&self) -> &str;
171
172    /// Whether this embedder is loaded and operational.
173    fn is_ready(&self) -> bool {
174        true
175    }
176
177    /// Whether this embedder produces semantically meaningful vectors.
178    ///
179    /// Hash embedders return `false`; neural models return `true`.
180    fn is_semantic(&self) -> bool;
181
182    /// The speed/quality category of this embedder.
183    fn category(&self) -> ModelCategory;
184
185    /// Default progressive tier assignment.
186    fn tier(&self) -> ModelTier {
187        self.category().default_tier()
188    }
189
190    /// Whether this model supports Matryoshka Representation Learning
191    /// (dimension truncation for faster search with controlled quality loss).
192    fn supports_mrl(&self) -> bool {
193        false
194    }
195
196    /// Truncate and re-normalize embedding to `target_dim`.
197    ///
198    /// # Errors
199    ///
200    /// Returns `InvalidConfig` when `target_dim` is zero.
201    fn truncate_embedding(&self, embedding: &[f32], target_dim: usize) -> SearchResult<Vec<f32>> {
202        if target_dim == 0 {
203            return Err(SearchError::InvalidConfig {
204                field: "target_dim".to_owned(),
205                value: "0".to_owned(),
206                reason: "target dimension must be at least 1".to_owned(),
207            });
208        }
209
210        if target_dim >= embedding.len() {
211            return Ok(embedding.to_vec());
212        }
213
214        Ok(l2_normalize(&embedding[..target_dim]))
215    }
216}
217
218// ─── Synchronous Embedder Bridge ─────────────────────────────────────────
219
220/// Synchronous embedding interface for host projects that call embedders from
221/// non-async contexts.
222///
223/// Implement this trait for embedders whose `embed` operations are inherently
224/// synchronous (e.g., hash embedders, CPU-only ONNX inference). The companion
225/// [`SyncEmbedderAdapter`] wraps any `SyncEmbed` implementor into a full
226/// async [`Embedder`], suitable for use anywhere frankensearch expects one.
227///
228/// # Example
229///
230/// ```ignore
231/// use frankensearch_core::traits::{SyncEmbed, SyncEmbedderAdapter, Embedder};
232///
233/// struct MyHashEmbedder { dim: usize }
234///
235/// impl SyncEmbed for MyHashEmbedder {
236///     fn embed_sync(&self, text: &str) -> SearchResult<Vec<f32>> { /* ... */ }
237///     fn dimension(&self) -> usize { self.dim }
238///     fn id(&self) -> &str { "my-hash" }
239///     fn model_name(&self) -> &str { "My Hash Embedder" }
240///     fn is_semantic(&self) -> bool { false }
241///     fn category(&self) -> ModelCategory { ModelCategory::HashEmbedder }
242/// }
243///
244/// // Use it as a full async Embedder:
245/// let adapted: Box<dyn Embedder> = Box::new(SyncEmbedderAdapter(MyHashEmbedder { dim: 256 }));
246/// ```
247pub trait SyncEmbed: Send + Sync {
248    /// Synchronously embed a single text into a vector.
249    ///
250    /// # Errors
251    ///
252    /// Returns [`SearchError`] when embedding fails (for example model load,
253    /// inference, or input validation failures).
254    fn embed_sync(&self, text: &str) -> SearchResult<Vec<f32>>;
255
256    /// Synchronously embed a batch of texts.
257    ///
258    /// Default implementation calls [`embed_sync`](Self::embed_sync) for each text.
259    ///
260    /// # Errors
261    ///
262    /// Returns the first [`SearchError`] encountered while embedding any item
263    /// in the batch.
264    fn embed_batch_sync(&self, texts: &[&str]) -> SearchResult<Vec<Vec<f32>>> {
265        texts.iter().map(|t| self.embed_sync(t)).collect()
266    }
267
268    /// The output dimensionality of embedding vectors.
269    fn dimension(&self) -> usize;
270
271    /// Unique, stable identifier for this embedder (stored in index headers).
272    fn id(&self) -> &str;
273
274    /// Human-readable model name.
275    fn model_name(&self) -> &str {
276        self.id()
277    }
278
279    /// Whether the embedder is loaded and operational.
280    fn is_ready(&self) -> bool {
281        true
282    }
283
284    /// Whether this embedder produces semantically meaningful vectors.
285    fn is_semantic(&self) -> bool;
286
287    /// The speed/quality category of this embedder.
288    fn category(&self) -> ModelCategory;
289
290    /// Default progressive tier assignment.
291    fn tier(&self) -> ModelTier {
292        self.category().default_tier()
293    }
294
295    /// Whether this model supports Matryoshka Representation Learning.
296    fn supports_mrl(&self) -> bool {
297        false
298    }
299}
300
301/// Adapts a [`SyncEmbed`] implementor into a full async [`Embedder`].
302///
303/// The sync `embed_sync()` call is wrapped in `Box::pin(async move { ... })`,
304/// which is zero-cost for pure computation (hash embedders) and acceptable for
305/// blocking ONNX inference when called from a `spawn_blocking` context.
306pub struct SyncEmbedderAdapter<T: SyncEmbed>(pub T);
307
308impl<T: SyncEmbed + 'static> Embedder for SyncEmbedderAdapter<T> {
309    fn embed<'a>(&'a self, _cx: &'a Cx, text: &'a str) -> SearchFuture<'a, Vec<f32>> {
310        Box::pin(async move { self.0.embed_sync(text) })
311    }
312
313    fn embed_batch<'a>(
314        &'a self,
315        _cx: &'a Cx,
316        texts: &'a [&'a str],
317    ) -> SearchFuture<'a, Vec<Vec<f32>>> {
318        Box::pin(async move { self.0.embed_batch_sync(texts) })
319    }
320
321    fn dimension(&self) -> usize {
322        self.0.dimension()
323    }
324
325    fn id(&self) -> &str {
326        self.0.id()
327    }
328
329    fn model_name(&self) -> &str {
330        self.0.model_name()
331    }
332
333    fn is_ready(&self) -> bool {
334        self.0.is_ready()
335    }
336
337    fn is_semantic(&self) -> bool {
338        self.0.is_semantic()
339    }
340
341    fn category(&self) -> ModelCategory {
342        self.0.category()
343    }
344
345    fn tier(&self) -> ModelTier {
346        self.0.tier()
347    }
348
349    fn supports_mrl(&self) -> bool {
350        self.0.supports_mrl()
351    }
352}
353
354// ─── Embedding Utilities ──────────────────────────────────────────────────
355
356/// L2-normalizes a vector to unit length.
357///
358/// Returns a zero vector if the input has zero norm (avoids division by zero).
359#[must_use]
360pub fn l2_normalize(vec: &[f32]) -> Vec<f32> {
361    let norm_sq: f32 = vec.iter().map(|x| x * x).sum();
362    if !norm_sq.is_finite() || norm_sq < f32::EPSILON {
363        return vec![0.0; vec.len()];
364    }
365    let inv_norm = 1.0 / norm_sq.sqrt();
366    vec.iter().map(|x| x * inv_norm).collect()
367}
368
369/// Computes cosine similarity between two vectors.
370///
371/// Returns 0.0 if either vector has zero norm.
372///
373/// # Panics
374///
375/// Panics in debug mode if the vectors have different lengths.
376#[must_use]
377pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
378    // Runtime length check — debug_assert is stripped in release builds,
379    // and zip would silently truncate mismatched vectors.
380    if a.len() != b.len() {
381        return 0.0;
382    }
383
384    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
385    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
386    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
387
388    let denom = norm_a * norm_b;
389    if !denom.is_finite() || denom < f32::EPSILON {
390        return 0.0;
391    }
392    dot / denom
393}
394
395/// Truncates an embedding to a target dimension and re-normalizes.
396///
397/// Only meaningful for models that support Matryoshka Representation Learning (MRL),
398/// where the first N dimensions capture most of the variance.
399///
400/// Returns the original vector unchanged if `target_dim >= embedding.len()`.
401#[must_use]
402pub fn truncate_embedding(embedding: &[f32], target_dim: usize) -> Vec<f32> {
403    if target_dim >= embedding.len() {
404        return embedding.to_vec();
405    }
406    l2_normalize(&embedding[..target_dim])
407}
408
409// ─── Reranker Trait ─────────────────────────────────────────────────────────
410
411/// A document for reranking: pairs a document ID with its text content.
412///
413/// Text must be provided because cross-encoders process query+document
414/// pairs through a transformer. `ScoredResult` intentionally does not
415/// carry text to avoid memory waste in the common case.
416#[derive(Debug, Clone)]
417pub struct RerankDocument {
418    /// Document identifier.
419    pub doc_id: String,
420    /// Document text content for cross-encoder input.
421    pub text: String,
422}
423
424/// A reranking score for a single document.
425#[derive(Debug, Clone, Serialize, Deserialize)]
426pub struct RerankScore {
427    /// Document identifier.
428    pub doc_id: String,
429    /// Cross-encoder relevance score (typically sigmoid-activated logit).
430    pub score: f32,
431    /// Position before reranking (for rank-change tracking).
432    pub original_rank: usize,
433    /// Raw pre-sigmoid logit, when the backend exposes it.
434    ///
435    /// Some cross-encoder implementations only emit a final score (after sigmoid
436    /// activation). When the raw logit is unavailable, this field is `None`.
437    #[serde(default, skip_serializing_if = "Option::is_none")]
438    pub raw_logit: Option<f32>,
439}
440
441/// Core trait for cross-encoder reranking models.
442///
443/// Cross-encoders process query+document pairs together through a transformer,
444/// producing more accurate relevance scores than bi-encoder cosine similarity.
445/// This accuracy comes at the cost of not being able to pre-compute anything:
446/// every query-document pair requires a full inference pass.
447///
448/// # Graceful Failure
449///
450/// The reranking step should never block search results. If the model is
451/// unavailable or inference fails, implementations should return
452/// `Err(SearchError::RerankFailed { .. })` and callers should fall back
453/// to the original RRF scores.
454pub trait Reranker: Send + Sync {
455    /// Score and re-rank documents against a query.
456    ///
457    /// Returns documents sorted by descending cross-encoder score.
458    ///
459    /// # Errors
460    ///
461    /// Returns `SearchError::RerankFailed` if cross-encoder inference fails.
462    fn rerank<'a>(
463        &'a self,
464        cx: &'a Cx,
465        query: &'a str,
466        documents: &'a [RerankDocument],
467    ) -> SearchFuture<'a, Vec<RerankScore>>;
468
469    /// A unique identifier for this reranker model.
470    fn id(&self) -> &str;
471
472    /// Human-friendly reranker model name.
473    fn model_name(&self) -> &str;
474
475    /// Maximum supported token length for query+document pair input.
476    fn max_length(&self) -> usize {
477        512
478    }
479
480    /// Whether this reranker is loaded and ready for inference.
481    fn is_available(&self) -> bool {
482        true
483    }
484}
485
486// ─── Synchronous Reranker Bridge ────────────────────────────────────────────
487
488/// Synchronous reranking interface for host projects that call rerankers from
489/// non-async contexts.
490///
491/// Implement this trait for rerankers whose `rerank` operations are inherently
492/// synchronous (e.g., blocking ONNX inference). The companion
493/// [`SyncRerankerAdapter`] wraps any `SyncRerank` implementor into a full
494/// async [`Reranker`], suitable for use anywhere frankensearch expects one.
495pub trait SyncRerank: Send + Sync {
496    /// Synchronously rerank documents against a query.
497    ///
498    /// Returns documents sorted by descending cross-encoder score.
499    ///
500    /// # Errors
501    ///
502    /// Returns [`SearchError`] when reranking fails (for example model load,
503    /// inference, or input validation failures).
504    fn rerank_sync(
505        &self,
506        query: &str,
507        documents: &[RerankDocument],
508    ) -> SearchResult<Vec<RerankScore>>;
509
510    /// A unique identifier for this reranker model.
511    fn id(&self) -> &str;
512
513    /// Human-friendly reranker model name.
514    fn model_name(&self) -> &str;
515
516    /// Maximum supported token length for query+document pair input.
517    fn max_length(&self) -> usize {
518        512
519    }
520
521    /// Whether this reranker is loaded and ready for inference.
522    fn is_available(&self) -> bool {
523        true
524    }
525}
526
527/// Adapts a [`SyncRerank`] implementor into a full async [`Reranker`].
528///
529/// The sync `rerank_sync()` call is wrapped in `Box::pin(async move { ... })`,
530/// which is acceptable for blocking ONNX inference when called from a
531/// `spawn_blocking` context.
532pub struct SyncRerankerAdapter<T: SyncRerank>(pub T);
533
534impl<T: SyncRerank + 'static> Reranker for SyncRerankerAdapter<T> {
535    fn rerank<'a>(
536        &'a self,
537        _cx: &'a Cx,
538        query: &'a str,
539        documents: &'a [RerankDocument],
540    ) -> SearchFuture<'a, Vec<RerankScore>> {
541        Box::pin(async move {
542            let mut scores = self.0.rerank_sync(query, documents)?;
543            scores.sort_by(|lhs, rhs| {
544                rhs.score
545                    .total_cmp(&lhs.score)
546                    .then_with(|| lhs.original_rank.cmp(&rhs.original_rank))
547                    .then_with(|| lhs.doc_id.cmp(&rhs.doc_id))
548            });
549            Ok(scores)
550        })
551    }
552
553    fn id(&self) -> &str {
554        self.0.id()
555    }
556
557    fn model_name(&self) -> &str {
558        self.0.model_name()
559    }
560
561    fn max_length(&self) -> usize {
562        self.0.max_length()
563    }
564
565    fn is_available(&self) -> bool {
566        self.0.is_available()
567    }
568}
569
570// ─── Lexical Search Trait ───────────────────────────────────────────────────
571
572/// Trait for full-text lexical search backends.
573///
574/// Two implementations are planned:
575/// - `TantivyIndex` in `frankensearch-lexical` (default, via `lexical` feature)
576/// - FTS5 adapter in `frankensearch-storage` (alternative, via `fts5` feature)
577///
578/// Both produce `ScoredResult` with `source = ScoreSource::Lexical`.
579pub trait LexicalSearch: Send + Sync {
580    /// Search for documents matching the query, returning up to `limit` results
581    /// sorted by BM25 relevance.
582    ///
583    /// # Errors
584    ///
585    /// Returns `SearchError` if the query cannot be parsed or the search backend fails.
586    fn search<'a>(
587        &'a self,
588        cx: &'a Cx,
589        query: &'a str,
590        limit: usize,
591    ) -> SearchFuture<'a, Vec<ScoredResult>>;
592
593    /// Index a single document for full-text search.
594    ///
595    /// # Errors
596    ///
597    /// Returns `SearchError` if the document cannot be indexed.
598    fn index_document<'a>(&'a self, cx: &'a Cx, doc: &'a IndexableDocument)
599    -> SearchFuture<'a, ()>;
600
601    /// Index a batch of documents.
602    ///
603    /// # Errors
604    ///
605    /// Returns `SearchError` if any document cannot be indexed.
606    fn index_documents<'a>(
607        &'a self,
608        cx: &'a Cx,
609        docs: &'a [IndexableDocument],
610    ) -> SearchFuture<'a, ()> {
611        Box::pin(async move {
612            for doc in docs {
613                self.index_document(cx, doc).await?;
614            }
615            Ok(())
616        })
617    }
618
619    /// Commit any pending writes to the index.
620    ///
621    /// # Errors
622    ///
623    /// Returns `SearchError` if the commit fails (e.g., I/O error).
624    fn commit<'a>(&'a self, cx: &'a Cx) -> SearchFuture<'a, ()>;
625
626    /// Number of documents currently indexed.
627    fn doc_count(&self) -> usize;
628}
629
630// ─── Metrics Exporter Trait ─────────────────────────────────────────────────
631
632/// Trait for exporting search/index/embed telemetry to external consumers.
633///
634/// Implementations must be non-blocking and fast, because callbacks are invoked
635/// directly from hot paths.
636pub trait MetricsExporter: fmt::Debug + Send + Sync {
637    /// Called when a search request completes.
638    fn on_search_completed(&self, metrics: &SearchMetrics);
639
640    /// Called when an embedding operation completes.
641    fn on_embedding_completed(&self, metrics: &EmbeddingMetrics);
642
643    /// Called when index state changes after an update/commit.
644    fn on_index_updated(&self, metrics: &IndexMetrics);
645
646    /// Called when a search pipeline error is observed.
647    fn on_error(&self, error: &SearchError);
648}
649
650/// Shared handle for dynamic telemetry exporters.
651pub type SharedMetricsExporter = Arc<dyn MetricsExporter>;
652
653/// No-op exporter used when no telemetry sink is attached.
654///
655/// This is intentionally empty so callers can cheaply opt out of telemetry.
656#[derive(Debug, Default, Clone, Copy)]
657pub struct NoOpMetricsExporter;
658
659impl MetricsExporter for NoOpMetricsExporter {
660    fn on_search_completed(&self, _: &SearchMetrics) {}
661
662    fn on_embedding_completed(&self, _: &EmbeddingMetrics) {}
663
664    fn on_index_updated(&self, _: &IndexMetrics) {}
665
666    fn on_error(&self, _: &SearchError) {}
667}
668
669#[cfg(test)]
670mod tests {
671    use asupersync::test_utils::run_test_with_cx;
672
673    use super::*;
674
675    struct UnsortedSyncReranker;
676
677    impl SyncRerank for UnsortedSyncReranker {
678        fn rerank_sync(
679            &self,
680            _query: &str,
681            _documents: &[RerankDocument],
682        ) -> SearchResult<Vec<RerankScore>> {
683            Ok(vec![
684                RerankScore {
685                    doc_id: "doc-a".to_owned(),
686                    score: 0.8,
687                    original_rank: 2,
688                    raw_logit: None,
689                },
690                RerankScore {
691                    doc_id: "doc-b".to_owned(),
692                    score: 0.8,
693                    original_rank: 1,
694                    raw_logit: None,
695                },
696                RerankScore {
697                    doc_id: "doc-c".to_owned(),
698                    score: 0.3,
699                    original_rank: 0,
700                    raw_logit: None,
701                },
702            ])
703        }
704
705        fn id(&self) -> &'static str {
706            "unsorted-sync-reranker"
707        }
708
709        fn model_name(&self) -> &'static str {
710            "Unsorted Sync Reranker"
711        }
712    }
713
714    #[test]
715    fn model_category_display() {
716        assert_eq!(ModelCategory::HashEmbedder.to_string(), "hash_embedder");
717        assert_eq!(ModelCategory::StaticEmbedder.to_string(), "static_embedder");
718        assert_eq!(
719            ModelCategory::TransformerEmbedder.to_string(),
720            "transformer_embedder"
721        );
722    }
723
724    #[test]
725    fn model_category_serialization() {
726        let json = serde_json::to_string(&ModelCategory::StaticEmbedder).unwrap();
727        let decoded: ModelCategory = serde_json::from_str(&json).unwrap();
728        assert_eq!(decoded, ModelCategory::StaticEmbedder);
729    }
730
731    #[test]
732    fn model_category_equality() {
733        assert_eq!(ModelCategory::HashEmbedder, ModelCategory::HashEmbedder);
734        assert_ne!(ModelCategory::HashEmbedder, ModelCategory::StaticEmbedder);
735        assert_ne!(
736            ModelCategory::StaticEmbedder,
737            ModelCategory::TransformerEmbedder
738        );
739    }
740
741    #[test]
742    fn model_category_default_tier() {
743        assert_eq!(ModelCategory::HashEmbedder.default_tier(), ModelTier::Fast);
744        assert_eq!(
745            ModelCategory::StaticEmbedder.default_tier(),
746            ModelTier::Fast
747        );
748        assert_eq!(
749            ModelCategory::TransformerEmbedder.default_tier(),
750            ModelTier::Quality
751        );
752    }
753
754    #[test]
755    fn model_tier_display() {
756        assert_eq!(ModelTier::Fast.to_string(), "fast");
757        assert_eq!(ModelTier::Quality.to_string(), "quality");
758    }
759
760    #[test]
761    fn model_info_roundtrip() {
762        let info = ModelInfo {
763            id: "potion-multilingual-128M".to_owned(),
764            name: "Potion 128M".to_owned(),
765            dimension: 256,
766            category: ModelCategory::StaticEmbedder,
767            tier: ModelTier::Fast,
768            is_semantic: true,
769            supports_mrl: false,
770            huggingface_id: Some("minishlab/potion-multilingual-128M".to_owned()),
771            size_bytes: Some(128_000_000),
772            license: Some("apache-2.0".to_owned()),
773        };
774
775        let json = serde_json::to_string(&info).unwrap();
776        let decoded: ModelInfo = serde_json::from_str(&json).unwrap();
777        assert_eq!(decoded, info);
778    }
779
780    #[test]
781    fn rerank_document_construction() {
782        let doc = RerankDocument {
783            doc_id: "doc-1".into(),
784            text: "Some content".into(),
785        };
786        assert_eq!(doc.doc_id, "doc-1");
787        assert_eq!(doc.text, "Some content");
788    }
789
790    #[test]
791    fn rerank_score_serialization() {
792        let score = RerankScore {
793            doc_id: "doc-1".into(),
794            score: 0.92,
795            original_rank: 3,
796            raw_logit: None,
797        };
798
799        let json = serde_json::to_string(&score).unwrap();
800        let decoded: RerankScore = serde_json::from_str(&json).unwrap();
801        assert_eq!(decoded.doc_id, "doc-1");
802        assert!((decoded.score - 0.92).abs() < 1e-6);
803        assert_eq!(decoded.original_rank, 3);
804    }
805
806    // Compile-time checks for trait object safety
807    #[test]
808    fn embedder_trait_is_object_safe() {
809        fn _takes_dyn_embedder(_: &dyn Embedder) {}
810    }
811
812    #[test]
813    fn reranker_trait_is_object_safe() {
814        fn _takes_dyn_reranker(_: &dyn Reranker) {}
815    }
816
817    #[test]
818    fn lexical_search_trait_is_object_safe() {
819        fn _takes_dyn_lexical(_: &dyn LexicalSearch) {}
820    }
821
822    #[test]
823    fn metrics_exporter_trait_is_object_safe() {
824        fn _takes_dyn_metrics_exporter(_: &dyn MetricsExporter) {}
825    }
826
827    #[test]
828    fn sync_reranker_adapter_sorts_descending_for_trait_contract() {
829        run_test_with_cx(|cx| async move {
830            let adapter = SyncRerankerAdapter(UnsortedSyncReranker);
831            let docs = vec![
832                RerankDocument {
833                    doc_id: "doc-a".to_owned(),
834                    text: "alpha".to_owned(),
835                },
836                RerankDocument {
837                    doc_id: "doc-b".to_owned(),
838                    text: "beta".to_owned(),
839                },
840                RerankDocument {
841                    doc_id: "doc-c".to_owned(),
842                    text: "gamma".to_owned(),
843                },
844            ];
845            let scores = adapter
846                .rerank(&cx, "query", &docs)
847                .await
848                .expect("adapter rerank should succeed");
849            let ids = scores
850                .iter()
851                .map(|score| score.doc_id.as_str())
852                .collect::<Vec<_>>();
853            assert_eq!(ids, vec!["doc-b", "doc-a", "doc-c"]);
854        });
855    }
856
857    #[test]
858    fn noop_metrics_exporter_callbacks_are_noops() {
859        let exporter = NoOpMetricsExporter;
860
861        let search_metrics = SearchMetrics {
862            mode: crate::types::SearchMode::Hybrid,
863            query_class: None,
864            total_latency_ms: 10.0,
865            phase1_latency_ms: Some(4.0),
866            phase2_latency_ms: Some(6.0),
867            result_count: 8,
868            lexical_candidates: 30,
869            semantic_candidates: 25,
870            refined: true,
871        };
872        let embedding_metrics = EmbeddingMetrics {
873            embedder_id: "fnv-hash-384".into(),
874            batch_size: 1,
875            duration_ms: 0.07,
876            dimension: 384,
877            is_semantic: false,
878        };
879        let index_metrics = IndexMetrics {
880            doc_count: 100,
881            index_size_bytes: 4096,
882            updated_docs: 1,
883            staleness_detected: false,
884        };
885
886        exporter.on_search_completed(&search_metrics);
887        exporter.on_embedding_completed(&embedding_metrics);
888        exporter.on_index_updated(&index_metrics);
889        exporter.on_error(&SearchError::SearchTimeout {
890            elapsed_ms: 11,
891            budget_ms: 10,
892        });
893    }
894
895    // ─── Utility function tests ─────────────────────────────────────────
896
897    #[test]
898    fn l2_normalize_produces_unit_vector() {
899        let v = vec![3.0, 4.0];
900        let normalized = l2_normalize(&v);
901        let norm: f32 = normalized.iter().map(|x| x * x).sum::<f32>().sqrt();
902        assert!((norm - 1.0).abs() < 1e-6);
903    }
904
905    #[test]
906    fn l2_normalize_zero_vector() {
907        let v = vec![0.0, 0.0, 0.0];
908        let normalized = l2_normalize(&v);
909        assert!(normalized.iter().all(|&x| x == 0.0));
910    }
911
912    #[test]
913    fn cosine_similarity_identical() {
914        let v = vec![1.0, 2.0, 3.0];
915        let sim = cosine_similarity(&v, &v);
916        assert!((sim - 1.0).abs() < 1e-6);
917    }
918
919    #[test]
920    fn cosine_similarity_orthogonal() {
921        let a = vec![1.0, 0.0];
922        let b = vec![0.0, 1.0];
923        assert!(cosine_similarity(&a, &b).abs() < 1e-6);
924    }
925
926    #[test]
927    fn cosine_similarity_zero_vector() {
928        let a = vec![1.0, 2.0];
929        let b = vec![0.0, 0.0];
930        assert!(cosine_similarity(&a, &b).abs() < f32::EPSILON);
931    }
932
933    #[test]
934    fn truncate_embedding_reduces_dim() {
935        let v = vec![1.0, 2.0, 3.0, 4.0];
936        let t = truncate_embedding(&v, 2);
937        assert_eq!(t.len(), 2);
938        let norm: f32 = t.iter().map(|x| x * x).sum::<f32>().sqrt();
939        assert!((norm - 1.0).abs() < 1e-6);
940    }
941
942    #[test]
943    fn truncate_embedding_noop_when_larger() {
944        let v = vec![1.0, 2.0];
945        assert_eq!(truncate_embedding(&v, 10), v);
946    }
947
948    #[test]
949    fn model_category_default_semantic_flag() {
950        assert!(!ModelCategory::HashEmbedder.default_semantic_flag());
951        assert!(ModelCategory::StaticEmbedder.default_semantic_flag());
952        assert!(ModelCategory::TransformerEmbedder.default_semantic_flag());
953    }
954}