1use std::fmt;
11use std::future::Future;
12use std::pin::Pin;
13use std::sync::Arc;
14
15use asupersync::Cx;
16use serde::{Deserialize, Serialize};
17
18use crate::error::{SearchError, SearchResult};
19use crate::types::{
20 EmbeddingMetrics, IndexMetrics, IndexableDocument, ScoredResult, SearchMetrics,
21};
22
23pub type SearchFuture<'a, T> = Pin<Box<dyn Future<Output = SearchResult<T>> + Send + 'a>>;
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
33pub enum ModelCategory {
34 HashEmbedder,
36 StaticEmbedder,
38 TransformerEmbedder,
40 ApiEmbedder,
42}
43
44impl fmt::Display for ModelCategory {
45 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46 match self {
47 Self::HashEmbedder => write!(f, "hash_embedder"),
48 Self::StaticEmbedder => write!(f, "static_embedder"),
49 Self::TransformerEmbedder => write!(f, "transformer_embedder"),
50 Self::ApiEmbedder => write!(f, "api_embedder"),
51 }
52 }
53}
54
55impl ModelCategory {
56 #[must_use]
58 pub const fn default_tier(self) -> ModelTier {
59 match self {
60 Self::HashEmbedder | Self::StaticEmbedder => ModelTier::Fast,
61 Self::TransformerEmbedder | Self::ApiEmbedder => ModelTier::Quality,
62 }
63 }
64
65 #[must_use]
67 pub const fn default_semantic_flag(self) -> bool {
68 !matches!(self, Self::HashEmbedder)
69 }
70}
71
72#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
74pub enum ModelTier {
75 Fast,
77 Quality,
79}
80
81impl fmt::Display for ModelTier {
82 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83 match self {
84 Self::Fast => write!(f, "fast"),
85 Self::Quality => write!(f, "quality"),
86 }
87 }
88}
89
90#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
92pub struct ModelInfo {
93 pub id: String,
95 pub name: String,
97 pub dimension: usize,
99 pub category: ModelCategory,
101 pub tier: ModelTier,
103 pub is_semantic: bool,
105 pub supports_mrl: bool,
107 pub huggingface_id: Option<String>,
109 pub size_bytes: Option<u64>,
111 pub license: Option<String>,
113}
114
115pub trait Embedder: Send + Sync {
128 fn embed<'a>(&'a self, cx: &'a Cx, text: &'a str) -> SearchFuture<'a, Vec<f32>>;
136
137 fn embed_batch<'a>(
147 &'a self,
148 cx: &'a Cx,
149 texts: &'a [&'a str],
150 ) -> SearchFuture<'a, Vec<Vec<f32>>> {
151 Box::pin(async move {
152 let mut out = Vec::with_capacity(texts.len());
153 for text in texts {
154 out.push(self.embed(cx, text).await?);
155 }
156 Ok(out)
157 })
158 }
159
160 fn dimension(&self) -> usize;
162
163 fn id(&self) -> &str;
168
169 fn model_name(&self) -> &str;
171
172 fn is_ready(&self) -> bool {
174 true
175 }
176
177 fn is_semantic(&self) -> bool;
181
182 fn category(&self) -> ModelCategory;
184
185 fn tier(&self) -> ModelTier {
187 self.category().default_tier()
188 }
189
190 fn supports_mrl(&self) -> bool {
193 false
194 }
195
196 fn truncate_embedding(&self, embedding: &[f32], target_dim: usize) -> SearchResult<Vec<f32>> {
202 if target_dim == 0 {
203 return Err(SearchError::InvalidConfig {
204 field: "target_dim".to_owned(),
205 value: "0".to_owned(),
206 reason: "target dimension must be at least 1".to_owned(),
207 });
208 }
209
210 if target_dim >= embedding.len() {
211 return Ok(embedding.to_vec());
212 }
213
214 Ok(l2_normalize(&embedding[..target_dim]))
215 }
216}
217
218pub trait SyncEmbed: Send + Sync {
248 fn embed_sync(&self, text: &str) -> SearchResult<Vec<f32>>;
255
256 fn embed_batch_sync(&self, texts: &[&str]) -> SearchResult<Vec<Vec<f32>>> {
265 texts.iter().map(|t| self.embed_sync(t)).collect()
266 }
267
268 fn dimension(&self) -> usize;
270
271 fn id(&self) -> &str;
273
274 fn model_name(&self) -> &str {
276 self.id()
277 }
278
279 fn is_ready(&self) -> bool {
281 true
282 }
283
284 fn is_semantic(&self) -> bool;
286
287 fn category(&self) -> ModelCategory;
289
290 fn tier(&self) -> ModelTier {
292 self.category().default_tier()
293 }
294
295 fn supports_mrl(&self) -> bool {
297 false
298 }
299}
300
301pub struct SyncEmbedderAdapter<T: SyncEmbed>(pub T);
307
308impl<T: SyncEmbed + 'static> Embedder for SyncEmbedderAdapter<T> {
309 fn embed<'a>(&'a self, _cx: &'a Cx, text: &'a str) -> SearchFuture<'a, Vec<f32>> {
310 Box::pin(async move { self.0.embed_sync(text) })
311 }
312
313 fn embed_batch<'a>(
314 &'a self,
315 _cx: &'a Cx,
316 texts: &'a [&'a str],
317 ) -> SearchFuture<'a, Vec<Vec<f32>>> {
318 Box::pin(async move { self.0.embed_batch_sync(texts) })
319 }
320
321 fn dimension(&self) -> usize {
322 self.0.dimension()
323 }
324
325 fn id(&self) -> &str {
326 self.0.id()
327 }
328
329 fn model_name(&self) -> &str {
330 self.0.model_name()
331 }
332
333 fn is_ready(&self) -> bool {
334 self.0.is_ready()
335 }
336
337 fn is_semantic(&self) -> bool {
338 self.0.is_semantic()
339 }
340
341 fn category(&self) -> ModelCategory {
342 self.0.category()
343 }
344
345 fn tier(&self) -> ModelTier {
346 self.0.tier()
347 }
348
349 fn supports_mrl(&self) -> bool {
350 self.0.supports_mrl()
351 }
352}
353
354#[must_use]
360pub fn l2_normalize(vec: &[f32]) -> Vec<f32> {
361 let norm_sq: f32 = vec.iter().map(|x| x * x).sum();
362 if !norm_sq.is_finite() || norm_sq < f32::EPSILON {
363 return vec![0.0; vec.len()];
364 }
365 let inv_norm = 1.0 / norm_sq.sqrt();
366 vec.iter().map(|x| x * inv_norm).collect()
367}
368
369#[must_use]
377pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
378 if a.len() != b.len() {
381 return 0.0;
382 }
383
384 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
385 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
386 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
387
388 let denom = norm_a * norm_b;
389 if !denom.is_finite() || denom < f32::EPSILON {
390 return 0.0;
391 }
392 dot / denom
393}
394
395#[must_use]
402pub fn truncate_embedding(embedding: &[f32], target_dim: usize) -> Vec<f32> {
403 if target_dim >= embedding.len() {
404 return embedding.to_vec();
405 }
406 l2_normalize(&embedding[..target_dim])
407}
408
409#[derive(Debug, Clone)]
417pub struct RerankDocument {
418 pub doc_id: String,
420 pub text: String,
422}
423
424#[derive(Debug, Clone, Serialize, Deserialize)]
426pub struct RerankScore {
427 pub doc_id: String,
429 pub score: f32,
431 pub original_rank: usize,
433 #[serde(default, skip_serializing_if = "Option::is_none")]
438 pub raw_logit: Option<f32>,
439}
440
441pub trait Reranker: Send + Sync {
455 fn rerank<'a>(
463 &'a self,
464 cx: &'a Cx,
465 query: &'a str,
466 documents: &'a [RerankDocument],
467 ) -> SearchFuture<'a, Vec<RerankScore>>;
468
469 fn id(&self) -> &str;
471
472 fn model_name(&self) -> &str;
474
475 fn max_length(&self) -> usize {
477 512
478 }
479
480 fn is_available(&self) -> bool {
482 true
483 }
484}
485
486pub trait SyncRerank: Send + Sync {
496 fn rerank_sync(
505 &self,
506 query: &str,
507 documents: &[RerankDocument],
508 ) -> SearchResult<Vec<RerankScore>>;
509
510 fn id(&self) -> &str;
512
513 fn model_name(&self) -> &str;
515
516 fn max_length(&self) -> usize {
518 512
519 }
520
521 fn is_available(&self) -> bool {
523 true
524 }
525}
526
527pub struct SyncRerankerAdapter<T: SyncRerank>(pub T);
533
534impl<T: SyncRerank + 'static> Reranker for SyncRerankerAdapter<T> {
535 fn rerank<'a>(
536 &'a self,
537 _cx: &'a Cx,
538 query: &'a str,
539 documents: &'a [RerankDocument],
540 ) -> SearchFuture<'a, Vec<RerankScore>> {
541 Box::pin(async move {
542 let mut scores = self.0.rerank_sync(query, documents)?;
543 scores.sort_by(|lhs, rhs| {
544 rhs.score
545 .total_cmp(&lhs.score)
546 .then_with(|| lhs.original_rank.cmp(&rhs.original_rank))
547 .then_with(|| lhs.doc_id.cmp(&rhs.doc_id))
548 });
549 Ok(scores)
550 })
551 }
552
553 fn id(&self) -> &str {
554 self.0.id()
555 }
556
557 fn model_name(&self) -> &str {
558 self.0.model_name()
559 }
560
561 fn max_length(&self) -> usize {
562 self.0.max_length()
563 }
564
565 fn is_available(&self) -> bool {
566 self.0.is_available()
567 }
568}
569
570pub trait LexicalSearch: Send + Sync {
580 fn search<'a>(
587 &'a self,
588 cx: &'a Cx,
589 query: &'a str,
590 limit: usize,
591 ) -> SearchFuture<'a, Vec<ScoredResult>>;
592
593 fn index_document<'a>(&'a self, cx: &'a Cx, doc: &'a IndexableDocument)
599 -> SearchFuture<'a, ()>;
600
601 fn index_documents<'a>(
607 &'a self,
608 cx: &'a Cx,
609 docs: &'a [IndexableDocument],
610 ) -> SearchFuture<'a, ()> {
611 Box::pin(async move {
612 for doc in docs {
613 self.index_document(cx, doc).await?;
614 }
615 Ok(())
616 })
617 }
618
619 fn commit<'a>(&'a self, cx: &'a Cx) -> SearchFuture<'a, ()>;
625
626 fn doc_count(&self) -> usize;
628}
629
630pub trait MetricsExporter: fmt::Debug + Send + Sync {
637 fn on_search_completed(&self, metrics: &SearchMetrics);
639
640 fn on_embedding_completed(&self, metrics: &EmbeddingMetrics);
642
643 fn on_index_updated(&self, metrics: &IndexMetrics);
645
646 fn on_error(&self, error: &SearchError);
648}
649
650pub type SharedMetricsExporter = Arc<dyn MetricsExporter>;
652
653#[derive(Debug, Default, Clone, Copy)]
657pub struct NoOpMetricsExporter;
658
659impl MetricsExporter for NoOpMetricsExporter {
660 fn on_search_completed(&self, _: &SearchMetrics) {}
661
662 fn on_embedding_completed(&self, _: &EmbeddingMetrics) {}
663
664 fn on_index_updated(&self, _: &IndexMetrics) {}
665
666 fn on_error(&self, _: &SearchError) {}
667}
668
669#[cfg(test)]
670mod tests {
671 use asupersync::test_utils::run_test_with_cx;
672
673 use super::*;
674
675 struct UnsortedSyncReranker;
676
677 impl SyncRerank for UnsortedSyncReranker {
678 fn rerank_sync(
679 &self,
680 _query: &str,
681 _documents: &[RerankDocument],
682 ) -> SearchResult<Vec<RerankScore>> {
683 Ok(vec![
684 RerankScore {
685 doc_id: "doc-a".to_owned(),
686 score: 0.8,
687 original_rank: 2,
688 raw_logit: None,
689 },
690 RerankScore {
691 doc_id: "doc-b".to_owned(),
692 score: 0.8,
693 original_rank: 1,
694 raw_logit: None,
695 },
696 RerankScore {
697 doc_id: "doc-c".to_owned(),
698 score: 0.3,
699 original_rank: 0,
700 raw_logit: None,
701 },
702 ])
703 }
704
705 fn id(&self) -> &'static str {
706 "unsorted-sync-reranker"
707 }
708
709 fn model_name(&self) -> &'static str {
710 "Unsorted Sync Reranker"
711 }
712 }
713
714 #[test]
715 fn model_category_display() {
716 assert_eq!(ModelCategory::HashEmbedder.to_string(), "hash_embedder");
717 assert_eq!(ModelCategory::StaticEmbedder.to_string(), "static_embedder");
718 assert_eq!(
719 ModelCategory::TransformerEmbedder.to_string(),
720 "transformer_embedder"
721 );
722 }
723
724 #[test]
725 fn model_category_serialization() {
726 let json = serde_json::to_string(&ModelCategory::StaticEmbedder).unwrap();
727 let decoded: ModelCategory = serde_json::from_str(&json).unwrap();
728 assert_eq!(decoded, ModelCategory::StaticEmbedder);
729 }
730
731 #[test]
732 fn model_category_equality() {
733 assert_eq!(ModelCategory::HashEmbedder, ModelCategory::HashEmbedder);
734 assert_ne!(ModelCategory::HashEmbedder, ModelCategory::StaticEmbedder);
735 assert_ne!(
736 ModelCategory::StaticEmbedder,
737 ModelCategory::TransformerEmbedder
738 );
739 }
740
741 #[test]
742 fn model_category_default_tier() {
743 assert_eq!(ModelCategory::HashEmbedder.default_tier(), ModelTier::Fast);
744 assert_eq!(
745 ModelCategory::StaticEmbedder.default_tier(),
746 ModelTier::Fast
747 );
748 assert_eq!(
749 ModelCategory::TransformerEmbedder.default_tier(),
750 ModelTier::Quality
751 );
752 }
753
754 #[test]
755 fn model_tier_display() {
756 assert_eq!(ModelTier::Fast.to_string(), "fast");
757 assert_eq!(ModelTier::Quality.to_string(), "quality");
758 }
759
760 #[test]
761 fn model_info_roundtrip() {
762 let info = ModelInfo {
763 id: "potion-multilingual-128M".to_owned(),
764 name: "Potion 128M".to_owned(),
765 dimension: 256,
766 category: ModelCategory::StaticEmbedder,
767 tier: ModelTier::Fast,
768 is_semantic: true,
769 supports_mrl: false,
770 huggingface_id: Some("minishlab/potion-multilingual-128M".to_owned()),
771 size_bytes: Some(128_000_000),
772 license: Some("apache-2.0".to_owned()),
773 };
774
775 let json = serde_json::to_string(&info).unwrap();
776 let decoded: ModelInfo = serde_json::from_str(&json).unwrap();
777 assert_eq!(decoded, info);
778 }
779
780 #[test]
781 fn rerank_document_construction() {
782 let doc = RerankDocument {
783 doc_id: "doc-1".into(),
784 text: "Some content".into(),
785 };
786 assert_eq!(doc.doc_id, "doc-1");
787 assert_eq!(doc.text, "Some content");
788 }
789
790 #[test]
791 fn rerank_score_serialization() {
792 let score = RerankScore {
793 doc_id: "doc-1".into(),
794 score: 0.92,
795 original_rank: 3,
796 raw_logit: None,
797 };
798
799 let json = serde_json::to_string(&score).unwrap();
800 let decoded: RerankScore = serde_json::from_str(&json).unwrap();
801 assert_eq!(decoded.doc_id, "doc-1");
802 assert!((decoded.score - 0.92).abs() < 1e-6);
803 assert_eq!(decoded.original_rank, 3);
804 }
805
806 #[test]
808 fn embedder_trait_is_object_safe() {
809 fn _takes_dyn_embedder(_: &dyn Embedder) {}
810 }
811
812 #[test]
813 fn reranker_trait_is_object_safe() {
814 fn _takes_dyn_reranker(_: &dyn Reranker) {}
815 }
816
817 #[test]
818 fn lexical_search_trait_is_object_safe() {
819 fn _takes_dyn_lexical(_: &dyn LexicalSearch) {}
820 }
821
822 #[test]
823 fn metrics_exporter_trait_is_object_safe() {
824 fn _takes_dyn_metrics_exporter(_: &dyn MetricsExporter) {}
825 }
826
827 #[test]
828 fn sync_reranker_adapter_sorts_descending_for_trait_contract() {
829 run_test_with_cx(|cx| async move {
830 let adapter = SyncRerankerAdapter(UnsortedSyncReranker);
831 let docs = vec![
832 RerankDocument {
833 doc_id: "doc-a".to_owned(),
834 text: "alpha".to_owned(),
835 },
836 RerankDocument {
837 doc_id: "doc-b".to_owned(),
838 text: "beta".to_owned(),
839 },
840 RerankDocument {
841 doc_id: "doc-c".to_owned(),
842 text: "gamma".to_owned(),
843 },
844 ];
845 let scores = adapter
846 .rerank(&cx, "query", &docs)
847 .await
848 .expect("adapter rerank should succeed");
849 let ids = scores
850 .iter()
851 .map(|score| score.doc_id.as_str())
852 .collect::<Vec<_>>();
853 assert_eq!(ids, vec!["doc-b", "doc-a", "doc-c"]);
854 });
855 }
856
857 #[test]
858 fn noop_metrics_exporter_callbacks_are_noops() {
859 let exporter = NoOpMetricsExporter;
860
861 let search_metrics = SearchMetrics {
862 mode: crate::types::SearchMode::Hybrid,
863 query_class: None,
864 total_latency_ms: 10.0,
865 phase1_latency_ms: Some(4.0),
866 phase2_latency_ms: Some(6.0),
867 result_count: 8,
868 lexical_candidates: 30,
869 semantic_candidates: 25,
870 refined: true,
871 };
872 let embedding_metrics = EmbeddingMetrics {
873 embedder_id: "fnv-hash-384".into(),
874 batch_size: 1,
875 duration_ms: 0.07,
876 dimension: 384,
877 is_semantic: false,
878 };
879 let index_metrics = IndexMetrics {
880 doc_count: 100,
881 index_size_bytes: 4096,
882 updated_docs: 1,
883 staleness_detected: false,
884 };
885
886 exporter.on_search_completed(&search_metrics);
887 exporter.on_embedding_completed(&embedding_metrics);
888 exporter.on_index_updated(&index_metrics);
889 exporter.on_error(&SearchError::SearchTimeout {
890 elapsed_ms: 11,
891 budget_ms: 10,
892 });
893 }
894
895 #[test]
898 fn l2_normalize_produces_unit_vector() {
899 let v = vec![3.0, 4.0];
900 let normalized = l2_normalize(&v);
901 let norm: f32 = normalized.iter().map(|x| x * x).sum::<f32>().sqrt();
902 assert!((norm - 1.0).abs() < 1e-6);
903 }
904
905 #[test]
906 fn l2_normalize_zero_vector() {
907 let v = vec![0.0, 0.0, 0.0];
908 let normalized = l2_normalize(&v);
909 assert!(normalized.iter().all(|&x| x == 0.0));
910 }
911
912 #[test]
913 fn cosine_similarity_identical() {
914 let v = vec![1.0, 2.0, 3.0];
915 let sim = cosine_similarity(&v, &v);
916 assert!((sim - 1.0).abs() < 1e-6);
917 }
918
919 #[test]
920 fn cosine_similarity_orthogonal() {
921 let a = vec![1.0, 0.0];
922 let b = vec![0.0, 1.0];
923 assert!(cosine_similarity(&a, &b).abs() < 1e-6);
924 }
925
926 #[test]
927 fn cosine_similarity_zero_vector() {
928 let a = vec![1.0, 2.0];
929 let b = vec![0.0, 0.0];
930 assert!(cosine_similarity(&a, &b).abs() < f32::EPSILON);
931 }
932
933 #[test]
934 fn truncate_embedding_reduces_dim() {
935 let v = vec![1.0, 2.0, 3.0, 4.0];
936 let t = truncate_embedding(&v, 2);
937 assert_eq!(t.len(), 2);
938 let norm: f32 = t.iter().map(|x| x * x).sum::<f32>().sqrt();
939 assert!((norm - 1.0).abs() < 1e-6);
940 }
941
942 #[test]
943 fn truncate_embedding_noop_when_larger() {
944 let v = vec![1.0, 2.0];
945 assert_eq!(truncate_embedding(&v, 10), v);
946 }
947
948 #[test]
949 fn model_category_default_semantic_flag() {
950 assert!(!ModelCategory::HashEmbedder.default_semantic_flag());
951 assert!(ModelCategory::StaticEmbedder.default_semantic_flag());
952 assert!(ModelCategory::TransformerEmbedder.default_semantic_flag());
953 }
954}