use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use asupersync::Cx;
use serde::{Deserialize, Serialize};
use crate::error::{SearchError, SearchResult};
use crate::types::{
EmbeddingMetrics, IndexMetrics, IndexableDocument, ScoredResult, SearchMetrics,
};
pub type SearchFuture<'a, T> = Pin<Box<dyn Future<Output = SearchResult<T>> + Send + 'a>>;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ModelCategory {
HashEmbedder,
StaticEmbedder,
TransformerEmbedder,
ApiEmbedder,
}
impl fmt::Display for ModelCategory {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::HashEmbedder => write!(f, "hash_embedder"),
Self::StaticEmbedder => write!(f, "static_embedder"),
Self::TransformerEmbedder => write!(f, "transformer_embedder"),
Self::ApiEmbedder => write!(f, "api_embedder"),
}
}
}
impl ModelCategory {
#[must_use]
pub const fn default_tier(self) -> ModelTier {
match self {
Self::HashEmbedder | Self::StaticEmbedder => ModelTier::Fast,
Self::TransformerEmbedder | Self::ApiEmbedder => ModelTier::Quality,
}
}
#[must_use]
pub const fn default_semantic_flag(self) -> bool {
!matches!(self, Self::HashEmbedder)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ModelTier {
Fast,
Quality,
}
impl fmt::Display for ModelTier {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Fast => write!(f, "fast"),
Self::Quality => write!(f, "quality"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ModelInfo {
pub id: String,
pub name: String,
pub dimension: usize,
pub category: ModelCategory,
pub tier: ModelTier,
pub is_semantic: bool,
pub supports_mrl: bool,
pub huggingface_id: Option<String>,
pub size_bytes: Option<u64>,
pub license: Option<String>,
}
pub trait Embedder: Send + Sync {
fn embed<'a>(&'a self, cx: &'a Cx, text: &'a str) -> SearchFuture<'a, Vec<f32>>;
fn embed_batch<'a>(
&'a self,
cx: &'a Cx,
texts: &'a [&'a str],
) -> SearchFuture<'a, Vec<Vec<f32>>> {
Box::pin(async move {
let mut out = Vec::with_capacity(texts.len());
for text in texts {
out.push(self.embed(cx, text).await?);
}
Ok(out)
})
}
fn dimension(&self) -> usize;
fn id(&self) -> &str;
fn model_name(&self) -> &str;
fn is_ready(&self) -> bool {
true
}
fn is_semantic(&self) -> bool;
fn category(&self) -> ModelCategory;
fn tier(&self) -> ModelTier {
self.category().default_tier()
}
fn supports_mrl(&self) -> bool {
false
}
fn truncate_embedding(&self, embedding: &[f32], target_dim: usize) -> SearchResult<Vec<f32>> {
if target_dim == 0 {
return Err(SearchError::InvalidConfig {
field: "target_dim".to_owned(),
value: "0".to_owned(),
reason: "target dimension must be at least 1".to_owned(),
});
}
if target_dim >= embedding.len() {
return Ok(embedding.to_vec());
}
Ok(l2_normalize(&embedding[..target_dim]))
}
}
pub trait SyncEmbed: Send + Sync {
fn embed_sync(&self, text: &str) -> SearchResult<Vec<f32>>;
fn embed_batch_sync(&self, texts: &[&str]) -> SearchResult<Vec<Vec<f32>>> {
texts.iter().map(|t| self.embed_sync(t)).collect()
}
fn dimension(&self) -> usize;
fn id(&self) -> &str;
fn model_name(&self) -> &str {
self.id()
}
fn is_ready(&self) -> bool {
true
}
fn is_semantic(&self) -> bool;
fn category(&self) -> ModelCategory;
fn tier(&self) -> ModelTier {
self.category().default_tier()
}
fn supports_mrl(&self) -> bool {
false
}
}
pub struct SyncEmbedderAdapter<T: SyncEmbed>(pub T);
impl<T: SyncEmbed + 'static> Embedder for SyncEmbedderAdapter<T> {
fn embed<'a>(&'a self, _cx: &'a Cx, text: &'a str) -> SearchFuture<'a, Vec<f32>> {
Box::pin(async move { self.0.embed_sync(text) })
}
fn embed_batch<'a>(
&'a self,
_cx: &'a Cx,
texts: &'a [&'a str],
) -> SearchFuture<'a, Vec<Vec<f32>>> {
Box::pin(async move { self.0.embed_batch_sync(texts) })
}
fn dimension(&self) -> usize {
self.0.dimension()
}
fn id(&self) -> &str {
self.0.id()
}
fn model_name(&self) -> &str {
self.0.model_name()
}
fn is_ready(&self) -> bool {
self.0.is_ready()
}
fn is_semantic(&self) -> bool {
self.0.is_semantic()
}
fn category(&self) -> ModelCategory {
self.0.category()
}
fn tier(&self) -> ModelTier {
self.0.tier()
}
fn supports_mrl(&self) -> bool {
self.0.supports_mrl()
}
}
#[must_use]
pub fn l2_normalize(vec: &[f32]) -> Vec<f32> {
let norm_sq: f32 = vec.iter().map(|x| x * x).sum();
if !norm_sq.is_finite() || norm_sq < f32::EPSILON {
return vec![0.0; vec.len()];
}
let inv_norm = 1.0 / norm_sq.sqrt();
vec.iter().map(|x| x * inv_norm).collect()
}
#[must_use]
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
let denom = norm_a * norm_b;
if !denom.is_finite() || denom < f32::EPSILON {
return 0.0;
}
dot / denom
}
#[must_use]
pub fn truncate_embedding(embedding: &[f32], target_dim: usize) -> Vec<f32> {
if target_dim >= embedding.len() {
return embedding.to_vec();
}
l2_normalize(&embedding[..target_dim])
}
#[derive(Debug, Clone)]
pub struct RerankDocument {
pub doc_id: String,
pub text: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RerankScore {
pub doc_id: String,
pub score: f32,
pub original_rank: usize,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub raw_logit: Option<f32>,
}
pub trait Reranker: Send + Sync {
fn rerank<'a>(
&'a self,
cx: &'a Cx,
query: &'a str,
documents: &'a [RerankDocument],
) -> SearchFuture<'a, Vec<RerankScore>>;
fn id(&self) -> &str;
fn model_name(&self) -> &str;
fn max_length(&self) -> usize {
512
}
fn is_available(&self) -> bool {
true
}
}
pub trait SyncRerank: Send + Sync {
fn rerank_sync(
&self,
query: &str,
documents: &[RerankDocument],
) -> SearchResult<Vec<RerankScore>>;
fn id(&self) -> &str;
fn model_name(&self) -> &str;
fn max_length(&self) -> usize {
512
}
fn is_available(&self) -> bool {
true
}
}
pub struct SyncRerankerAdapter<T: SyncRerank>(pub T);
impl<T: SyncRerank + 'static> Reranker for SyncRerankerAdapter<T> {
fn rerank<'a>(
&'a self,
_cx: &'a Cx,
query: &'a str,
documents: &'a [RerankDocument],
) -> SearchFuture<'a, Vec<RerankScore>> {
Box::pin(async move {
let mut scores = self.0.rerank_sync(query, documents)?;
scores.sort_by(|lhs, rhs| {
rhs.score
.total_cmp(&lhs.score)
.then_with(|| lhs.original_rank.cmp(&rhs.original_rank))
.then_with(|| lhs.doc_id.cmp(&rhs.doc_id))
});
Ok(scores)
})
}
fn id(&self) -> &str {
self.0.id()
}
fn model_name(&self) -> &str {
self.0.model_name()
}
fn max_length(&self) -> usize {
self.0.max_length()
}
fn is_available(&self) -> bool {
self.0.is_available()
}
}
pub trait LexicalSearch: Send + Sync {
fn search<'a>(
&'a self,
cx: &'a Cx,
query: &'a str,
limit: usize,
) -> SearchFuture<'a, Vec<ScoredResult>>;
fn index_document<'a>(&'a self, cx: &'a Cx, doc: &'a IndexableDocument)
-> SearchFuture<'a, ()>;
fn index_documents<'a>(
&'a self,
cx: &'a Cx,
docs: &'a [IndexableDocument],
) -> SearchFuture<'a, ()> {
Box::pin(async move {
for doc in docs {
self.index_document(cx, doc).await?;
}
Ok(())
})
}
fn commit<'a>(&'a self, cx: &'a Cx) -> SearchFuture<'a, ()>;
fn doc_count(&self) -> usize;
}
pub trait MetricsExporter: fmt::Debug + Send + Sync {
fn on_search_completed(&self, metrics: &SearchMetrics);
fn on_embedding_completed(&self, metrics: &EmbeddingMetrics);
fn on_index_updated(&self, metrics: &IndexMetrics);
fn on_error(&self, error: &SearchError);
}
pub type SharedMetricsExporter = Arc<dyn MetricsExporter>;
#[derive(Debug, Default, Clone, Copy)]
pub struct NoOpMetricsExporter;
impl MetricsExporter for NoOpMetricsExporter {
fn on_search_completed(&self, _: &SearchMetrics) {}
fn on_embedding_completed(&self, _: &EmbeddingMetrics) {}
fn on_index_updated(&self, _: &IndexMetrics) {}
fn on_error(&self, _: &SearchError) {}
}
#[cfg(test)]
mod tests {
use asupersync::test_utils::run_test_with_cx;
use super::*;
struct UnsortedSyncReranker;
impl SyncRerank for UnsortedSyncReranker {
fn rerank_sync(
&self,
_query: &str,
_documents: &[RerankDocument],
) -> SearchResult<Vec<RerankScore>> {
Ok(vec![
RerankScore {
doc_id: "doc-a".to_owned(),
score: 0.8,
original_rank: 2,
raw_logit: None,
},
RerankScore {
doc_id: "doc-b".to_owned(),
score: 0.8,
original_rank: 1,
raw_logit: None,
},
RerankScore {
doc_id: "doc-c".to_owned(),
score: 0.3,
original_rank: 0,
raw_logit: None,
},
])
}
fn id(&self) -> &'static str {
"unsorted-sync-reranker"
}
fn model_name(&self) -> &'static str {
"Unsorted Sync Reranker"
}
}
#[test]
fn model_category_display() {
assert_eq!(ModelCategory::HashEmbedder.to_string(), "hash_embedder");
assert_eq!(ModelCategory::StaticEmbedder.to_string(), "static_embedder");
assert_eq!(
ModelCategory::TransformerEmbedder.to_string(),
"transformer_embedder"
);
}
#[test]
fn model_category_serialization() {
let json = serde_json::to_string(&ModelCategory::StaticEmbedder).unwrap();
let decoded: ModelCategory = serde_json::from_str(&json).unwrap();
assert_eq!(decoded, ModelCategory::StaticEmbedder);
}
#[test]
fn model_category_equality() {
assert_eq!(ModelCategory::HashEmbedder, ModelCategory::HashEmbedder);
assert_ne!(ModelCategory::HashEmbedder, ModelCategory::StaticEmbedder);
assert_ne!(
ModelCategory::StaticEmbedder,
ModelCategory::TransformerEmbedder
);
}
#[test]
fn model_category_default_tier() {
assert_eq!(ModelCategory::HashEmbedder.default_tier(), ModelTier::Fast);
assert_eq!(
ModelCategory::StaticEmbedder.default_tier(),
ModelTier::Fast
);
assert_eq!(
ModelCategory::TransformerEmbedder.default_tier(),
ModelTier::Quality
);
}
#[test]
fn model_tier_display() {
assert_eq!(ModelTier::Fast.to_string(), "fast");
assert_eq!(ModelTier::Quality.to_string(), "quality");
}
#[test]
fn model_info_roundtrip() {
let info = ModelInfo {
id: "potion-multilingual-128M".to_owned(),
name: "Potion 128M".to_owned(),
dimension: 256,
category: ModelCategory::StaticEmbedder,
tier: ModelTier::Fast,
is_semantic: true,
supports_mrl: false,
huggingface_id: Some("minishlab/potion-multilingual-128M".to_owned()),
size_bytes: Some(128_000_000),
license: Some("apache-2.0".to_owned()),
};
let json = serde_json::to_string(&info).unwrap();
let decoded: ModelInfo = serde_json::from_str(&json).unwrap();
assert_eq!(decoded, info);
}
#[test]
fn rerank_document_construction() {
let doc = RerankDocument {
doc_id: "doc-1".into(),
text: "Some content".into(),
};
assert_eq!(doc.doc_id, "doc-1");
assert_eq!(doc.text, "Some content");
}
#[test]
fn rerank_score_serialization() {
let score = RerankScore {
doc_id: "doc-1".into(),
score: 0.92,
original_rank: 3,
raw_logit: None,
};
let json = serde_json::to_string(&score).unwrap();
let decoded: RerankScore = serde_json::from_str(&json).unwrap();
assert_eq!(decoded.doc_id, "doc-1");
assert!((decoded.score - 0.92).abs() < 1e-6);
assert_eq!(decoded.original_rank, 3);
}
#[test]
fn embedder_trait_is_object_safe() {
fn _takes_dyn_embedder(_: &dyn Embedder) {}
}
#[test]
fn reranker_trait_is_object_safe() {
fn _takes_dyn_reranker(_: &dyn Reranker) {}
}
#[test]
fn lexical_search_trait_is_object_safe() {
fn _takes_dyn_lexical(_: &dyn LexicalSearch) {}
}
#[test]
fn metrics_exporter_trait_is_object_safe() {
fn _takes_dyn_metrics_exporter(_: &dyn MetricsExporter) {}
}
#[test]
fn sync_reranker_adapter_sorts_descending_for_trait_contract() {
run_test_with_cx(|cx| async move {
let adapter = SyncRerankerAdapter(UnsortedSyncReranker);
let docs = vec![
RerankDocument {
doc_id: "doc-a".to_owned(),
text: "alpha".to_owned(),
},
RerankDocument {
doc_id: "doc-b".to_owned(),
text: "beta".to_owned(),
},
RerankDocument {
doc_id: "doc-c".to_owned(),
text: "gamma".to_owned(),
},
];
let scores = adapter
.rerank(&cx, "query", &docs)
.await
.expect("adapter rerank should succeed");
let ids = scores
.iter()
.map(|score| score.doc_id.as_str())
.collect::<Vec<_>>();
assert_eq!(ids, vec!["doc-b", "doc-a", "doc-c"]);
});
}
#[test]
fn noop_metrics_exporter_callbacks_are_noops() {
let exporter = NoOpMetricsExporter;
let search_metrics = SearchMetrics {
mode: crate::types::SearchMode::Hybrid,
query_class: None,
total_latency_ms: 10.0,
phase1_latency_ms: Some(4.0),
phase2_latency_ms: Some(6.0),
result_count: 8,
lexical_candidates: 30,
semantic_candidates: 25,
refined: true,
};
let embedding_metrics = EmbeddingMetrics {
embedder_id: "fnv-hash-384".into(),
batch_size: 1,
duration_ms: 0.07,
dimension: 384,
is_semantic: false,
};
let index_metrics = IndexMetrics {
doc_count: 100,
index_size_bytes: 4096,
updated_docs: 1,
staleness_detected: false,
};
exporter.on_search_completed(&search_metrics);
exporter.on_embedding_completed(&embedding_metrics);
exporter.on_index_updated(&index_metrics);
exporter.on_error(&SearchError::SearchTimeout {
elapsed_ms: 11,
budget_ms: 10,
});
}
#[test]
fn l2_normalize_produces_unit_vector() {
let v = vec![3.0, 4.0];
let normalized = l2_normalize(&v);
let norm: f32 = normalized.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-6);
}
#[test]
fn l2_normalize_zero_vector() {
let v = vec![0.0, 0.0, 0.0];
let normalized = l2_normalize(&v);
assert!(normalized.iter().all(|&x| x == 0.0));
}
#[test]
fn cosine_similarity_identical() {
let v = vec![1.0, 2.0, 3.0];
let sim = cosine_similarity(&v, &v);
assert!((sim - 1.0).abs() < 1e-6);
}
#[test]
fn cosine_similarity_orthogonal() {
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
assert!(cosine_similarity(&a, &b).abs() < 1e-6);
}
#[test]
fn cosine_similarity_zero_vector() {
let a = vec![1.0, 2.0];
let b = vec![0.0, 0.0];
assert!(cosine_similarity(&a, &b).abs() < f32::EPSILON);
}
#[test]
fn truncate_embedding_reduces_dim() {
let v = vec![1.0, 2.0, 3.0, 4.0];
let t = truncate_embedding(&v, 2);
assert_eq!(t.len(), 2);
let norm: f32 = t.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-6);
}
#[test]
fn truncate_embedding_noop_when_larger() {
let v = vec![1.0, 2.0];
assert_eq!(truncate_embedding(&v, 10), v);
}
#[test]
fn model_category_default_semantic_flag() {
assert!(!ModelCategory::HashEmbedder.default_semantic_flag());
assert!(ModelCategory::StaticEmbedder.default_semantic_flag());
assert!(ModelCategory::TransformerEmbedder.default_semantic_flag());
}
}