1use std::cmp::Reverse;
2
3use candle_core::{DType, Device, IndexOp, Tensor};
4use candle_nn::{Linear, Module, VarBuilder};
5use candle_transformers::models::bert::{BertModel, Config as BertConfig};
6use rayon::prelude::*;
7use tokenizers::{
8 Encoding,
9 PaddingParams,
10 PaddingStrategy,
11 Tokenizer,
12 pad_encodings,
13};
14
15use crate::{
16 builder::{ColbertBuilder, DenseModuleData},
17 error::ColbertError,
18 modernbert::{Config as ModernBertConfig, ModernBert},
19 types::Similarities,
20 utils::normalize_l2,
21};
22
23#[allow(clippy::large_enum_variant)]
33pub enum BaseModel {
34 ModernBert(ModernBert),
36 Bert(BertModel),
38}
39
40impl BaseModel {
41 fn forward(
43 &self,
44 input_ids: &Tensor,
45 attention_mask: &Tensor,
46 token_type_ids: &Tensor,
47 ) -> Result<Tensor, candle_core::Error> {
48 match self {
49 BaseModel::ModernBert(model) => {
50 model.forward(input_ids, attention_mask)
51 }
52 BaseModel::Bert(model) => {
53 model.forward(input_ids, token_type_ids, Some(attention_mask))
54 }
55 }
56 }
57}
58
59pub(crate) fn normalize_and_mask_padded(
66 embeddings: &Tensor,
67 attention_mask: &Tensor,
68) -> Result<Tensor, candle_core::Error> {
69 let normalized = normalize_l2(embeddings)?;
70 let mask = attention_mask.to_dtype(normalized.dtype())?.unsqueeze(2)?;
71 normalized.broadcast_mul(&mask)
72}
73
74#[cfg_attr(not(test), allow(dead_code))]
80pub(crate) fn filter_normalize_and_pad_compact(
81 embeddings: &Tensor,
82 attention_mask: &Tensor,
83 device: &Device,
84) -> Result<Tensor, candle_core::Error> {
85 let (batch_size, _, dim) = embeddings.dims3()?;
86 let dtype = embeddings.dtype();
87 let mut processed_embeddings: Vec<Tensor> = Vec::with_capacity(batch_size);
88 let mut max_len = 0;
89
90 for i in 0..batch_size {
91 let single_embedding = embeddings.i(i)?;
92 let single_mask = attention_mask.i(i)?.to_vec1::<u32>()?;
93
94 let mut kept_rows = Vec::new();
95 for (j, &mask_val) in single_mask.iter().enumerate() {
96 if mask_val == 1 {
97 kept_rows.push(single_embedding.i(j)?);
98 }
99 }
100
101 let (normalized, current_len) = if kept_rows.is_empty() {
102 let zeros = Tensor::zeros((1, dim), dtype, device)?;
103 (zeros, 1)
104 } else {
105 let filtered = Tensor::stack(&kept_rows, 0)?;
106 let len = filtered.dim(0)?;
107 (normalize_l2(&filtered)?, len)
108 };
109
110 if current_len > max_len {
111 max_len = current_len;
112 }
113 processed_embeddings.push(normalized);
114 }
115
116 let mut padded_tensors = Vec::with_capacity(batch_size);
117 for tensor in &processed_embeddings {
118 let current_len = tensor.dim(0)?;
119 let dim = tensor.dim(1)?;
120 let pad_len = max_len - current_len;
121
122 if pad_len > 0 {
123 let padding = Tensor::zeros((pad_len, dim), dtype, device)?;
124 let padded = Tensor::cat(&[tensor, &padding], 0)?;
125 padded_tensors.push(padded);
126 } else {
127 padded_tensors.push(tensor.clone());
128 }
129 }
130
131 Tensor::stack(&padded_tensors, 0)
132}
133
134pub(crate) fn normalize_mask_and_truncate_right_padded(
137 embeddings: &Tensor,
138 attention_mask: &Tensor,
139 max_len: usize,
140) -> Result<Tensor, candle_core::Error> {
141 let masked = normalize_and_mask_padded(embeddings, attention_mask)?;
142 masked.narrow(1, 0, max_len.max(1))
143}
144
145pub(crate) fn concatenate_embedding_batches(
146 embeddings: Vec<Tensor>,
147) -> Result<Tensor, candle_core::Error> {
148 if embeddings.is_empty() {
149 return Err(candle_core::Error::Msg(
150 "embedding batches cannot be empty".into(),
151 ));
152 }
153 if embeddings.len() == 1 {
154 return Ok(embeddings.into_iter().next().unwrap());
155 }
156
157 let mut max_tokens = 0;
158 let mut needs_padding = false;
159 for batch in &embeddings {
160 let (_, tokens, _) = batch.dims3()?;
161 if max_tokens == 0 {
162 max_tokens = tokens;
163 } else if tokens != max_tokens {
164 needs_padding = true;
165 max_tokens = max_tokens.max(tokens);
166 }
167 }
168
169 if !needs_padding {
170 return Tensor::cat(&embeddings, 0);
171 }
172
173 let mut padded_batches = Vec::with_capacity(embeddings.len());
174 for batch in embeddings {
175 let (batch_size, tokens, dim) = batch.dims3()?;
176 if tokens == max_tokens {
177 padded_batches.push(batch);
178 continue;
179 }
180
181 let padding = Tensor::zeros(
182 (batch_size, max_tokens - tokens, dim),
183 batch.dtype(),
184 batch.device(),
185 )?;
186 padded_batches.push(Tensor::cat(&[&batch, &padding], 1)?);
187 }
188
189 Tensor::cat(&padded_batches, 0)
190}
191
192pub(crate) fn compute_similarities(
199 queries_embeddings: &Tensor,
200 documents_embeddings: &Tensor,
201) -> Result<Similarities, ColbertError> {
202 let scores =
203 compute_raw_similarity(queries_embeddings, documents_embeddings)?;
204 let max_scores = scores.max(3)?;
205 let similarities = max_scores.sum(2)?;
206 let similarities_vec = similarities.to_vec2::<f32>()?;
207 Ok(Similarities {
208 data: similarities_vec,
209 })
210}
211
212pub(crate) fn compute_raw_similarity(
217 queries_embeddings: &Tensor,
218 documents_embeddings: &Tensor,
219) -> Result<Tensor, ColbertError> {
220 queries_embeddings
221 .unsqueeze(1)?
222 .broadcast_matmul(&documents_embeddings.transpose(1, 2)?.unsqueeze(0)?)
223 .map_err(ColbertError::from)
224}
225
226pub(crate) fn build_dense_layers(
234 dense_modules: Vec<DenseModuleData>,
235 device: &Device,
236) -> Result<Vec<DenseLayer>, ColbertError> {
237 const SUPPORTED_ACTIVATION: &str = "torch.nn.modules.linear.Identity";
238
239 let mut layers = Vec::with_capacity(dense_modules.len());
240 for (idx, module) in dense_modules.into_iter().enumerate() {
241 let cfg: serde_json::Value =
242 serde_json::from_slice(&module.config_bytes)?;
243
244 let activation = cfg["activation_function"]
245 .as_str()
246 .unwrap_or(SUPPORTED_ACTIVATION);
247 if activation != SUPPORTED_ACTIVATION {
248 return Err(ColbertError::Operation(format!(
249 "Dense module {idx}: unsupported activation_function '{activation}' (only {SUPPORTED_ACTIVATION} is supported)"
250 )));
251 }
252 if cfg["bias"].as_bool().unwrap_or(false) {
253 return Err(ColbertError::Operation(format!(
254 "Dense module {idx}: bias=true is not supported"
255 )));
256 }
257 let in_features = cfg["in_features"].as_u64().ok_or_else(|| {
258 ColbertError::Operation(format!(
259 "Dense module {idx}: missing 'in_features'"
260 ))
261 })? as usize;
262 let out_features = cfg["out_features"].as_u64().ok_or_else(|| {
263 ColbertError::Operation(format!(
264 "Dense module {idx}: missing 'out_features'"
265 ))
266 })? as usize;
267 let use_residual = cfg["use_residual"].as_bool().unwrap_or(false);
268
269 let vb = VarBuilder::from_buffered_safetensors(
270 module.weights_bytes,
271 DType::F32,
272 device,
273 )?;
274 let linear = candle_nn::linear_no_bias(
275 in_features,
276 out_features,
277 vb.pp("linear"),
278 )?;
279 let residual = if use_residual {
280 Some(candle_nn::linear_no_bias(
281 in_features,
282 out_features,
283 vb.pp("residual"),
284 )?)
285 } else {
286 None
287 };
288 layers.push(DenseLayer { linear, residual });
289 }
290 Ok(layers)
291}
292
293pub(crate) struct DenseLayer {
301 pub(crate) linear: Linear,
302 pub(crate) residual: Option<Linear>,
303}
304
305impl DenseLayer {
306 pub(crate) fn forward(
309 &self,
310 x: &Tensor,
311 ) -> Result<Tensor, candle_core::Error> {
312 let proj = self.linear.forward(x)?;
313 match &self.residual {
314 Some(residual) => proj + residual.forward(x)?,
315 None => Ok(proj),
316 }
317 }
318}
319
320pub struct ColBERT {
327 pub(crate) model: BaseModel,
328 pub(crate) dense_layers: Vec<DenseLayer>,
329 pub(crate) tokenizer: Tokenizer,
330 pub(crate) mask_token_id: u32,
331 pub(crate) mask_token: String,
332 pub(crate) query_prefix: String,
333 pub(crate) document_prefix: String,
334 pub(crate) query_prompt: String,
335 pub(crate) document_prompt: String,
336 pub(crate) do_query_expansion: bool,
337 pub(crate) attend_to_expansion_tokens: bool,
338 pub(crate) query_length: usize,
339 pub(crate) document_length: usize,
340 pub(crate) batch_size: usize,
341 pub device: Device,
343}
344
345impl ColBERT {
346 #[allow(clippy::too_many_arguments)]
353 pub fn new(
354 weights: Vec<u8>,
355 dense_modules: Vec<DenseModuleData>,
356 tokenizer_bytes: Vec<u8>,
357 config_bytes: Vec<u8>,
358 query_prefix: String,
359 document_prefix: String,
360 query_prompt: String,
361 document_prompt: String,
362 mask_token: String,
363 do_query_expansion: bool,
364 attend_to_expansion_tokens: bool,
365 query_length: Option<usize>,
366 document_length: Option<usize>,
367 batch_size: Option<usize>,
368 device: &Device,
369 ) -> Result<Self, ColbertError> {
370 if dense_modules.is_empty() {
371 return Err(ColbertError::Operation(
372 "ColBERT requires at least one Dense projection layer".into(),
373 ));
374 }
375
376 let vb =
377 VarBuilder::from_buffered_safetensors(weights, DType::F32, device)?;
378
379 let config_value: serde_json::Value =
380 serde_json::from_slice(&config_bytes)?;
381 let architectures = config_value["architectures"]
382 .as_array()
383 .and_then(|arr| arr.first())
384 .and_then(|v| v.as_str())
385 .ok_or_else(|| {
386 ColbertError::Operation(
387 "Missing or invalid 'architectures' in config.json".into(),
388 )
389 })?;
390
391 let model = match architectures {
392 "ModernBertModel" => {
393 let config: ModernBertConfig =
394 serde_json::from_slice(&config_bytes)?;
395 let model = ModernBert::load(vb.clone(), &config)?;
396 BaseModel::ModernBert(model)
397 }
398 "BertForMaskedLM" | "BertModel" => {
399 let config: BertConfig = serde_json::from_slice(&config_bytes)?;
400 let model = BertModel::load(vb.clone(), &config)?;
401 BaseModel::Bert(model)
402 }
403 arch => {
404 return Err(ColbertError::Operation(format!(
405 "Unsupported architecture: {}",
406 arch
407 )));
408 }
409 };
410
411 let tokenizer = Tokenizer::from_bytes(&tokenizer_bytes)?;
412
413 let mask_token_id =
414 tokenizer.token_to_id(mask_token.as_str()).ok_or_else(|| {
415 ColbertError::Operation(format!(
416 "Token '{}' not found in the tokenizer's vocabulary.",
417 mask_token
418 ))
419 })?;
420
421 let dense_layers = build_dense_layers(dense_modules, device)?;
422
423 let final_attend_to_expansion_tokens = if !do_query_expansion {
425 false
426 } else {
427 attend_to_expansion_tokens
428 };
429
430 Ok(Self {
431 model,
432 dense_layers,
433 tokenizer,
434 mask_token_id,
435 mask_token,
436 query_prefix,
437 document_prefix,
438 query_prompt,
439 document_prompt,
440 do_query_expansion,
441 attend_to_expansion_tokens: final_attend_to_expansion_tokens,
442 query_length: query_length.unwrap_or(32),
443 document_length: document_length.unwrap_or(180),
444 batch_size: batch_size.unwrap_or(32),
445 device: device.clone(),
446 })
447 }
448
449 pub fn from(repo_id: &str) -> ColbertBuilder {
451 ColbertBuilder::new(repo_id)
452 }
453
454 fn finalize_embeddings(
459 &self,
460 projected_embeddings: &Tensor,
461 attention_mask: &Tensor,
462 max_valid_len: usize,
463 is_query: bool,
464 ) -> Result<Tensor, candle_core::Error> {
465 if is_query && self.do_query_expansion {
466 normalize_l2(projected_embeddings).map_err(candle_core::Error::from)
467 } else {
468 normalize_mask_and_truncate_right_padded(
469 projected_embeddings,
470 attention_mask,
471 max_valid_len,
472 )
473 }
474 }
475
476 pub(crate) fn project(
481 &self,
482 token_embeddings: &Tensor,
483 ) -> Result<Tensor, candle_core::Error> {
484 let mut iter = self.dense_layers.iter();
485 let first = iter
486 .next()
487 .expect("ColBERT::new guarantees at least one Dense layer");
488 let mut out = first.forward(token_embeddings)?;
489 for layer in iter {
490 out = layer.forward(&out)?;
491 }
492 Ok(out)
493 }
494
495 pub fn document_token_lengths(
503 &mut self,
504 sentences: &[String],
505 ) -> Result<Vec<u32>, ColbertError> {
506 if sentences.is_empty() {
507 return Ok(Vec::new());
508 }
509 let _ = self.tokenizer.with_truncation(Some(
510 tokenizers::TruncationParams {
511 max_length: self.document_length,
512 ..Default::default()
513 },
514 ));
515 self.tokenizer.with_padding(None);
519
520 let prompt = self.document_prompt.as_str();
521 let prefix = self.document_prefix.as_str();
522 let prefixed_texts: Vec<String> =
523 if prompt.is_empty() && prefix.is_empty() {
524 sentences.to_vec()
525 } else {
526 sentences
527 .iter()
528 .map(|text| format!("{prefix}{prompt}{text}"))
529 .collect()
530 };
531
532 let encodings =
533 self.tokenizer.encode_batch_fast(prefixed_texts, true)?;
534 Ok(encodings.iter().map(|e| e.get_ids().len() as u32).collect())
535 }
536
537 pub fn encode_documents_with_lengths(
550 &mut self,
551 sentences: &[String],
552 ) -> Result<(Tensor, Vec<u32>), ColbertError> {
553 let lengths = self.document_token_lengths(sentences)?;
554 let embeddings = self.encode(sentences, false)?;
555 Ok((embeddings, lengths))
556 }
557
558 pub fn encode(
563 &mut self,
564 sentences: &[String],
565 is_query: bool,
566 ) -> Result<Tensor, ColbertError> {
567 if sentences.is_empty() {
568 return Err(ColbertError::Operation(
569 "Input sentences cannot be empty.".into(),
570 ));
571 }
572
573 let prompt = if is_query {
574 &self.query_prompt
575 } else {
576 &self.document_prompt
577 };
578 let prompted: Vec<String>;
579 let sentences: &[String] = if prompt.is_empty() {
580 sentences
581 } else {
582 prompted =
583 sentences.iter().map(|s| format!("{prompt}{s}")).collect();
584 &prompted
585 };
586
587 if self.device.is_cpu() {
588 let mut tokenized_batches = Vec::new();
589 for batch_sentences in sentences.chunks(self.batch_size) {
590 tokenized_batches
591 .push(self.tokenize(batch_sentences, is_query)?);
592 }
593
594 let all_embeddings = tokenized_batches
595 .into_par_iter()
596 .map(
597 |(
598 token_ids,
599 attention_mask,
600 token_type_ids,
601 max_valid_len,
602 )|
603 -> Result<Tensor, ColbertError> {
604 let token_embeddings = self.model.forward(
605 &token_ids,
606 &attention_mask,
607 &token_type_ids,
608 )?;
609 let token_embeddings =
610 if token_embeddings.is_contiguous() {
611 token_embeddings
612 } else {
613 token_embeddings.contiguous()?
614 };
615 let projected_embeddings =
616 self.project(&token_embeddings)?;
617
618 self.finalize_embeddings(
619 &projected_embeddings,
620 &attention_mask,
621 max_valid_len,
622 is_query,
623 )
624 .map_err(ColbertError::from)
625 },
626 )
627 .collect::<Result<Vec<_>, _>>()?;
628
629 return concatenate_embedding_batches(all_embeddings)
630 .map_err(ColbertError::from);
631 }
632
633 if !is_query && sentences.len() > self.batch_size {
635 let texts_with_prefix: Vec<_> = sentences
636 .iter()
637 .map(|text| format!("{}{}", self.document_prefix, text))
638 .collect();
639 let _ = self.tokenizer.with_truncation(Some(
640 tokenizers::TruncationParams {
641 max_length: self.document_length,
642 ..Default::default()
643 },
644 ));
645 self.tokenizer.with_padding(None);
646
647 let encodings =
648 self.tokenizer.encode_batch_fast(texts_with_prefix, true)?;
649 let mut indexed_encodings: Vec<(usize, Encoding)> =
650 encodings.into_iter().enumerate().collect();
651 indexed_encodings.sort_unstable_by_key(|(_, encoding)| {
652 Reverse(encoding.get_ids().len())
653 });
654
655 let mut inverse = vec![0u32; indexed_encodings.len()];
656 for (sorted_idx, (original_idx, _)) in
657 indexed_encodings.iter().enumerate()
658 {
659 inverse[*original_idx] = sorted_idx as u32;
660 }
661 let inverse_len = inverse.len();
662 let mut sorted_encodings: Vec<Encoding> = indexed_encodings
663 .into_iter()
664 .map(|(_, encoding)| encoding)
665 .collect();
666
667 let mut all_embeddings = Vec::with_capacity(
668 sorted_encodings.len().div_ceil(self.batch_size),
669 );
670 let padding = PaddingParams {
671 strategy: PaddingStrategy::BatchLongest,
672 ..Default::default()
673 };
674 let max_tokens_per_batch =
675 self.batch_size * self.document_length.max(1);
676 let mut batch_start = 0usize;
677 while batch_start < sorted_encodings.len() {
678 let first_len =
679 sorted_encodings[batch_start].get_ids().len().max(1);
680 let batch_cap = (max_tokens_per_batch / first_len).max(1);
681 let batch_end =
682 (batch_start + batch_cap).min(sorted_encodings.len());
683 let batch_encodings =
684 &mut sorted_encodings[batch_start..batch_end];
685 let first_len = batch_encodings
686 .first()
687 .map_or(0, |encoding| encoding.get_ids().len());
688 let last_len = batch_encodings
689 .last()
690 .map_or(0, |encoding| encoding.get_ids().len());
691 let has_padding = first_len != last_len;
692 if has_padding {
693 pad_encodings(batch_encodings, &padding)?;
694 }
695 let (token_ids, attention_mask, token_type_ids, max_valid_len) =
696 self.tensorize_encodings(batch_encodings, false)?;
697
698 let token_embeddings = {
699 #[cfg(feature = "cuda")]
700 {
701 let valid_lens = if has_padding {
702 Some(
703 batch_encodings
704 .iter()
705 .map(|encoding| encoding.get_ids().len())
706 .collect::<Vec<_>>(),
707 )
708 } else {
709 None
710 };
711
712 if !has_padding {
713 if let BaseModel::ModernBert(model) = &self.model {
714 model.forward_unmasked(&token_ids)?
715 } else {
716 self.model.forward(
717 &token_ids,
718 &attention_mask,
719 &token_type_ids,
720 )?
721 }
722 } else if let (
723 BaseModel::ModernBert(model),
724 Some(valid_lens),
725 ) = (&self.model, valid_lens.as_ref())
726 {
727 model
728 .forward_varlen_padded(&token_ids, valid_lens)?
729 } else {
730 self.model.forward(
731 &token_ids,
732 &attention_mask,
733 &token_type_ids,
734 )?
735 }
736 }
737 #[cfg(not(feature = "cuda"))]
738 {
739 self.model.forward(
740 &token_ids,
741 &attention_mask,
742 &token_type_ids,
743 )?
744 }
745 };
746 let token_embeddings = if token_embeddings.is_contiguous() {
747 token_embeddings
748 } else {
749 token_embeddings.contiguous()?
750 };
751 let projected_embeddings = self.project(&token_embeddings)?;
752 let final_embeddings = self.finalize_embeddings(
753 &projected_embeddings,
754 &attention_mask,
755 max_valid_len,
756 false,
757 )?;
758 all_embeddings.push(final_embeddings);
759 batch_start = batch_end;
760 }
761
762 let embeddings = concatenate_embedding_batches(all_embeddings)
763 .map_err(ColbertError::from)?;
764 let restore_indices =
765 Tensor::from_vec(inverse, inverse_len, &self.device)?;
766 return embeddings
767 .index_select(&restore_indices, 0)
768 .map_err(ColbertError::from);
769 }
770
771 let mut all_embeddings =
772 Vec::with_capacity(sentences.len().div_ceil(self.batch_size));
773 for batch_sentences in sentences.chunks(self.batch_size) {
774 let (token_ids, attention_mask, token_type_ids, max_valid_len) =
775 self.tokenize(batch_sentences, is_query)?;
776
777 let token_embeddings = self.model.forward(
778 &token_ids,
779 &attention_mask,
780 &token_type_ids,
781 )?;
782 let token_embeddings = if token_embeddings.is_contiguous() {
783 token_embeddings
784 } else {
785 token_embeddings.contiguous()?
786 };
787
788 let projected_embeddings = self.project(&token_embeddings)?;
789
790 let final_embeddings = self.finalize_embeddings(
791 &projected_embeddings,
792 &attention_mask,
793 max_valid_len,
794 is_query,
795 )?;
796
797 all_embeddings.push(final_embeddings);
798 }
799
800 concatenate_embedding_batches(all_embeddings)
801 .map_err(ColbertError::from)
802 }
803
804 pub fn similarity(
806 &self,
807 queries_embeddings: &Tensor,
808 documents_embeddings: &Tensor,
809 ) -> Result<Similarities, ColbertError> {
810 compute_similarities(queries_embeddings, documents_embeddings)
811 }
812
813 pub fn raw_similarity(
815 &self,
816 queries_embeddings: &Tensor,
817 documents_embeddings: &Tensor,
818 ) -> Result<Tensor, ColbertError> {
819 compute_raw_similarity(queries_embeddings, documents_embeddings)
820 }
821
822 fn tensorize_encodings(
823 &self,
824 encodings: &[Encoding],
825 is_query: bool,
826 ) -> Result<(Tensor, Tensor, Tensor, usize), ColbertError> {
827 let device = &self.device;
828 let batch_size = encodings.len();
829 if batch_size == 0 {
830 return Err(ColbertError::Operation(
831 "Input sentences cannot be empty.".into(),
832 ));
833 }
834
835 let seq_len = encodings.first().map_or(0, |e| e.get_ids().len());
840 let needs_query_valid_len = is_query
841 && !self.do_query_expansion
842 && !self.attend_to_expansion_tokens;
843 let needs_token_type_ids = matches!(&self.model, BaseModel::Bert(_));
844 let mut max_valid_len = if needs_query_valid_len {
845 1
846 } else {
847 seq_len.max(1)
848 };
849 let flat_len = batch_size * seq_len;
850 let mut ids_vec = Vec::<u32>::with_capacity(flat_len);
851 let mut mask_vec = Vec::<u32>::with_capacity(flat_len);
852 let mut type_ids_vec =
853 needs_token_type_ids.then(|| Vec::<u32>::with_capacity(flat_len));
854 for enc in encodings {
855 ids_vec.extend(enc.get_ids());
856 let attention = enc.get_attention_mask();
857 if needs_query_valid_len {
858 let mut valid_len = 0usize;
859 for &mask in attention {
860 valid_len += mask as usize;
861 mask_vec.push(mask);
862 }
863 max_valid_len = max_valid_len.max(valid_len.max(1));
864 } else {
865 mask_vec.extend(attention);
866 }
867 if let Some(type_ids_vec) = type_ids_vec.as_mut() {
868 type_ids_vec.extend(enc.get_type_ids());
869 }
870 }
871
872 let token_ids =
873 Tensor::from_vec(ids_vec, (batch_size, seq_len), device)?;
874 let mut attention_mask =
875 Tensor::from_vec(mask_vec, (batch_size, seq_len), device)?;
876 let token_type_ids = match type_ids_vec {
877 Some(type_ids_vec) => {
878 Tensor::from_vec(type_ids_vec, (batch_size, seq_len), device)?
879 }
880 None => Tensor::zeros((1, 1), DType::U32, device)?,
881 };
882
883 if is_query && self.attend_to_expansion_tokens {
884 attention_mask = attention_mask.ones_like()?;
885 }
886
887 Ok((token_ids, attention_mask, token_type_ids, max_valid_len))
888 }
889
890 pub(crate) fn tokenize(
892 &mut self,
893 texts: &[String],
894 is_query: bool,
895 ) -> Result<(Tensor, Tensor, Tensor, usize), ColbertError> {
896 let (prefix, max_length) = if is_query {
897 (self.query_prefix.as_str(), self.query_length)
898 } else {
899 (self.document_prefix.as_str(), self.document_length)
900 };
901
902 let texts_with_prefix: Vec<_> = texts
903 .iter()
904 .map(|text| format!("{}{}", prefix, text))
905 .collect();
906
907 let _ = self.tokenizer.with_truncation(Some(
908 tokenizers::TruncationParams {
909 max_length,
910 ..Default::default()
911 },
912 ));
913
914 let padding_params = if is_query {
915 PaddingParams {
916 strategy: PaddingStrategy::Fixed(max_length),
917 pad_id: self.mask_token_id,
918 pad_token: self.mask_token.clone(),
919 ..Default::default()
920 }
921 } else {
922 PaddingParams {
923 strategy: PaddingStrategy::BatchLongest,
924 ..Default::default()
925 }
926 };
927 self.tokenizer.with_padding(Some(padding_params));
928
929 let encodings =
930 self.tokenizer.encode_batch_fast(texts_with_prefix, true)?;
931 self.tensorize_encodings(&encodings, is_query)
932 }
933}
934
935#[cfg(test)]
944fn test_device() -> Device {
945 #[cfg(feature = "cuda")]
946 {
947 if let Ok(d) = Device::new_cuda(0) {
948 return d;
949 }
950 }
951 #[cfg(feature = "metal")]
952 {
953 if let Ok(d) = Device::new_metal(0) {
954 return d;
955 }
956 }
957 Device::Cpu
958}
959
960#[cfg(test)]
961mod tests {
962 use candle_core::{DType, Tensor};
963
964 use super::{
965 concatenate_embedding_batches,
966 filter_normalize_and_pad_compact,
967 normalize_and_mask_padded,
968 normalize_mask_and_truncate_right_padded,
969 };
970
971 #[test]
972 fn fast_document_path_matches_compact_path_for_right_padded_masks() {
973 let device = super::test_device();
974 let embeddings = Tensor::from_vec(
975 vec![
976 1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, ],
979 (2, 4, 2),
980 &device,
981 )
982 .unwrap();
983 let attention_mask =
984 Tensor::from_vec(vec![1u32, 1, 1, 1, 1, 1, 0, 0], (2, 4), &device)
985 .unwrap();
986
987 let compact = filter_normalize_and_pad_compact(
988 &embeddings,
989 &attention_mask,
990 &device,
991 )
992 .unwrap();
993 let fast =
994 normalize_and_mask_padded(&embeddings, &attention_mask).unwrap();
995
996 let compact = compact.to_vec3::<f32>().unwrap();
997 let fast = fast.to_vec3::<f32>().unwrap();
998
999 assert_eq!(compact.len(), fast.len());
1000 for (compact_doc, fast_doc) in compact.iter().zip(fast.iter()) {
1001 assert_eq!(compact_doc.len(), fast_doc.len());
1002 for (compact_row, fast_row) in
1003 compact_doc.iter().zip(fast_doc.iter())
1004 {
1005 assert_eq!(compact_row.len(), fast_row.len());
1006 for (compact_value, fast_value) in
1007 compact_row.iter().zip(fast_row.iter())
1008 {
1009 assert!((compact_value - fast_value).abs() < 1e-6);
1010 }
1011 }
1012 }
1013 }
1014
1015 #[test]
1016 fn fast_query_path_matches_compact_path_for_right_padded_masks() {
1017 let device = super::test_device();
1018 let embeddings = Tensor::from_vec(
1019 vec![
1020 1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, ],
1023 (2, 4, 2),
1024 &device,
1025 )
1026 .unwrap();
1027 let attention_mask =
1028 Tensor::from_vec(vec![1u32, 1, 1, 0, 1, 1, 0, 0], (2, 4), &device)
1029 .unwrap();
1030
1031 let compact = filter_normalize_and_pad_compact(
1032 &embeddings,
1033 &attention_mask,
1034 &device,
1035 )
1036 .unwrap();
1037 let fast = normalize_mask_and_truncate_right_padded(
1038 &embeddings,
1039 &attention_mask,
1040 3,
1041 )
1042 .unwrap();
1043
1044 assert_eq!(
1045 compact.to_vec3::<f32>().unwrap(),
1046 fast.to_vec3::<f32>().unwrap()
1047 );
1048 }
1049
1050 #[test]
1051 fn fast_document_path_zeroes_masked_rows() {
1052 let device = super::test_device();
1053 let embeddings = Tensor::from_vec(
1054 vec![1.0f32, 0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 2.0],
1055 (1, 4, 2),
1056 &device,
1057 )
1058 .unwrap();
1059 let attention_mask =
1060 Tensor::from_vec(vec![1u32, 1, 0, 0], (1, 4), &device).unwrap();
1061
1062 let fast = normalize_and_mask_padded(&embeddings, &attention_mask)
1063 .unwrap()
1064 .to_vec3::<f32>()
1065 .unwrap();
1066
1067 assert!((fast[0][0][0] - 1.0).abs() < 1e-6);
1068 assert!((fast[0][0][1] - 0.0).abs() < 1e-6);
1069 assert!((fast[0][1][0] - 0.0).abs() < 1e-6);
1070 assert!((fast[0][1][1] - 1.0).abs() < 1e-6);
1071 assert_eq!(fast[0][2], vec![0.0, 0.0]);
1072 assert_eq!(fast[0][3], vec![0.0, 0.0]);
1073 }
1074
1075 #[test]
1076 fn concatenate_embedding_batches_pads_variable_sequence_lengths() {
1077 let device = super::test_device();
1078 let first = Tensor::zeros((64, 514, 128), DType::F32, &device).unwrap();
1079 let second =
1080 Tensor::zeros((64, 519, 128), DType::F32, &device).unwrap();
1081
1082 assert!(Tensor::cat(&[&first, &second], 0).is_err());
1083
1084 let combined =
1085 concatenate_embedding_batches(vec![first, second]).unwrap();
1086 assert_eq!(combined.dims3().unwrap(), (128, 519, 128));
1087 }
1088}
1089
1090#[cfg(test)]
1091mod hegel_tests {
1092 use candle_core::{Device, Tensor};
1106 use candle_nn::{Linear, Module};
1107 use hegel::{TestCase, generators as gs};
1108
1109 use super::{
1110 DenseLayer,
1111 compute_raw_similarity,
1112 compute_similarities,
1113 concatenate_embedding_batches,
1114 filter_normalize_and_pad_compact,
1115 normalize_and_mask_padded,
1116 normalize_mask_and_truncate_right_padded,
1117 test_device,
1118 };
1119
1120 #[hegel::composite]
1128 fn embeddings_with_free_mask(
1129 tc: TestCase,
1130 dev: Device,
1131 ) -> (Tensor, Tensor) {
1132 let b: usize =
1133 tc.draw(gs::integers::<usize>().min_value(1).max_value(3));
1134 let s: usize =
1135 tc.draw(gs::integers::<usize>().min_value(1).max_value(8));
1136 let d: usize =
1137 tc.draw(gs::integers::<usize>().min_value(1).max_value(6));
1138 let emb_data: Vec<f32> = tc.draw(
1139 gs::vecs(
1140 gs::floats::<f32>()
1141 .min_value(-5.0)
1142 .max_value(5.0)
1143 .allow_nan(false)
1144 .allow_infinity(false),
1145 )
1146 .min_size(b * s * d)
1147 .max_size(b * s * d),
1148 );
1149 let mask_data: Vec<u32> = tc.draw(
1150 gs::vecs(gs::integers::<u32>().min_value(0).max_value(1))
1151 .min_size(b * s)
1152 .max_size(b * s),
1153 );
1154 let embeddings = Tensor::from_vec(emb_data, (b, s, d), &dev).unwrap();
1155 let mask = Tensor::from_vec(mask_data, (b, s), &dev).unwrap();
1156 (embeddings, mask)
1157 }
1158
1159 #[hegel::composite]
1164 fn embeddings_with_right_padded_mask(
1165 tc: TestCase,
1166 dev: Device,
1167 ) -> (Tensor, Tensor, usize) {
1168 let b: usize =
1169 tc.draw(gs::integers::<usize>().min_value(1).max_value(3));
1170 let s: usize =
1171 tc.draw(gs::integers::<usize>().min_value(1).max_value(8));
1172 let d: usize =
1173 tc.draw(gs::integers::<usize>().min_value(1).max_value(6));
1174 let emb_data: Vec<f32> = tc.draw(
1175 gs::vecs(
1176 gs::floats::<f32>()
1177 .min_value(-5.0)
1178 .max_value(5.0)
1179 .allow_nan(false)
1180 .allow_infinity(false),
1181 )
1182 .min_size(b * s * d)
1183 .max_size(b * s * d),
1184 );
1185 let mut mask_flat = Vec::<u32>::with_capacity(b * s);
1186 let mut max_valid = 0usize;
1187 for _ in 0..b {
1188 let valid: usize =
1189 tc.draw(gs::integers::<usize>().min_value(0).max_value(s));
1190 max_valid = max_valid.max(valid);
1191 for j in 0..s {
1192 mask_flat.push(u32::from(j < valid));
1193 }
1194 }
1195 let embeddings = Tensor::from_vec(emb_data, (b, s, d), &dev).unwrap();
1196 let mask = Tensor::from_vec(mask_flat, (b, s), &dev).unwrap();
1197 (embeddings, mask, max_valid)
1198 }
1199
1200 #[hegel::composite]
1204 fn embedding_batch_list(tc: TestCase, dev: Device) -> Vec<Tensor> {
1205 let n_batches: usize =
1206 tc.draw(gs::integers::<usize>().min_value(1).max_value(4));
1207 let batch: usize =
1208 tc.draw(gs::integers::<usize>().min_value(1).max_value(3));
1209 let dim: usize =
1210 tc.draw(gs::integers::<usize>().min_value(1).max_value(4));
1211 let finite = || {
1212 gs::floats::<f32>()
1213 .min_value(-3.0)
1214 .max_value(3.0)
1215 .allow_nan(false)
1216 .allow_infinity(false)
1217 };
1218 let mut out = Vec::with_capacity(n_batches);
1219 for _ in 0..n_batches {
1220 let tokens: usize =
1221 tc.draw(gs::integers::<usize>().min_value(1).max_value(6));
1222 let data: Vec<f32> = tc.draw(
1223 gs::vecs(finite())
1224 .min_size(batch * tokens * dim)
1225 .max_size(batch * tokens * dim),
1226 );
1227 out.push(
1228 Tensor::from_vec(data, (batch, tokens, dim), &dev).unwrap(),
1229 );
1230 }
1231 out
1232 }
1233
1234 #[hegel::composite]
1237 fn query_doc_pair(tc: TestCase, dev: Device) -> (Tensor, Tensor) {
1238 let dim: usize =
1239 tc.draw(gs::integers::<usize>().min_value(1).max_value(6));
1240 let q_batch: usize =
1241 tc.draw(gs::integers::<usize>().min_value(1).max_value(3));
1242 let q_tokens: usize =
1243 tc.draw(gs::integers::<usize>().min_value(1).max_value(6));
1244 let d_batch: usize =
1245 tc.draw(gs::integers::<usize>().min_value(1).max_value(3));
1246 let d_tokens: usize =
1247 tc.draw(gs::integers::<usize>().min_value(1).max_value(6));
1248 let finite = || {
1249 gs::floats::<f32>()
1250 .min_value(-1.0)
1251 .max_value(1.0)
1252 .allow_nan(false)
1253 .allow_infinity(false)
1254 };
1255 let q_data: Vec<f32> = tc.draw(
1256 gs::vecs(finite())
1257 .min_size(q_batch * q_tokens * dim)
1258 .max_size(q_batch * q_tokens * dim),
1259 );
1260 let d_data: Vec<f32> = tc.draw(
1261 gs::vecs(finite())
1262 .min_size(d_batch * d_tokens * dim)
1263 .max_size(d_batch * d_tokens * dim),
1264 );
1265 let q =
1266 Tensor::from_vec(q_data, (q_batch, q_tokens, dim), &dev).unwrap();
1267 let d =
1268 Tensor::from_vec(d_data, (d_batch, d_tokens, dim), &dev).unwrap();
1269 (q, d)
1270 }
1271
1272 #[hegel::test(test_cases = 200)]
1281 fn normalize_and_mask_padded_respects_mask(tc: TestCase) {
1282 let dev = test_device();
1283 let (emb, mask) = tc.draw(embeddings_with_free_mask(dev));
1284 let out = normalize_and_mask_padded(&emb, &mask).unwrap();
1285 assert_eq!(out.dims(), emb.dims(), "shape must be preserved");
1286
1287 let out_v: Vec<Vec<Vec<f32>>> = out.to_vec3::<f32>().unwrap();
1288 let mask_v: Vec<Vec<u32>> = mask.to_vec2::<u32>().unwrap();
1289 for (b_idx, row_block) in out_v.iter().enumerate() {
1290 for (s_idx, row) in row_block.iter().enumerate() {
1291 let bit = mask_v[b_idx][s_idx];
1292 if bit == 0 {
1293 for v in row {
1294 assert_eq!(
1295 *v, 0.0,
1296 "masked row at ({b_idx},{s_idx}) not zeroed",
1297 );
1298 }
1299 } else {
1300 let n2: f32 = row.iter().map(|v| v * v).sum();
1301 assert!(
1302 n2 <= 1.0 + 1e-4,
1303 "unmasked row at ({b_idx},{s_idx}) has n²={n2}",
1304 );
1305 }
1306 }
1307 }
1308 }
1309
1310 #[hegel::test(test_cases = 200)]
1313 fn truncate_right_padded_has_expected_shape(tc: TestCase) {
1314 let dev = test_device();
1315 let (emb, mask, max_valid) =
1316 tc.draw(embeddings_with_right_padded_mask(dev));
1317 let (b, _, d) = emb.dims3().unwrap();
1318 let out =
1319 normalize_mask_and_truncate_right_padded(&emb, &mask, max_valid)
1320 .unwrap();
1321 assert_eq!(out.dim(0).unwrap(), b);
1322 assert_eq!(out.dim(1).unwrap(), max_valid.max(1));
1323 assert_eq!(out.dim(2).unwrap(), d);
1324 }
1325
1326 #[hegel::test(test_cases = 200)]
1331 fn truncate_right_padded_matches_compact(tc: TestCase) {
1332 let dev = test_device();
1333 let (emb, mask, max_valid) =
1334 tc.draw(embeddings_with_right_padded_mask(dev.clone()));
1335 let fast =
1336 normalize_mask_and_truncate_right_padded(&emb, &mask, max_valid)
1337 .unwrap();
1338 let compact =
1339 filter_normalize_and_pad_compact(&emb, &mask, &dev).unwrap();
1340
1341 let (fast_b, fast_s, fast_d) = fast.dims3().unwrap();
1346 let (comp_b, comp_s, comp_d) = compact.dims3().unwrap();
1347 assert_eq!(fast_b, comp_b);
1348 assert_eq!(fast_d, comp_d);
1349 let common = fast_s.min(comp_s);
1350 let fast_cmp = fast.narrow(1, 0, common).unwrap();
1351 let comp_cmp = compact.narrow(1, 0, common).unwrap();
1352
1353 let fv: Vec<Vec<Vec<f32>>> = fast_cmp.to_vec3::<f32>().unwrap();
1354 let cv: Vec<Vec<Vec<f32>>> = comp_cmp.to_vec3::<f32>().unwrap();
1355 for (fb, cb) in fv.iter().zip(cv.iter()) {
1356 for (fr, cr) in fb.iter().zip(cb.iter()) {
1357 for (fv, cv) in fr.iter().zip(cr.iter()) {
1358 assert!(
1359 (fv - cv).abs() < 1e-5,
1360 "fast vs compact divergence: {fv} vs {cv}",
1361 );
1362 }
1363 }
1364 }
1365 }
1366
1367 #[hegel::test(test_cases = 100)]
1370 fn concatenate_single_is_identity(tc: TestCase) {
1371 let dev = test_device();
1372 let list = tc.draw(embedding_batch_list(dev));
1373 let only = list.into_iter().next().unwrap();
1374 let clone = only.to_vec3::<f32>().unwrap();
1375 let out = concatenate_embedding_batches(vec![only.clone()]).unwrap();
1376 let out_v: Vec<Vec<Vec<f32>>> = out.to_vec3::<f32>().unwrap();
1377 assert_eq!(clone, out_v);
1378 }
1379
1380 #[hegel::test(test_cases = 150)]
1384 fn concatenate_shape_and_zero_padding(tc: TestCase) {
1385 let dev = test_device();
1386 let list = tc.draw(embedding_batch_list(dev));
1387 let expected_batch: usize =
1388 list.iter().map(|t| t.dim(0).unwrap()).sum();
1389 let expected_tokens: usize =
1390 list.iter().map(|t| t.dim(1).unwrap()).max().unwrap();
1391 let expected_dim = list[0].dim(2).unwrap();
1392
1393 let originals: Vec<Vec<Vec<Vec<f32>>>> =
1394 list.iter().map(|t| t.to_vec3::<f32>().unwrap()).collect();
1395
1396 let out = concatenate_embedding_batches(list).unwrap();
1397 assert_eq!(out.dim(0).unwrap(), expected_batch);
1398 assert_eq!(out.dim(1).unwrap(), expected_tokens);
1399 assert_eq!(out.dim(2).unwrap(), expected_dim);
1400
1401 let out_v: Vec<Vec<Vec<f32>>> = out.to_vec3::<f32>().unwrap();
1402 let mut row = 0usize;
1403 for orig_batch in originals {
1404 let tokens_here = orig_batch[0].len();
1405 for orig_row in orig_batch {
1406 let out_row = &out_v[row];
1407 for (t, ot) in orig_row.iter().enumerate() {
1409 assert_eq!(&out_row[t], ot);
1410 }
1411 for (t, pad_row) in out_row.iter().enumerate().skip(tokens_here)
1413 {
1414 for v in pad_row {
1415 assert_eq!(
1416 *v, 0.0,
1417 "pad region at (row={row}, t={t}) not zero",
1418 );
1419 }
1420 }
1421 row += 1;
1422 }
1423 }
1424 }
1425
1426 fn naive_raw_similarity(q: &Tensor, d: &Tensor) -> Vec<Vec<Vec<Vec<f32>>>> {
1431 let qv: Vec<Vec<Vec<f32>>> = q.to_vec3::<f32>().unwrap();
1432 let dv: Vec<Vec<Vec<f32>>> = d.to_vec3::<f32>().unwrap();
1433 qv.iter()
1434 .map(|query| {
1435 dv.iter()
1436 .map(|doc| {
1437 query
1438 .iter()
1439 .map(|qt| {
1440 doc.iter()
1441 .map(|dt| {
1442 qt.iter()
1443 .zip(dt.iter())
1444 .map(|(a, b)| a * b)
1445 .sum::<f32>()
1446 })
1447 .collect::<Vec<f32>>()
1448 })
1449 .collect::<Vec<Vec<f32>>>()
1450 })
1451 .collect::<Vec<Vec<Vec<f32>>>>()
1452 })
1453 .collect()
1454 }
1455
1456 fn naive_max_sim(q: &Tensor, d: &Tensor) -> Vec<Vec<f32>> {
1457 naive_raw_similarity(q, d)
1458 .iter()
1459 .map(|query| {
1460 query
1461 .iter()
1462 .map(|doc| {
1463 doc.iter()
1464 .map(|per_qtok| {
1465 per_qtok
1466 .iter()
1467 .copied()
1468 .fold(f32::NEG_INFINITY, f32::max)
1469 })
1470 .sum::<f32>()
1471 })
1472 .collect::<Vec<f32>>()
1473 })
1474 .collect()
1475 }
1476
1477 fn approx_eq_matrix(a: &[Vec<f32>], b: &[Vec<f32>], tol: f32) {
1478 assert_eq!(a.len(), b.len());
1479 for (ra, rb) in a.iter().zip(b.iter()) {
1480 assert_eq!(ra.len(), rb.len());
1481 for (x, y) in ra.iter().zip(rb.iter()) {
1482 assert!(
1483 (x - y).abs() < tol,
1484 "matrix drift: {x} vs {y} (tol={tol})",
1485 );
1486 }
1487 }
1488 }
1489
1490 #[hegel::test(test_cases = 200)]
1492 fn similarity_matches_naive_maxsim(tc: TestCase) {
1493 let dev = test_device();
1494 let (q, d) = tc.draw(query_doc_pair(dev));
1495 let got = compute_similarities(&q, &d).unwrap();
1496 let want = naive_max_sim(&q, &d);
1497 approx_eq_matrix(&got.data, &want, 1e-4);
1498 }
1499
1500 #[hegel::test(test_cases = 150)]
1505 fn raw_similarity_matches_naive(tc: TestCase) {
1506 let dev = test_device();
1507 let (q, d) = tc.draw(query_doc_pair(dev));
1508 let raw = compute_raw_similarity(&q, &d).unwrap();
1509 let (nq, nd, qt, dt) = raw.dims4().unwrap();
1510 let flat = raw.reshape((nq * nd, qt, dt)).unwrap();
1511 let got: Vec<Vec<Vec<f32>>> = flat.to_vec3::<f32>().unwrap();
1512 let want = naive_raw_similarity(&q, &d);
1513
1514 let mut idx = 0usize;
1515 for query_block in &want {
1516 for doc_block in query_block {
1517 let got_slab = &got[idx];
1518 idx += 1;
1519 assert_eq!(got_slab.len(), doc_block.len());
1520 for (g_row, w_row) in got_slab.iter().zip(doc_block.iter()) {
1521 assert_eq!(g_row.len(), w_row.len());
1522 for (x, y) in g_row.iter().zip(w_row.iter()) {
1523 assert!(
1524 (x - y).abs() < 1e-4,
1525 "raw sim drift: {x} vs {y}",
1526 );
1527 }
1528 }
1529 }
1530 }
1531 assert_eq!(idx, nq * nd);
1532 }
1533
1534 #[hegel::test(test_cases = 100)]
1537 fn similarity_shape_contract(tc: TestCase) {
1538 let dev = test_device();
1539 let (q, d) = tc.draw(query_doc_pair(dev));
1540 let nq = q.dim(0).unwrap();
1541 let nd = d.dim(0).unwrap();
1542 let out = compute_similarities(&q, &d).unwrap();
1543 assert_eq!(out.data.len(), nq);
1544 for row in &out.data {
1545 assert_eq!(row.len(), nd);
1546 }
1547 }
1548
1549 #[hegel::test(test_cases = 150)]
1553 fn zero_doc_token_is_non_decreasing(tc: TestCase) {
1554 let dev = test_device();
1555 let (q, d) = tc.draw(query_doc_pair(dev.clone()));
1556 let (db, dt, dd) = d.dims3().unwrap();
1557 let zeros = Tensor::zeros((db, 1, dd), d.dtype(), &dev).unwrap();
1558 let d_padded = Tensor::cat(&[&d, &zeros], 1).unwrap();
1559 assert_eq!(d_padded.dim(1).unwrap(), dt + 1);
1560
1561 let before = compute_similarities(&q, &d).unwrap();
1562 let after = compute_similarities(&q, &d_padded).unwrap();
1563 for (rb, ra) in before.data.iter().zip(after.data.iter()) {
1564 for (vb, va) in rb.iter().zip(ra.iter()) {
1565 assert!(
1566 *va + 1e-4 >= *vb,
1567 "zero-doc-token decreased similarity: {vb} → {va}",
1568 );
1569 }
1570 }
1571 }
1572
1573 #[hegel::composite]
1594 fn weight_matrix(
1595 tc: TestCase,
1596 out_features: usize,
1597 in_features: usize,
1598 dev: Device,
1599 ) -> Tensor {
1600 let n = out_features * in_features;
1601 let data: Vec<f32> = tc.draw(
1602 gs::vecs(
1603 gs::floats::<f32>()
1604 .min_value(-1.0)
1605 .max_value(1.0)
1606 .allow_nan(false)
1607 .allow_infinity(false),
1608 )
1609 .min_size(n)
1610 .max_size(n),
1611 );
1612 Tensor::from_vec(data, (out_features, in_features), &dev).unwrap()
1613 }
1614
1615 #[hegel::composite]
1617 fn activations(
1618 tc: TestCase,
1619 batch: usize,
1620 tokens: usize,
1621 dim: usize,
1622 dev: Device,
1623 ) -> Tensor {
1624 let n = batch * tokens * dim;
1625 let data: Vec<f32> = tc.draw(
1626 gs::vecs(
1627 gs::floats::<f32>()
1628 .min_value(-1.0)
1629 .max_value(1.0)
1630 .allow_nan(false)
1631 .allow_infinity(false),
1632 )
1633 .min_size(n)
1634 .max_size(n),
1635 );
1636 Tensor::from_vec(data, (batch, tokens, dim), &dev).unwrap()
1637 }
1638
1639 fn max_abs_diff(a: &Tensor, b: &Tensor) -> f32 {
1642 let diff = (a - b).unwrap().abs().unwrap();
1643 let flat: Vec<f32> = diff.flatten_all().unwrap().to_vec1().unwrap();
1644 flat.into_iter().fold(0.0f32, f32::max)
1645 }
1646
1647 #[hegel::test(test_cases = 100)]
1651 fn dense_layer_without_residual_matches_plain_linear(tc: TestCase) {
1652 let dev = test_device();
1653 let in_dim: usize =
1654 tc.draw(gs::integers::<usize>().min_value(1).max_value(8));
1655 let out_dim: usize =
1656 tc.draw(gs::integers::<usize>().min_value(1).max_value(8));
1657 let batch: usize =
1658 tc.draw(gs::integers::<usize>().min_value(1).max_value(3));
1659 let tokens: usize =
1660 tc.draw(gs::integers::<usize>().min_value(1).max_value(4));
1661
1662 let w = tc.draw(weight_matrix(out_dim, in_dim, dev.clone()));
1663 let x = tc.draw(activations(batch, tokens, in_dim, dev));
1664
1665 let layer = DenseLayer {
1666 linear: Linear::new(w.clone(), None),
1667 residual: None,
1668 };
1669 let plain = Linear::new(w, None);
1670
1671 let got = layer.forward(&x).unwrap();
1672 let want = plain.forward(&x).unwrap();
1673 assert_eq!(got.dims(), want.dims());
1674 assert!(
1675 max_abs_diff(&got, &want) < 1e-5,
1676 "no-residual DenseLayer diverged from plain Linear",
1677 );
1678 }
1679
1680 #[hegel::test(test_cases = 200)]
1684 fn dense_layer_with_residual_matches_summed_weights(tc: TestCase) {
1685 let dev = test_device();
1686 let in_dim: usize =
1687 tc.draw(gs::integers::<usize>().min_value(1).max_value(8));
1688 let out_dim: usize =
1689 tc.draw(gs::integers::<usize>().min_value(1).max_value(8));
1690 let batch: usize =
1691 tc.draw(gs::integers::<usize>().min_value(1).max_value(3));
1692 let tokens: usize =
1693 tc.draw(gs::integers::<usize>().min_value(1).max_value(4));
1694
1695 let w_linear = tc.draw(weight_matrix(out_dim, in_dim, dev.clone()));
1696 let w_residual = tc.draw(weight_matrix(out_dim, in_dim, dev.clone()));
1697 let x = tc.draw(activations(batch, tokens, in_dim, dev));
1698
1699 let layer = DenseLayer {
1700 linear: Linear::new(w_linear.clone(), None),
1701 residual: Some(Linear::new(w_residual.clone(), None)),
1702 };
1703 let summed = Linear::new((&w_linear + &w_residual).unwrap(), None);
1704
1705 let got = layer.forward(&x).unwrap();
1706 let want = summed.forward(&x).unwrap();
1707 assert_eq!(got.dims(), want.dims());
1708 assert!(
1709 max_abs_diff(&got, &want) < 1e-4,
1710 "residual DenseLayer diverged from Linear(linear + residual)",
1711 );
1712 }
1713
1714 #[hegel::test(test_cases = 200)]
1719 fn two_layer_chain_equivalent_to_composed_weights(tc: TestCase) {
1720 let dev = test_device();
1721 let in_dim: usize =
1722 tc.draw(gs::integers::<usize>().min_value(1).max_value(6));
1723 let mid_dim: usize =
1724 tc.draw(gs::integers::<usize>().min_value(1).max_value(6));
1725 let out_dim: usize =
1726 tc.draw(gs::integers::<usize>().min_value(1).max_value(6));
1727 let batch: usize =
1728 tc.draw(gs::integers::<usize>().min_value(1).max_value(3));
1729 let tokens: usize =
1730 tc.draw(gs::integers::<usize>().min_value(1).max_value(4));
1731
1732 let w1 = tc.draw(weight_matrix(mid_dim, in_dim, dev.clone()));
1733 let w2 = tc.draw(weight_matrix(out_dim, mid_dim, dev.clone()));
1734 let x = tc.draw(activations(batch, tokens, in_dim, dev));
1735
1736 let layers = [
1737 DenseLayer {
1738 linear: Linear::new(w1.clone(), None),
1739 residual: None,
1740 },
1741 DenseLayer {
1742 linear: Linear::new(w2.clone(), None),
1743 residual: None,
1744 },
1745 ];
1746
1747 let mut iter = layers.iter();
1751 let first = iter.next().unwrap();
1752 let mut chain_out = first.forward(&x).unwrap();
1753 for layer in iter {
1754 chain_out = layer.forward(&chain_out).unwrap();
1755 }
1756
1757 let composed_weight = w2.matmul(&w1).unwrap();
1759 let composed = Linear::new(composed_weight, None);
1760 let reference = composed.forward(&x).unwrap();
1761
1762 assert_eq!(chain_out.dims(), reference.dims());
1763 assert!(
1764 max_abs_diff(&chain_out, &reference) < 1e-3,
1765 "two-layer chain diverged from composed-weight Linear",
1766 );
1767 }
1768
1769 #[hegel::test(test_cases = 100)]
1774 fn chain_output_dim_matches_last_layer_out_features(tc: TestCase) {
1775 let dev = test_device();
1776 let in_dim: usize =
1777 tc.draw(gs::integers::<usize>().min_value(1).max_value(6));
1778 let mid_dim: usize =
1779 tc.draw(gs::integers::<usize>().min_value(1).max_value(6));
1780 let final_dim: usize =
1781 tc.draw(gs::integers::<usize>().min_value(1).max_value(6));
1782 let batch: usize =
1783 tc.draw(gs::integers::<usize>().min_value(1).max_value(3));
1784 let tokens: usize =
1785 tc.draw(gs::integers::<usize>().min_value(1).max_value(4));
1786 let mid_has_residual: bool = tc.draw(gs::booleans());
1787
1788 let w1 = tc.draw(weight_matrix(mid_dim, in_dim, dev.clone()));
1789 let w1_res = mid_has_residual
1790 .then(|| tc.draw(weight_matrix(mid_dim, in_dim, dev.clone())));
1791 let w2 = tc.draw(weight_matrix(final_dim, mid_dim, dev.clone()));
1792 let x = tc.draw(activations(batch, tokens, in_dim, dev));
1793
1794 let layers = [
1795 DenseLayer {
1796 linear: Linear::new(w1, None),
1797 residual: w1_res.map(|w| Linear::new(w, None)),
1798 },
1799 DenseLayer {
1800 linear: Linear::new(w2, None),
1801 residual: None,
1802 },
1803 ];
1804
1805 let mut iter = layers.iter();
1806 let first = iter.next().unwrap();
1807 let mut out = first.forward(&x).unwrap();
1808 for layer in iter {
1809 out = layer.forward(&out).unwrap();
1810 }
1811 assert_eq!(out.dims(), &[batch, tokens, final_dim]);
1812 }
1813
1814 #[hegel::test(test_cases = 150)]
1819 fn similarity_linear_in_positive_query_scale(tc: TestCase) {
1820 let dev = test_device();
1821 let (q, d) = tc.draw(query_doc_pair(dev));
1822 let k: f32 = tc.draw(
1823 gs::floats::<f32>()
1824 .min_value(0.25)
1825 .max_value(4.0)
1826 .allow_nan(false)
1827 .allow_infinity(false),
1828 );
1829 let q_scaled = q.affine(f64::from(k), 0.0).unwrap();
1830
1831 let base = compute_similarities(&q, &d).unwrap();
1832 let scaled = compute_similarities(&q_scaled, &d).unwrap();
1833 for (rb, rs) in base.data.iter().zip(scaled.data.iter()) {
1834 for (vb, vs) in rb.iter().zip(rs.iter()) {
1835 assert!(
1836 (*vs - vb * k).abs() < 1e-3,
1837 "scale-linearity drift: k·{vb}={} vs {vs} (k={k})",
1838 vb * k,
1839 );
1840 }
1841 }
1842 }
1843}