Skip to main content

lexa_core/
embed.rs

1//! Embedding + reranking inference.
2//!
3//! Production backbone is **`nomic-ai/nomic-embed-text-v1.5` (quantized)**:
4//! 768-dim, Matryoshka-trained, Apache-2 licensed. The deep-tier reranker is
5//! `BAAI/bge-reranker-base`. The `Hash` backend is a deterministic
6//! FNV-1a-into-fixed-dim fallback used only by tests and CI smoke runs.
7
8use fastembed::{
9    EmbeddingModel, InitOptions, RerankInitOptions, RerankerModel, TextEmbedding, TextRerank,
10};
11
12use crate::{LexaError, Result};
13
14/// Native embedding dimension. Tied to the model below; both must change
15/// together (and the existing `vectors_bin` schema must be re-indexed).
16pub const EMBEDDING_DIMS: usize = 768;
17
18/// Matryoshka preview dimension. Nomic v1.5 was MRL-trained at canonical
19/// prefix dims {64, 128, 256, 512, 768}; 256 is the published Exa choice
20/// for the coarse retrieval pass. We store a second binary-quantized index
21/// at this width and use it as the first stage of two-stage KNN — coarse
22/// 256-bit Hamming over the whole corpus, then full 768-bit Hamming over
23/// the top-K survivors.
24pub const PREVIEW_DIMS: usize = 256;
25
26/// Task prefix prepended to *queries* before embedding. Nomic v1.5 was
27/// trained with this asymmetric pair; using a query without the prefix
28/// silently drops nDCG by several points.
29const QUERY_PREFIX: &str = "search_query: ";
30
31/// Task prefix prepended to *documents* before embedding.
32const DOCUMENT_PREFIX: &str = "search_document: ";
33
34#[derive(Debug, Clone, Copy, Eq, PartialEq)]
35pub enum EmbeddingBackend {
36    /// Real ONNX inference via fastembed-rs.
37    FastEmbed,
38    /// Deterministic FNV-1a hashing into a fixed-dim vector. CI / offline only.
39    Hash,
40}
41
42#[derive(Debug, Clone)]
43pub struct EmbeddingConfig {
44    pub backend: EmbeddingBackend,
45    pub show_download_progress: bool,
46}
47
48impl Default for EmbeddingConfig {
49    fn default() -> Self {
50        let backend = match std::env::var("LEXA_EMBEDDER").ok().as_deref() {
51            Some("hash") => EmbeddingBackend::Hash,
52            _ => EmbeddingBackend::FastEmbed,
53        };
54        Self {
55            backend,
56            show_download_progress: true,
57        }
58    }
59}
60
61pub enum Embedder {
62    Fast(Box<TextEmbedding>),
63    Hash,
64}
65
66impl Embedder {
67    pub fn new(config: &EmbeddingConfig) -> Result<Self> {
68        match config.backend {
69            EmbeddingBackend::Hash => Ok(Self::Hash),
70            EmbeddingBackend::FastEmbed => {
71                let options = InitOptions::new(EmbeddingModel::NomicEmbedTextV15Q)
72                    .with_show_download_progress(config.show_download_progress);
73                TextEmbedding::try_new(options)
74                    .map(Box::new)
75                    .map(Self::Fast)
76                    .map_err(|error| LexaError::Embedding(error.to_string()))
77            }
78        }
79    }
80
81    /// Encode a batch of *documents* (passages to be indexed).
82    pub fn embed_documents(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
83        let prefixed: Vec<String> = match self {
84            Self::Fast(_) => texts
85                .iter()
86                .map(|text| format!("{DOCUMENT_PREFIX}{text}"))
87                .collect(),
88            Self::Hash => texts.to_vec(),
89        };
90        self.encode(&prefixed)
91    }
92
93    /// Encode a *query* string. Symmetric with `embed_documents` — without
94    /// the matching task prefix, retrieval quality drops measurably.
95    pub fn embed_query(&mut self, query: &str) -> Result<Vec<f32>> {
96        let prefixed = match self {
97            Self::Fast(_) => format!("{QUERY_PREFIX}{query}"),
98            Self::Hash => query.to_string(),
99        };
100        Ok(self.encode(&[prefixed])?.remove(0))
101    }
102
103    fn encode(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
104        match self {
105            Self::Fast(model) => model
106                .embed(texts, None)
107                .map_err(|error| LexaError::Embedding(error.to_string())),
108            Self::Hash => Ok(texts.iter().map(|text| hash_embedding(text)).collect()),
109        }
110    }
111}
112
113pub enum Reranker {
114    Fast(Box<TextRerank>),
115    Hash,
116}
117
118impl Reranker {
119    pub fn new(config: &EmbeddingConfig) -> Result<Self> {
120        match config.backend {
121            EmbeddingBackend::Hash => Ok(Self::Hash),
122            EmbeddingBackend::FastEmbed => {
123                let options = RerankInitOptions::new(RerankerModel::BGERerankerBase)
124                    .with_show_download_progress(config.show_download_progress);
125                TextRerank::try_new(options)
126                    .map(Box::new)
127                    .map(Self::Fast)
128                    .map_err(|error| LexaError::Embedding(error.to_string()))
129            }
130        }
131    }
132
133    pub fn rerank(&mut self, query: &str, documents: &[String]) -> Result<Vec<(usize, f32)>> {
134        match self {
135            Self::Fast(model) => {
136                let refs: Vec<&str> = documents.iter().map(String::as_str).collect();
137                model
138                    .rerank(query, refs, false, None)
139                    .map(|items| {
140                        items
141                            .into_iter()
142                            .map(|item| (item.index, item.score))
143                            .collect()
144                    })
145                    .map_err(|error| LexaError::Embedding(error.to_string()))
146            }
147            Self::Hash => {
148                let q = hash_embedding(query);
149                let mut scores: Vec<(usize, f32)> = documents
150                    .iter()
151                    .enumerate()
152                    .map(|(idx, text)| (idx, cosine(&q, &hash_embedding(text))))
153                    .collect();
154                scores.sort_by(|left, right| {
155                    right
156                        .1
157                        .partial_cmp(&left.1)
158                        .unwrap_or(std::cmp::Ordering::Equal)
159                });
160                Ok(scores)
161            }
162        }
163    }
164}
165
166/// Truncate a Matryoshka-trained embedding to a smaller prefix and
167/// re-normalize. Nomic v1.5 is MRL-trained at canonical dims 64, 128, 256,
168/// 512, 768, so any prefix is a valid embedding in the same vector space.
169/// fastembed already returns L2-normalized embeddings; we re-normalize after
170/// truncation so cosine scores stay in `[-1, 1]`.
171pub fn matryoshka_truncate(vector: &[f32], target_dims: usize) -> Vec<f32> {
172    let take = target_dims.min(vector.len());
173    let mut out = vector[..take].to_vec();
174    let norm = out.iter().map(|value| value * value).sum::<f32>().sqrt();
175    if norm > 0.0 {
176        for value in &mut out {
177            *value /= norm;
178        }
179    }
180    out
181}
182
183pub fn hash_embedding(text: &str) -> Vec<f32> {
184    let mut out = vec![0.0; EMBEDDING_DIMS];
185    for token in tokenize(text) {
186        let hash = fnv1a(token.as_bytes());
187        let idx = (hash as usize) % EMBEDDING_DIMS;
188        let sign = if hash & 1 == 0 { 1.0 } else { -1.0 };
189        out[idx] += sign;
190    }
191    normalize(&mut out);
192    out
193}
194
195fn tokenize(text: &str) -> Vec<String> {
196    text.split(|ch: char| !ch.is_ascii_alphanumeric())
197        .filter_map(|raw| {
198            let token = raw.trim().to_ascii_lowercase();
199            (token.len() > 1).then_some(token)
200        })
201        .collect()
202}
203
204fn normalize(values: &mut [f32]) {
205    let norm = values.iter().map(|value| value * value).sum::<f32>().sqrt();
206    if norm > 0.0 {
207        for value in values {
208            *value /= norm;
209        }
210    }
211}
212
213pub fn cosine(left: &[f32], right: &[f32]) -> f32 {
214    left.iter().zip(right.iter()).map(|(l, r)| l * r).sum()
215}
216
217/// Pack an f32 embedding into a raw little-endian byte buffer.
218///
219/// `sqlite-vec` accepts both JSON arrays and raw f32 BLOBs of length
220/// `dims * 4` bytes; the BLOB form skips the JSON tokenizer on every insert
221/// and every query. Both x86_64 and arm64 are little-endian, so `to_ne_bytes`
222/// matches what sqlite-vec's `memcpy` reader expects.
223pub fn vector_blob(vector: &[f32]) -> Vec<u8> {
224    let mut out = Vec::with_capacity(std::mem::size_of_val(vector));
225    for value in vector {
226        out.extend_from_slice(&value.to_ne_bytes());
227    }
228    out
229}
230
231fn fnv1a(bytes: &[u8]) -> u64 {
232    let mut hash = 0xcbf29ce484222325u64;
233    for byte in bytes {
234        hash ^= u64::from(*byte);
235        hash = hash.wrapping_mul(0x100000001b3);
236    }
237    hash
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243
244    #[test]
245    fn matryoshka_truncate_normalizes() {
246        let v = vec![3.0, 4.0, 0.0, 0.0];
247        let t = matryoshka_truncate(&v, 2);
248        assert_eq!(t.len(), 2);
249        let norm = t.iter().map(|value| value * value).sum::<f32>().sqrt();
250        assert!((norm - 1.0).abs() < 1e-6);
251        assert!((t[0] - 0.6).abs() < 1e-6);
252        assert!((t[1] - 0.8).abs() < 1e-6);
253    }
254
255    #[test]
256    fn matryoshka_truncate_caps_at_input_len() {
257        let v = vec![1.0, 0.0, 0.0];
258        assert_eq!(matryoshka_truncate(&v, 8).len(), 3);
259    }
260
261    #[test]
262    fn hash_embedding_has_canonical_dims() {
263        assert_eq!(hash_embedding("hello world").len(), EMBEDDING_DIMS);
264    }
265}