reflex/embedding/sinter/
mod.rs

1//! Sinter embedder (GGUF + tokenizer).
2//!
3//! Use [`SinterConfig::stub`] for tests/examples without model files.
4
5/// Sinter configuration.
6pub mod config;
7pub(crate) mod model;
8
9#[cfg(test)]
10mod tests;
11
12pub use config::{SINTER_EMBEDDING_DIM, SINTER_MAX_SEQ_LEN, SinterConfig};
13
14use std::sync::Arc;
15
16use candle_core::{Device, IndexOp, Tensor};
17use half::f16;
18use parking_lot::Mutex;
19use tracing::{debug, info, warn};
20
21use crate::embedding::device::select_device;
22use crate::embedding::error::EmbeddingError;
23use crate::embedding::utils::load_tokenizer;
24
25use model::Qwen2ForEmbedding;
26
27enum EmbedderBackend {
28    Model {
29        model: Arc<Mutex<Qwen2ForEmbedding>>,
30        tokenizer: Arc<tokenizers::Tokenizer>,
31        device: Device,
32    },
33    Stub {
34        device: Device,
35    },
36}
37
38/// Embedding generator for semantic search (supports stub mode).
39pub struct SinterEmbedder {
40    backend: EmbedderBackend,
41    config: SinterConfig,
42}
43
44impl std::fmt::Debug for SinterEmbedder {
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        f.debug_struct("SinterEmbedder")
47            .field(
48                "backend",
49                &match &self.backend {
50                    EmbedderBackend::Model { device, .. } => format!("Model({:?})", device),
51                    EmbedderBackend::Stub { device } => format!("Stub({:?})", device),
52                },
53            )
54            .field("embedding_dim", &self.config.embedding_dim)
55            .field("max_seq_len", &self.config.max_seq_len)
56            .finish()
57    }
58}
59
60impl SinterEmbedder {
61    /// Loads the embedder from a config (stub mode is supported).
62    pub fn load(config: SinterConfig) -> Result<Self, EmbeddingError> {
63        config.validate()?;
64
65        let device = select_device()?;
66        debug!(?device, "Selected compute device for Sinter");
67
68        if config.testing_stub {
69            warn!("Sinter running in STUB mode (testing only)");
70            return Ok(Self {
71                backend: EmbedderBackend::Stub { device },
72                config,
73            });
74        }
75
76        if !config.model_available() || !config.tokenizer_available() {
77            return Err(EmbeddingError::ModelNotFound {
78                path: config.model_path.clone(),
79            });
80        }
81
82        let (model, tokenizer) = Self::load_model(&config, &device)?;
83
84        info!(
85            model_path = %config.model_path.display(),
86            embedding_dim = config.embedding_dim,
87            max_seq_len = config.max_seq_len,
88            hidden_size = model.config().hidden_size,
89            num_layers = model.config().num_layers,
90            "Sinter model loaded successfully (full transformer)"
91        );
92
93        Ok(Self {
94            backend: EmbedderBackend::Model {
95                model: Arc::new(Mutex::new(model)),
96                tokenizer: Arc::new(tokenizer),
97                device,
98            },
99            config,
100        })
101    }
102
103    fn load_model(
104        config: &SinterConfig,
105        device: &Device,
106    ) -> Result<(Qwen2ForEmbedding, tokenizers::Tokenizer), EmbeddingError> {
107        let tokenizer = load_tokenizer(&config.tokenizer_path).map_err(|e| {
108            EmbeddingError::TokenizationFailed {
109                reason: format!("Failed to load tokenizer: {}", e),
110            }
111        })?;
112
113        let mut model_file = std::fs::File::open(&config.model_path)?;
114        let model_content = candle_core::quantized::gguf_file::Content::read(&mut model_file)
115            .map_err(|e| EmbeddingError::ModelLoadFailed {
116                reason: format!("Failed to read GGUF content: {}", e),
117            })?;
118
119        let model = Qwen2ForEmbedding::from_gguf(
120            model_content,
121            &mut model_file,
122            device,
123            config.max_seq_len,
124        )
125        .map_err(|e| EmbeddingError::ModelLoadFailed {
126            reason: format!("Failed to load Qwen2 model: {}", e),
127        })?;
128
129        // Validate embedding dimension
130        if config.embedding_dim > model.config().hidden_size {
131            return Err(EmbeddingError::InvalidConfig {
132                reason: format!(
133                    "embedding_dim ({}) exceeds model hidden_size ({})",
134                    config.embedding_dim,
135                    model.config().hidden_size
136                ),
137            });
138        }
139
140        info!(
141            hidden_size = model.config().hidden_size,
142            num_layers = model.config().num_layers,
143            "Qwen2 transformer loaded"
144        );
145
146        Ok((model, tokenizer))
147    }
148
149    /// Generates an embedding for a single string.
150    pub fn embed(&self, text: &str) -> Result<Vec<f16>, EmbeddingError> {
151        match &self.backend {
152            EmbedderBackend::Model {
153                model,
154                tokenizer,
155                device,
156            } => self.embed_with_model(text, model, tokenizer, device),
157            EmbedderBackend::Stub { .. } => self.embed_stub(text),
158        }
159    }
160
161    /// Generates embeddings for a batch of strings.
162    pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f16>>, EmbeddingError> {
163        if texts.is_empty() {
164            return Ok(vec![]);
165        }
166
167        match &self.backend {
168            EmbedderBackend::Model {
169                model,
170                tokenizer,
171                device,
172            } => self.embed_batch_with_model(texts, model, tokenizer, device),
173            EmbedderBackend::Stub { .. } => {
174                texts.iter().map(|text| self.embed_stub(text)).collect()
175            }
176        }
177    }
178
179    fn embed_with_model(
180        &self,
181        text: &str,
182        model: &Arc<Mutex<Qwen2ForEmbedding>>,
183        tokenizer: &tokenizers::Tokenizer,
184        device: &Device,
185    ) -> Result<Vec<f16>, EmbeddingError> {
186        let encoding =
187            tokenizer
188                .encode(text, true)
189                .map_err(|e| EmbeddingError::TokenizationFailed {
190                    reason: e.to_string(),
191                })?;
192
193        let mut tokens: Vec<u32> = encoding.get_ids().to_vec();
194        if tokens.is_empty() {
195            return Ok(vec![f16::from_f32(0.0); self.config.embedding_dim]);
196        }
197
198        if tokens.len() > self.config.max_seq_len {
199            tokens.truncate(self.config.max_seq_len);
200        }
201
202        debug!(
203            text_len = text.len(),
204            token_count = tokens.len(),
205            "Generating embedding (transformer forward pass)"
206        );
207
208        // Create input tensor: [1, seq_len]
209        let input_ids = Tensor::new(&tokens[..], device)
210            .map_err(|e| EmbeddingError::InferenceFailed {
211                reason: format!("Failed to create input tensor: {}", e),
212            })?
213            .unsqueeze(0)
214            .map_err(|e| EmbeddingError::InferenceFailed {
215                reason: format!("Failed to unsqueeze input: {}", e),
216            })?;
217
218        // Run full transformer forward pass
219        let hidden_states =
220            model
221                .lock()
222                .forward(&input_ids)
223                .map_err(|e| EmbeddingError::InferenceFailed {
224                    reason: format!("Transformer forward pass failed: {}", e),
225                })?;
226
227        // Last-token pooling: extract the final token's hidden state
228        // hidden_states shape: [1, seq_len, hidden_size]
229        let last_idx = tokens.len() - 1;
230        let embedding = hidden_states
231            .i((0, last_idx, ..self.config.embedding_dim))
232            .map_err(|e| EmbeddingError::InferenceFailed {
233                reason: format!("Failed to extract last token embedding: {}", e),
234            })?
235            .to_vec1::<f32>()
236            .map_err(|e| EmbeddingError::InferenceFailed {
237                reason: format!("Failed to convert embedding to vec: {}", e),
238            })?;
239
240        Ok(self.normalize_and_convert_f16(embedding))
241    }
242
243    fn embed_batch_with_model(
244        &self,
245        texts: &[&str],
246        model: &Arc<Mutex<Qwen2ForEmbedding>>,
247        tokenizer: &tokenizers::Tokenizer,
248        device: &Device,
249    ) -> Result<Vec<Vec<f16>>, EmbeddingError> {
250        // Process sequentially for now (proper batching would need padding)
251        let mut results = Vec::with_capacity(texts.len());
252        for text in texts {
253            results.push(self.embed_with_model(text, model, tokenizer, device)?);
254        }
255        Ok(results)
256    }
257
258    fn embed_stub(&self, text: &str) -> Result<Vec<f16>, EmbeddingError> {
259        use std::hash::{DefaultHasher, Hash, Hasher};
260
261        debug!(text_len = text.len(), "Generating stub embedding");
262
263        let mut hasher = DefaultHasher::new();
264        text.hash(&mut hasher);
265        let seed = hasher.finish();
266
267        let mut embedding = Vec::with_capacity(self.config.embedding_dim);
268        let mut state = seed;
269
270        for _ in 0..self.config.embedding_dim {
271            state = state.wrapping_mul(6364136223846793005).wrapping_add(1);
272            let value = ((state >> 32) as f32 / u32::MAX as f32) * 2.0 - 1.0;
273            embedding.push(value);
274        }
275
276        let result = self.normalize_and_convert_f16(embedding);
277
278        Ok(result)
279    }
280
281    fn normalize_and_convert_f16(&self, mut embedding: Vec<f32>) -> Vec<f16> {
282        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
283
284        if norm > 0.0 {
285            for x in &mut embedding {
286                *x /= norm;
287            }
288        }
289
290        embedding.into_iter().map(f16::from_f32).collect()
291    }
292
293    /// Returns the configured output embedding dimension.
294    pub fn embedding_dim(&self) -> usize {
295        self.config.embedding_dim
296    }
297
298    /// Returns `true` if running in stub mode.
299    pub fn is_stub(&self) -> bool {
300        matches!(self.backend, EmbedderBackend::Stub { .. })
301    }
302
303    /// Returns `true` if a model is loaded.
304    pub fn has_model(&self) -> bool {
305        matches!(self.backend, EmbedderBackend::Model { .. })
306    }
307
308    /// Returns the embedder configuration.
309    pub fn config(&self) -> &SinterConfig {
310        &self.config
311    }
312}