1use fastembed::{
9 EmbeddingModel, InitOptions, RerankInitOptions, RerankerModel, TextEmbedding, TextRerank,
10};
11
12use crate::{LexaError, Result};
13
14pub const EMBEDDING_DIMS: usize = 768;
17
18pub const PREVIEW_DIMS: usize = 256;
25
26const QUERY_PREFIX: &str = "search_query: ";
30
31const DOCUMENT_PREFIX: &str = "search_document: ";
33
34#[derive(Debug, Clone, Copy, Eq, PartialEq)]
35pub enum EmbeddingBackend {
36 FastEmbed,
38 Hash,
40}
41
42#[derive(Debug, Clone)]
43pub struct EmbeddingConfig {
44 pub backend: EmbeddingBackend,
45 pub show_download_progress: bool,
46}
47
48impl Default for EmbeddingConfig {
49 fn default() -> Self {
50 let backend = match std::env::var("LEXA_EMBEDDER").ok().as_deref() {
51 Some("hash") => EmbeddingBackend::Hash,
52 _ => EmbeddingBackend::FastEmbed,
53 };
54 Self {
55 backend,
56 show_download_progress: true,
57 }
58 }
59}
60
61pub enum Embedder {
62 Fast(Box<TextEmbedding>),
63 Hash,
64}
65
66impl Embedder {
67 pub fn new(config: &EmbeddingConfig) -> Result<Self> {
68 match config.backend {
69 EmbeddingBackend::Hash => Ok(Self::Hash),
70 EmbeddingBackend::FastEmbed => {
71 let options = InitOptions::new(EmbeddingModel::NomicEmbedTextV15Q)
72 .with_show_download_progress(config.show_download_progress);
73 TextEmbedding::try_new(options)
74 .map(Box::new)
75 .map(Self::Fast)
76 .map_err(|error| LexaError::Embedding(error.to_string()))
77 }
78 }
79 }
80
81 pub fn embed_documents(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
83 let prefixed: Vec<String> = match self {
84 Self::Fast(_) => texts
85 .iter()
86 .map(|text| format!("{DOCUMENT_PREFIX}{text}"))
87 .collect(),
88 Self::Hash => texts.to_vec(),
89 };
90 self.encode(&prefixed)
91 }
92
93 pub fn embed_query(&mut self, query: &str) -> Result<Vec<f32>> {
96 let prefixed = match self {
97 Self::Fast(_) => format!("{QUERY_PREFIX}{query}"),
98 Self::Hash => query.to_string(),
99 };
100 Ok(self.encode(&[prefixed])?.remove(0))
101 }
102
103 fn encode(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
104 match self {
105 Self::Fast(model) => model
106 .embed(texts, None)
107 .map_err(|error| LexaError::Embedding(error.to_string())),
108 Self::Hash => Ok(texts.iter().map(|text| hash_embedding(text)).collect()),
109 }
110 }
111}
112
113pub enum Reranker {
114 Fast(Box<TextRerank>),
115 Hash,
116}
117
118impl Reranker {
119 pub fn new(config: &EmbeddingConfig) -> Result<Self> {
120 match config.backend {
121 EmbeddingBackend::Hash => Ok(Self::Hash),
122 EmbeddingBackend::FastEmbed => {
123 let options = RerankInitOptions::new(RerankerModel::BGERerankerBase)
124 .with_show_download_progress(config.show_download_progress);
125 TextRerank::try_new(options)
126 .map(Box::new)
127 .map(Self::Fast)
128 .map_err(|error| LexaError::Embedding(error.to_string()))
129 }
130 }
131 }
132
133 pub fn rerank(&mut self, query: &str, documents: &[String]) -> Result<Vec<(usize, f32)>> {
134 match self {
135 Self::Fast(model) => {
136 let refs: Vec<&str> = documents.iter().map(String::as_str).collect();
137 model
138 .rerank(query, refs, false, None)
139 .map(|items| {
140 items
141 .into_iter()
142 .map(|item| (item.index, item.score))
143 .collect()
144 })
145 .map_err(|error| LexaError::Embedding(error.to_string()))
146 }
147 Self::Hash => {
148 let q = hash_embedding(query);
149 let mut scores: Vec<(usize, f32)> = documents
150 .iter()
151 .enumerate()
152 .map(|(idx, text)| (idx, cosine(&q, &hash_embedding(text))))
153 .collect();
154 scores.sort_by(|left, right| {
155 right
156 .1
157 .partial_cmp(&left.1)
158 .unwrap_or(std::cmp::Ordering::Equal)
159 });
160 Ok(scores)
161 }
162 }
163 }
164}
165
166pub fn matryoshka_truncate(vector: &[f32], target_dims: usize) -> Vec<f32> {
172 let take = target_dims.min(vector.len());
173 let mut out = vector[..take].to_vec();
174 let norm = out.iter().map(|value| value * value).sum::<f32>().sqrt();
175 if norm > 0.0 {
176 for value in &mut out {
177 *value /= norm;
178 }
179 }
180 out
181}
182
183pub fn hash_embedding(text: &str) -> Vec<f32> {
184 let mut out = vec![0.0; EMBEDDING_DIMS];
185 for token in tokenize(text) {
186 let hash = fnv1a(token.as_bytes());
187 let idx = (hash as usize) % EMBEDDING_DIMS;
188 let sign = if hash & 1 == 0 { 1.0 } else { -1.0 };
189 out[idx] += sign;
190 }
191 normalize(&mut out);
192 out
193}
194
195fn tokenize(text: &str) -> Vec<String> {
196 text.split(|ch: char| !ch.is_ascii_alphanumeric())
197 .filter_map(|raw| {
198 let token = raw.trim().to_ascii_lowercase();
199 (token.len() > 1).then_some(token)
200 })
201 .collect()
202}
203
204fn normalize(values: &mut [f32]) {
205 let norm = values.iter().map(|value| value * value).sum::<f32>().sqrt();
206 if norm > 0.0 {
207 for value in values {
208 *value /= norm;
209 }
210 }
211}
212
213pub fn cosine(left: &[f32], right: &[f32]) -> f32 {
214 left.iter().zip(right.iter()).map(|(l, r)| l * r).sum()
215}
216
217pub fn vector_blob(vector: &[f32]) -> Vec<u8> {
224 let mut out = Vec::with_capacity(std::mem::size_of_val(vector));
225 for value in vector {
226 out.extend_from_slice(&value.to_ne_bytes());
227 }
228 out
229}
230
231fn fnv1a(bytes: &[u8]) -> u64 {
232 let mut hash = 0xcbf29ce484222325u64;
233 for byte in bytes {
234 hash ^= u64::from(*byte);
235 hash = hash.wrapping_mul(0x100000001b3);
236 }
237 hash
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243
244 #[test]
245 fn matryoshka_truncate_normalizes() {
246 let v = vec![3.0, 4.0, 0.0, 0.0];
247 let t = matryoshka_truncate(&v, 2);
248 assert_eq!(t.len(), 2);
249 let norm = t.iter().map(|value| value * value).sum::<f32>().sqrt();
250 assert!((norm - 1.0).abs() < 1e-6);
251 assert!((t[0] - 0.6).abs() < 1e-6);
252 assert!((t[1] - 0.8).abs() < 1e-6);
253 }
254
255 #[test]
256 fn matryoshka_truncate_caps_at_input_len() {
257 let v = vec![1.0, 0.0, 0.0];
258 assert_eq!(matryoshka_truncate(&v, 8).len(), 3);
259 }
260
261 #[test]
262 fn hash_embedding_has_canonical_dims() {
263 assert_eq!(hash_embedding("hello world").len(), EMBEDDING_DIMS);
264 }
265}