1pub mod config;
7pub(crate) mod model;
8
9#[cfg(test)]
10mod tests;
11
12pub use config::{SINTER_EMBEDDING_DIM, SINTER_MAX_SEQ_LEN, SinterConfig};
13
14use std::sync::Arc;
15
16use candle_core::{Device, IndexOp, Tensor};
17use half::f16;
18use parking_lot::Mutex;
19use tracing::{debug, info, warn};
20
21use crate::embedding::device::select_device;
22use crate::embedding::error::EmbeddingError;
23use crate::embedding::utils::load_tokenizer;
24
25use model::Qwen2ForEmbedding;
26
27enum EmbedderBackend {
28 Model {
29 model: Arc<Mutex<Qwen2ForEmbedding>>,
30 tokenizer: Arc<tokenizers::Tokenizer>,
31 device: Device,
32 },
33 Stub {
34 device: Device,
35 },
36}
37
38pub struct SinterEmbedder {
40 backend: EmbedderBackend,
41 config: SinterConfig,
42}
43
44impl std::fmt::Debug for SinterEmbedder {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 f.debug_struct("SinterEmbedder")
47 .field(
48 "backend",
49 &match &self.backend {
50 EmbedderBackend::Model { device, .. } => format!("Model({:?})", device),
51 EmbedderBackend::Stub { device } => format!("Stub({:?})", device),
52 },
53 )
54 .field("embedding_dim", &self.config.embedding_dim)
55 .field("max_seq_len", &self.config.max_seq_len)
56 .finish()
57 }
58}
59
60impl SinterEmbedder {
61 pub fn load(config: SinterConfig) -> Result<Self, EmbeddingError> {
63 config.validate()?;
64
65 let device = select_device()?;
66 debug!(?device, "Selected compute device for Sinter");
67
68 if config.testing_stub {
69 warn!("Sinter running in STUB mode (testing only)");
70 return Ok(Self {
71 backend: EmbedderBackend::Stub { device },
72 config,
73 });
74 }
75
76 if !config.model_available() || !config.tokenizer_available() {
77 return Err(EmbeddingError::ModelNotFound {
78 path: config.model_path.clone(),
79 });
80 }
81
82 let (model, tokenizer) = Self::load_model(&config, &device)?;
83
84 info!(
85 model_path = %config.model_path.display(),
86 embedding_dim = config.embedding_dim,
87 max_seq_len = config.max_seq_len,
88 hidden_size = model.config().hidden_size,
89 num_layers = model.config().num_layers,
90 "Sinter model loaded successfully (full transformer)"
91 );
92
93 Ok(Self {
94 backend: EmbedderBackend::Model {
95 model: Arc::new(Mutex::new(model)),
96 tokenizer: Arc::new(tokenizer),
97 device,
98 },
99 config,
100 })
101 }
102
103 fn load_model(
104 config: &SinterConfig,
105 device: &Device,
106 ) -> Result<(Qwen2ForEmbedding, tokenizers::Tokenizer), EmbeddingError> {
107 let tokenizer = load_tokenizer(&config.tokenizer_path).map_err(|e| {
108 EmbeddingError::TokenizationFailed {
109 reason: format!("Failed to load tokenizer: {}", e),
110 }
111 })?;
112
113 let mut model_file = std::fs::File::open(&config.model_path)?;
114 let model_content = candle_core::quantized::gguf_file::Content::read(&mut model_file)
115 .map_err(|e| EmbeddingError::ModelLoadFailed {
116 reason: format!("Failed to read GGUF content: {}", e),
117 })?;
118
119 let model = Qwen2ForEmbedding::from_gguf(
120 model_content,
121 &mut model_file,
122 device,
123 config.max_seq_len,
124 )
125 .map_err(|e| EmbeddingError::ModelLoadFailed {
126 reason: format!("Failed to load Qwen2 model: {}", e),
127 })?;
128
129 if config.embedding_dim > model.config().hidden_size {
131 return Err(EmbeddingError::InvalidConfig {
132 reason: format!(
133 "embedding_dim ({}) exceeds model hidden_size ({})",
134 config.embedding_dim,
135 model.config().hidden_size
136 ),
137 });
138 }
139
140 info!(
141 hidden_size = model.config().hidden_size,
142 num_layers = model.config().num_layers,
143 "Qwen2 transformer loaded"
144 );
145
146 Ok((model, tokenizer))
147 }
148
149 pub fn embed(&self, text: &str) -> Result<Vec<f16>, EmbeddingError> {
151 match &self.backend {
152 EmbedderBackend::Model {
153 model,
154 tokenizer,
155 device,
156 } => self.embed_with_model(text, model, tokenizer, device),
157 EmbedderBackend::Stub { .. } => self.embed_stub(text),
158 }
159 }
160
161 pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f16>>, EmbeddingError> {
163 if texts.is_empty() {
164 return Ok(vec![]);
165 }
166
167 match &self.backend {
168 EmbedderBackend::Model {
169 model,
170 tokenizer,
171 device,
172 } => self.embed_batch_with_model(texts, model, tokenizer, device),
173 EmbedderBackend::Stub { .. } => {
174 texts.iter().map(|text| self.embed_stub(text)).collect()
175 }
176 }
177 }
178
179 fn embed_with_model(
180 &self,
181 text: &str,
182 model: &Arc<Mutex<Qwen2ForEmbedding>>,
183 tokenizer: &tokenizers::Tokenizer,
184 device: &Device,
185 ) -> Result<Vec<f16>, EmbeddingError> {
186 let encoding =
187 tokenizer
188 .encode(text, true)
189 .map_err(|e| EmbeddingError::TokenizationFailed {
190 reason: e.to_string(),
191 })?;
192
193 let mut tokens: Vec<u32> = encoding.get_ids().to_vec();
194 if tokens.is_empty() {
195 return Ok(vec![f16::from_f32(0.0); self.config.embedding_dim]);
196 }
197
198 if tokens.len() > self.config.max_seq_len {
199 tokens.truncate(self.config.max_seq_len);
200 }
201
202 debug!(
203 text_len = text.len(),
204 token_count = tokens.len(),
205 "Generating embedding (transformer forward pass)"
206 );
207
208 let input_ids = Tensor::new(&tokens[..], device)
210 .map_err(|e| EmbeddingError::InferenceFailed {
211 reason: format!("Failed to create input tensor: {}", e),
212 })?
213 .unsqueeze(0)
214 .map_err(|e| EmbeddingError::InferenceFailed {
215 reason: format!("Failed to unsqueeze input: {}", e),
216 })?;
217
218 let hidden_states =
220 model
221 .lock()
222 .forward(&input_ids)
223 .map_err(|e| EmbeddingError::InferenceFailed {
224 reason: format!("Transformer forward pass failed: {}", e),
225 })?;
226
227 let last_idx = tokens.len() - 1;
230 let embedding = hidden_states
231 .i((0, last_idx, ..self.config.embedding_dim))
232 .map_err(|e| EmbeddingError::InferenceFailed {
233 reason: format!("Failed to extract last token embedding: {}", e),
234 })?
235 .to_vec1::<f32>()
236 .map_err(|e| EmbeddingError::InferenceFailed {
237 reason: format!("Failed to convert embedding to vec: {}", e),
238 })?;
239
240 Ok(self.normalize_and_convert_f16(embedding))
241 }
242
243 fn embed_batch_with_model(
244 &self,
245 texts: &[&str],
246 model: &Arc<Mutex<Qwen2ForEmbedding>>,
247 tokenizer: &tokenizers::Tokenizer,
248 device: &Device,
249 ) -> Result<Vec<Vec<f16>>, EmbeddingError> {
250 let mut results = Vec::with_capacity(texts.len());
252 for text in texts {
253 results.push(self.embed_with_model(text, model, tokenizer, device)?);
254 }
255 Ok(results)
256 }
257
258 fn embed_stub(&self, text: &str) -> Result<Vec<f16>, EmbeddingError> {
259 use std::hash::{DefaultHasher, Hash, Hasher};
260
261 debug!(text_len = text.len(), "Generating stub embedding");
262
263 let mut hasher = DefaultHasher::new();
264 text.hash(&mut hasher);
265 let seed = hasher.finish();
266
267 let mut embedding = Vec::with_capacity(self.config.embedding_dim);
268 let mut state = seed;
269
270 for _ in 0..self.config.embedding_dim {
271 state = state.wrapping_mul(6364136223846793005).wrapping_add(1);
272 let value = ((state >> 32) as f32 / u32::MAX as f32) * 2.0 - 1.0;
273 embedding.push(value);
274 }
275
276 let result = self.normalize_and_convert_f16(embedding);
277
278 Ok(result)
279 }
280
281 fn normalize_and_convert_f16(&self, mut embedding: Vec<f32>) -> Vec<f16> {
282 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
283
284 if norm > 0.0 {
285 for x in &mut embedding {
286 *x /= norm;
287 }
288 }
289
290 embedding.into_iter().map(f16::from_f32).collect()
291 }
292
293 pub fn embedding_dim(&self) -> usize {
295 self.config.embedding_dim
296 }
297
298 pub fn is_stub(&self) -> bool {
300 matches!(self.backend, EmbedderBackend::Stub { .. })
301 }
302
303 pub fn has_model(&self) -> bool {
305 matches!(self.backend, EmbedderBackend::Model { .. })
306 }
307
308 pub fn config(&self) -> &SinterConfig {
310 &self.config
311 }
312}