1use std::path::Path;
7
8use candle_core::quantized::gguf_file;
9use candle_core::quantized::{QMatMul, QTensor};
10use candle_core::{D, DType, Device, IndexOp, Module, Result as CandleResult, Tensor};
11use candle_nn::Linear;
12
13use crate::error::EmbeddingError;
14
15pub fn mean_pool(embeddings: &Tensor, attention_mask: &Tensor) -> CandleResult<Tensor> {
26 let mask = attention_mask
28 .unsqueeze(D::Minus1)?
29 .to_dtype(embeddings.dtype())?;
30 let masked = embeddings.broadcast_mul(&mask)?;
31 let summed = masked.sum(1)?; let counts = mask.sum(1)?; let counts = counts.clamp(1e-9, f64::MAX)?;
35 summed.broadcast_div(&counts)
36}
37
38pub fn l2_normalize(tensor: &Tensor) -> CandleResult<Tensor> {
42 let norm_sq = tensor.sqr()?.sum_keepdim(D::Minus1)?;
43 let norm = norm_sq.sqrt()?.clamp(1e-12, f64::MAX)?;
44 tensor.broadcast_div(&norm)
45}
46
47#[derive(Debug, Clone)]
52struct RmsNorm {
53 weight: Tensor,
54 eps: f64,
55}
56
57impl RmsNorm {
58 fn from_qtensor(qtensor: &QTensor, eps: f64) -> CandleResult<Self> {
59 let weight = qtensor.dequantize(&Device::Cpu)?;
60 Ok(Self { weight, eps })
61 }
62}
63
64impl Module for RmsNorm {
65 fn forward(&self, x: &Tensor) -> CandleResult<Tensor> {
66 #[allow(clippy::cast_possible_truncation)]
67 let eps = self.eps as f32;
68 candle_nn::ops::rms_norm(x, &self.weight, eps)
69 }
70}
71
72#[derive(Debug, Clone)]
77struct RotaryEmbedding {
78 sin: Tensor,
79 cos: Tensor,
80}
81
82impl RotaryEmbedding {
83 #[allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
84 fn new(head_dim: usize, max_seq_len: usize, rope_theta: f64) -> CandleResult<Self> {
85 let inv_freq: Vec<f32> = (0..head_dim)
86 .step_by(2)
87 .map(|i| 1f32 / rope_theta.powf(i as f64 / head_dim as f64) as f32)
88 .collect();
89 let inv_freq_len = inv_freq.len();
90 let inv_freq =
91 Tensor::from_vec(inv_freq, (1, inv_freq_len), &Device::Cpu)?.to_dtype(DType::F32)?;
92 #[allow(clippy::cast_possible_truncation)]
93 let max_seq_u32 = max_seq_len as u32;
94 let t = Tensor::arange(0u32, max_seq_u32, &Device::Cpu)?
95 .to_dtype(DType::F32)?
96 .reshape((max_seq_len, 1))?;
97 let freqs = t.matmul(&inv_freq)?;
98 Ok(Self {
99 sin: freqs.sin()?,
100 cos: freqs.cos()?,
101 })
102 }
103
104 fn apply(&self, q: &Tensor, k: &Tensor) -> CandleResult<(Tensor, Tensor)> {
105 let (_b, _h, seq_len, _d) = q.dims4()?;
106 let cos = self.cos.i(..seq_len)?;
107 let sin = self.sin.i(..seq_len)?;
108 let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
109 let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
110 Ok((q_embed, k_embed))
111 }
112}
113
114struct LayerWeights {
119 attn_q: QMatMul,
120 attn_k: QMatMul,
121 attn_v: QMatMul,
122 attn_o: QMatMul,
123 attn_q_norm: RmsNorm,
124 attn_k_norm: RmsNorm,
125 attn_norm: RmsNorm,
126 post_attn_norm: RmsNorm,
127 ffn_norm: RmsNorm,
128 post_ffn_norm: RmsNorm,
129 ffn_gate: QMatMul,
130 ffn_up: QMatMul,
131 ffn_down: QMatMul,
132 n_heads: usize,
133 n_kv_heads: usize,
134 head_dim: usize,
135 rotary: RotaryEmbedding,
136}
137
138impl LayerWeights {
139 fn forward(&self, x: &Tensor) -> CandleResult<Tensor> {
141 let (b_sz, seq_len, _hidden) = x.dims3()?;
142 let residual = x;
143
144 let x = self.attn_norm.forward(x)?;
146
147 let q = self.attn_q.forward(&x)?;
149 let k = self.attn_k.forward(&x)?;
150 let v = self.attn_v.forward(&x)?;
151
152 let q = q
154 .reshape((b_sz, seq_len, self.n_heads, self.head_dim))?
155 .transpose(1, 2)?;
156 let k = k
157 .reshape((b_sz, seq_len, self.n_kv_heads, self.head_dim))?
158 .transpose(1, 2)?;
159 let v = v
160 .reshape((b_sz, seq_len, self.n_kv_heads, self.head_dim))?
161 .transpose(1, 2)?;
162
163 let q = self.attn_q_norm.forward(&q.contiguous()?)?;
165 let k = self.attn_k_norm.forward(&k.contiguous()?)?;
166
167 let (q, k) = self.rotary.apply(&q, &k)?;
169
170 let repeat = self.n_heads / self.n_kv_heads;
172 let k = if repeat > 1 {
173 let k = k.unsqueeze(2)?;
174 k.expand((b_sz, self.n_kv_heads, repeat, seq_len, self.head_dim))?
175 .reshape((b_sz, self.n_heads, seq_len, self.head_dim))?
176 } else {
177 k
178 };
179 let v = if repeat > 1 {
180 let v = v.unsqueeze(2)?;
181 v.expand((b_sz, self.n_kv_heads, repeat, seq_len, self.head_dim))?
182 .reshape((b_sz, self.n_heads, seq_len, self.head_dim))?
183 } else {
184 v
185 };
186
187 #[allow(clippy::cast_precision_loss)]
189 let scale = 1.0 / (self.head_dim as f64).sqrt();
190 let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
191 let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
192 let attn_out = attn_weights.matmul(&v)?;
193
194 let q_dim = self.n_heads * self.head_dim;
196 let attn_out = attn_out.transpose(1, 2)?.reshape((b_sz, seq_len, q_dim))?;
197 let attn_out = self.attn_o.forward(&attn_out)?;
198
199 let x = (residual + self.post_attn_norm.forward(&attn_out)?)?;
201 let residual = &x;
202
203 let ff_in = self.ffn_norm.forward(&x)?;
205 let gate = self.ffn_gate.forward(&ff_in)?;
206 let up = self.ffn_up.forward(&ff_in)?;
207 let ff_out = (candle_nn::Activation::Gelu.forward(&gate)? * up)?;
208 let ff_out = self.ffn_down.forward(&ff_out)?;
209 let out = (residual + self.post_ffn_norm.forward(&ff_out)?)?;
210
211 Ok(out)
212 }
213}
214
215pub struct EmbeddingGemmaModel {
221 token_embd: Tensor,
222 layers: Vec<LayerWeights>,
223 output_norm: RmsNorm,
224 dense1: Linear,
225 dense2: Linear,
226 tokenizer: tokenizers::Tokenizer,
227}
228
229impl EmbeddingGemmaModel {
230 pub fn load(
232 gguf_path: &Path,
233 dense1_path: &Path,
234 dense2_path: &Path,
235 tokenizer_path: &Path,
236 ) -> Result<Self, EmbeddingError> {
237 let device = Device::Cpu;
238
239 let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)
241 .map_err(|e| EmbeddingError::ModelLoad(format!("tokenizer load failed: {e}")))?;
242
243 let mut file = std::fs::File::open(gguf_path)
245 .map_err(|e| EmbeddingError::ModelLoad(format!("failed to open GGUF: {e}")))?;
246 let ct = gguf_file::Content::read(&mut file)
247 .map_err(|e| EmbeddingError::ModelLoad(format!("failed to read GGUF: {e}")))?;
248
249 let arch = match ct.metadata.get("general.architecture") {
251 Some(gguf_file::Value::String(s)) => s.clone(),
252 _ => "gemma3".to_string(),
253 };
254 let get_meta_u32 = |key: &str| -> Result<u32, EmbeddingError> {
255 let full_key = format!("{arch}.{key}");
256 match ct.metadata.get(&full_key) {
257 Some(gguf_file::Value::U32(v)) => Ok(*v),
258 #[allow(clippy::cast_possible_truncation)]
259 Some(gguf_file::Value::U64(v)) => Ok(*v as u32),
260 _ => Err(EmbeddingError::ModelLoad(format!(
261 "missing or invalid GGUF metadata: {arch}.{key}"
262 ))),
263 }
264 };
265 let get_meta_f32 = |key: &str| -> Result<f32, EmbeddingError> {
266 let full_key = format!("{arch}.{key}");
267 match ct.metadata.get(&full_key) {
268 Some(gguf_file::Value::F32(v)) => Ok(*v),
269 _ => Err(EmbeddingError::ModelLoad(format!(
270 "missing or invalid GGUF metadata: {arch}.{key}"
271 ))),
272 }
273 };
274
275 #[allow(clippy::cast_possible_truncation)]
276 let n_layers = get_meta_u32("block_count")? as usize;
277 #[allow(clippy::cast_possible_truncation)]
278 let n_heads = get_meta_u32("attention.head_count")? as usize;
279 #[allow(clippy::cast_possible_truncation)]
280 let n_kv_heads = get_meta_u32("attention.head_count_kv")? as usize;
281 #[allow(clippy::cast_possible_truncation)]
282 let head_dim = get_meta_u32("attention.key_length")? as usize;
283 let rms_eps =
284 f64::from(get_meta_f32("attention.layer_norm_rms_epsilon").unwrap_or(1e-6_f32));
285 let rope_theta = f64::from(get_meta_f32("rope.freq_base").unwrap_or(10000.0_f32));
286 let max_seq_len = 2048_usize;
287
288 let token_embd = ct
290 .tensor(&mut file, "token_embd.weight", &device)
291 .map_err(|e| EmbeddingError::ModelLoad(format!("token_embd: {e}")))?
292 .dequantize(&device)
293 .map_err(|e| EmbeddingError::ModelLoad(format!("token_embd dequant: {e}")))?;
294
295 let rotary = RotaryEmbedding::new(head_dim, max_seq_len, rope_theta)
297 .map_err(|e| EmbeddingError::ModelLoad(format!("rotary: {e}")))?;
298
299 let mut layers = Vec::with_capacity(n_layers);
301 for i in 0..n_layers {
302 let prefix = format!("blk.{i}");
303 let layer = Self::load_layer(
304 &ct, &mut file, &device, &prefix, rms_eps, n_heads, n_kv_heads, head_dim, &rotary,
305 )?;
306 layers.push(layer);
307 }
308
309 let output_norm_tensor = ct
311 .tensor(&mut file, "output_norm.weight", &device)
312 .map_err(|e| EmbeddingError::ModelLoad(format!("output_norm: {e}")))?;
313 let output_norm = RmsNorm::from_qtensor(&output_norm_tensor, rms_eps)
314 .map_err(|e| EmbeddingError::ModelLoad(format!("output_norm rmsnorm: {e}")))?;
315
316 let dense1 = Self::load_dense(dense1_path, &device)?;
318 let dense2 = Self::load_dense(dense2_path, &device)?;
319
320 Ok(Self {
321 token_embd,
322 layers,
323 output_norm,
324 dense1,
325 dense2,
326 tokenizer,
327 })
328 }
329
330 #[allow(clippy::too_many_arguments)]
331 fn load_layer(
332 ct: &gguf_file::Content,
333 file: &mut std::fs::File,
334 device: &Device,
335 prefix: &str,
336 rms_eps: f64,
337 n_heads: usize,
338 n_kv_heads: usize,
339 head_dim: usize,
340 rotary: &RotaryEmbedding,
341 ) -> Result<LayerWeights, EmbeddingError> {
342 macro_rules! qt {
343 ($name:expr) => {{
344 let full = format!("{}.{}", prefix, $name);
345 ct.tensor(file, &full, device)
346 .map_err(|e| EmbeddingError::ModelLoad(format!("{full}: {e}")))?
347 }};
348 }
349 macro_rules! qm {
350 ($name:expr) => {{
351 let t = qt!($name);
352 let full = format!("{}.{}", prefix, $name);
353 QMatMul::from_qtensor(t)
354 .map_err(|e| EmbeddingError::ModelLoad(format!("{full} qmatmul: {e}")))?
355 }};
356 }
357 macro_rules! rn {
358 ($name:expr) => {{
359 let t = qt!($name);
360 let full = format!("{}.{}", prefix, $name);
361 RmsNorm::from_qtensor(&t, rms_eps)
362 .map_err(|e| EmbeddingError::ModelLoad(format!("{full} rmsnorm: {e}")))?
363 }};
364 }
365
366 Ok(LayerWeights {
367 attn_q: qm!("attn_q.weight"),
368 attn_k: qm!("attn_k.weight"),
369 attn_v: qm!("attn_v.weight"),
370 attn_o: qm!("attn_output.weight"),
371 attn_q_norm: rn!("attn_q_norm.weight"),
372 attn_k_norm: rn!("attn_k_norm.weight"),
373 attn_norm: rn!("attn_norm.weight"),
374 post_attn_norm: rn!("post_attention_norm.weight"),
375 ffn_norm: rn!("ffn_norm.weight"),
376 post_ffn_norm: rn!("post_ffw_norm.weight"),
377 ffn_gate: qm!("ffn_gate.weight"),
378 ffn_up: qm!("ffn_up.weight"),
379 ffn_down: qm!("ffn_down.weight"),
380 n_heads,
381 n_kv_heads,
382 head_dim,
383 rotary: rotary.clone(),
384 })
385 }
386
387 fn load_dense(path: &Path, device: &Device) -> Result<Linear, EmbeddingError> {
388 let tensors = candle_core::safetensors::load(path, device).map_err(|e| {
389 EmbeddingError::ModelLoad(format!("dense safetensors load {}: {e}", path.display()))
390 })?;
391 let weight = tensors
392 .get("linear.weight")
393 .or_else(|| tensors.get("weight"))
394 .or_else(|| tensors.get("0.weight"))
395 .ok_or_else(|| {
396 let keys: Vec<_> = tensors.keys().collect();
397 EmbeddingError::ModelLoad(format!(
398 "no weight tensor found in {}, available keys: {keys:?}",
399 path.display()
400 ))
401 })?
402 .clone();
403 let bias = tensors
404 .get("linear.bias")
405 .or_else(|| tensors.get("bias"))
406 .or_else(|| tensors.get("0.bias"))
407 .cloned();
408 Ok(Linear::new(weight, bias))
409 }
410
411 pub fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
415 let encoding = self
416 .tokenizer
417 .encode(text, true)
418 .map_err(|e| EmbeddingError::Tokenization(e.to_string()))?;
419 let token_ids = encoding.get_ids();
420 let attention_mask_data: Vec<f32> = encoding
421 .get_attention_mask()
422 .iter()
423 .map(|&v| if v == 0 { 0.0_f32 } else { 1.0_f32 })
424 .collect();
425
426 let device = Device::Cpu;
427
428 let input_ids = Tensor::new(token_ids, &device)
430 .map_err(|e| EmbeddingError::Inference(format!("input tensor: {e}")))?
431 .unsqueeze(0)
432 .map_err(|e| EmbeddingError::Inference(format!("unsqueeze: {e}")))?;
433 let attention_mask = Tensor::new(&attention_mask_data[..], &device)
434 .map_err(|e| EmbeddingError::Inference(format!("mask tensor: {e}")))?
435 .unsqueeze(0)
436 .map_err(|e| EmbeddingError::Inference(format!("mask unsqueeze: {e}")))?;
437
438 let mut hidden = self
440 .token_embd
441 .index_select(
442 &input_ids
443 .squeeze(0)
444 .map_err(|e| EmbeddingError::Inference(format!("squeeze: {e}")))?,
445 0,
446 )
447 .map_err(|e| EmbeddingError::Inference(format!("embedding lookup: {e}")))?
448 .unsqueeze(0)
449 .map_err(|e| EmbeddingError::Inference(format!("embd unsqueeze: {e}")))?;
450
451 let hidden_dim = hidden
453 .dim(D::Minus1)
454 .map_err(|e| EmbeddingError::Inference(format!("hidden dim: {e}")))?;
455 #[allow(clippy::cast_precision_loss)]
456 let scale = (hidden_dim as f64).sqrt();
457 hidden = hidden
458 .affine(scale, 0.0)
459 .map_err(|e| EmbeddingError::Inference(format!("embd scale: {e}")))?;
460
461 for layer in &self.layers {
463 hidden = layer
464 .forward(&hidden)
465 .map_err(|e| EmbeddingError::Inference(format!("layer forward: {e}")))?;
466 }
467
468 hidden = self
470 .output_norm
471 .forward(&hidden)
472 .map_err(|e| EmbeddingError::Inference(format!("output norm: {e}")))?;
473
474 let pooled = mean_pool(&hidden, &attention_mask)
476 .map_err(|e| EmbeddingError::Inference(format!("mean pool: {e}")))?;
477
478 let projected = self
480 .dense1
481 .forward(&pooled)
482 .map_err(|e| EmbeddingError::Inference(format!("dense1: {e}")))?;
483 let projected = self
484 .dense2
485 .forward(&projected)
486 .map_err(|e| EmbeddingError::Inference(format!("dense2: {e}")))?;
487
488 let normalized = l2_normalize(&projected)
490 .map_err(|e| EmbeddingError::Inference(format!("l2 normalize: {e}")))?;
491
492 let result: Vec<f32> = normalized
494 .squeeze(0)
495 .map_err(|e| EmbeddingError::Inference(format!("result squeeze: {e}")))?
496 .to_vec1()
497 .map_err(|e| EmbeddingError::Inference(format!("to_vec1: {e}")))?;
498
499 Ok(result)
500 }
501}
502
503#[cfg(test)]
508mod tests {
509 use super::*;
510
511 #[test]
512 fn mean_pool_averages_over_sequence() {
513 let device = Device::Cpu;
514 let embeddings =
516 Tensor::from_vec(vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0], (1, 3, 2), &device).unwrap();
517 let mask = Tensor::from_vec(vec![1.0_f32, 1.0, 1.0], (1, 3), &device).unwrap();
519
520 let pooled = mean_pool(&embeddings, &mask).unwrap();
521 let result: Vec<f32> = pooled.squeeze(0).unwrap().to_vec1().unwrap();
522
523 assert!((result[0] - 3.0).abs() < 1e-5);
525 assert!((result[1] - 4.0).abs() < 1e-5);
526 }
527
528 #[test]
529 fn mean_pool_respects_attention_mask() {
530 let device = Device::Cpu;
531 let embeddings =
533 Tensor::from_vec(vec![1.0_f32, 2.0, 3.0, 4.0, 99.0, 99.0], (1, 3, 2), &device).unwrap();
534 let mask = Tensor::from_vec(vec![1.0_f32, 1.0, 0.0], (1, 3), &device).unwrap();
536
537 let pooled = mean_pool(&embeddings, &mask).unwrap();
538 let result: Vec<f32> = pooled.squeeze(0).unwrap().to_vec1().unwrap();
539
540 assert!((result[0] - 2.0).abs() < 1e-5);
542 assert!((result[1] - 3.0).abs() < 1e-5);
543 }
544
545 #[test]
546 fn l2_normalize_produces_unit_vector() {
547 let device = Device::Cpu;
548 let tensor = Tensor::from_vec(vec![3.0_f32, 4.0], (1, 2), &device).unwrap();
549
550 let normalized = l2_normalize(&tensor).unwrap();
551 let result: Vec<f32> = normalized.squeeze(0).unwrap().to_vec1().unwrap();
552
553 assert!((result[0] - 0.6).abs() < 1e-5);
555 assert!((result[1] - 0.8).abs() < 1e-5);
556
557 let magnitude: f32 = result.iter().map(|x| x * x).sum::<f32>().sqrt();
559 assert!((magnitude - 1.0).abs() < 1e-5);
560 }
561
562 #[test]
563 fn l2_normalize_handles_batch() {
564 let device = Device::Cpu;
565 let tensor = Tensor::from_vec(vec![3.0_f32, 4.0, 0.0, 5.0], (2, 2), &device).unwrap();
566
567 let normalized = l2_normalize(&tensor).unwrap();
568 let result: Vec<Vec<f32>> = normalized.to_vec2().unwrap();
569
570 assert!((result[0][0] - 0.6).abs() < 1e-5);
572 assert!((result[0][1] - 0.8).abs() < 1e-5);
573 assert!(result[1][0].abs() < 1e-5);
575 assert!((result[1][1] - 1.0).abs() < 1e-5);
576 }
577}