Skip to main content

docbert_pylate/
model.rs

1use std::cmp::Reverse;
2
3use candle_core::{DType, Device, IndexOp, Tensor};
4use candle_nn::{Linear, Module, VarBuilder};
5use candle_transformers::models::bert::{BertModel, Config as BertConfig};
6use rayon::prelude::*;
7use tokenizers::{
8    Encoding,
9    PaddingParams,
10    PaddingStrategy,
11    Tokenizer,
12    pad_encodings,
13};
14
15use crate::{
16    builder::{ColbertBuilder, DenseModuleData},
17    error::ColbertError,
18    modernbert::{Config as ModernBertConfig, ModernBert},
19    types::Similarities,
20    utils::normalize_l2,
21};
22
23/// An enum to abstract over different underlying BERT-based models.
24///
25/// This allows `ColBERT` to use different architectures like
26/// `BertModel` or `ModernBert` without changing the core logic.
27///
28/// `ModernBert` and `BertModel` have very different struct sizes
29/// (`clippy::large_enum_variant` would suggest boxing one), but boxing would
30/// add an indirection on the hot forward path and we hold exactly one of
31/// these per `ColBERT` anyway, so the size difference is immaterial.
32#[allow(clippy::large_enum_variant)]
33pub enum BaseModel {
34    /// A variant holding a `ModernBert` model.
35    ModernBert(ModernBert),
36    /// A variant holding a standard `BertModel`.
37    Bert(BertModel),
38}
39
40impl BaseModel {
41    /// Performs a forward pass through the appropriate underlying model.
42    fn forward(
43        &self,
44        input_ids: &Tensor,
45        attention_mask: &Tensor,
46        token_type_ids: &Tensor,
47    ) -> Result<Tensor, candle_core::Error> {
48        match self {
49            BaseModel::ModernBert(model) => {
50                model.forward(input_ids, attention_mask)
51            }
52            BaseModel::Bert(model) => {
53                model.forward(input_ids, token_type_ids, Some(attention_mask))
54            }
55        }
56    }
57}
58
59/// Normalizes embeddings and zeros out masked rows without changing the sequence length.
60///
61/// This is the fast path used for document encoding. Documents are tokenized with
62/// right-padding to the batch longest sequence, so filtering rows out and then padding them
63/// back in produces the same layout as simply zeroing the masked rows after normalization.
64/// Keeping the whole operation on-device avoids per-row CPU roundtrips and tiny tensor ops.
65pub(crate) fn normalize_and_mask_padded(
66    embeddings: &Tensor,
67    attention_mask: &Tensor,
68) -> Result<Tensor, candle_core::Error> {
69    let normalized = normalize_l2(embeddings)?;
70    let mask = attention_mask.to_dtype(normalized.dtype())?.unsqueeze(2)?;
71    normalized.broadcast_mul(&mask)
72}
73
74/// Filters rows with a mask, normalizes the kept rows, and pads back to the batch max length.
75///
76/// Reference implementation retained to back the fast-path parity
77/// tests in `mod tests` below — production code paths use
78/// [`normalize_mask_and_truncate_right_padded`] instead.
79#[cfg_attr(not(test), allow(dead_code))]
80pub(crate) fn filter_normalize_and_pad_compact(
81    embeddings: &Tensor,
82    attention_mask: &Tensor,
83    device: &Device,
84) -> Result<Tensor, candle_core::Error> {
85    let (batch_size, _, dim) = embeddings.dims3()?;
86    let dtype = embeddings.dtype();
87    let mut processed_embeddings: Vec<Tensor> = Vec::with_capacity(batch_size);
88    let mut max_len = 0;
89
90    for i in 0..batch_size {
91        let single_embedding = embeddings.i(i)?;
92        let single_mask = attention_mask.i(i)?.to_vec1::<u32>()?;
93
94        let mut kept_rows = Vec::new();
95        for (j, &mask_val) in single_mask.iter().enumerate() {
96            if mask_val == 1 {
97                kept_rows.push(single_embedding.i(j)?);
98            }
99        }
100
101        let (normalized, current_len) = if kept_rows.is_empty() {
102            let zeros = Tensor::zeros((1, dim), dtype, device)?;
103            (zeros, 1)
104        } else {
105            let filtered = Tensor::stack(&kept_rows, 0)?;
106            let len = filtered.dim(0)?;
107            (normalize_l2(&filtered)?, len)
108        };
109
110        if current_len > max_len {
111            max_len = current_len;
112        }
113        processed_embeddings.push(normalized);
114    }
115
116    let mut padded_tensors = Vec::with_capacity(batch_size);
117    for tensor in &processed_embeddings {
118        let current_len = tensor.dim(0)?;
119        let dim = tensor.dim(1)?;
120        let pad_len = max_len - current_len;
121
122        if pad_len > 0 {
123            let padding = Tensor::zeros((pad_len, dim), dtype, device)?;
124            let padded = Tensor::cat(&[tensor, &padding], 0)?;
125            padded_tensors.push(padded);
126        } else {
127            padded_tensors.push(tensor.clone());
128        }
129    }
130
131    Tensor::stack(&padded_tensors, 0)
132}
133
134/// Fast path for right-padded masks: normalize on-device, zero masked rows, then trim the
135/// shared padded suffix down to the batch's maximum valid length.
136pub(crate) fn normalize_mask_and_truncate_right_padded(
137    embeddings: &Tensor,
138    attention_mask: &Tensor,
139    max_len: usize,
140) -> Result<Tensor, candle_core::Error> {
141    let masked = normalize_and_mask_padded(embeddings, attention_mask)?;
142    masked.narrow(1, 0, max_len.max(1))
143}
144
145pub(crate) fn concatenate_embedding_batches(
146    embeddings: Vec<Tensor>,
147) -> Result<Tensor, candle_core::Error> {
148    if embeddings.is_empty() {
149        return Err(candle_core::Error::Msg(
150            "embedding batches cannot be empty".into(),
151        ));
152    }
153    if embeddings.len() == 1 {
154        return Ok(embeddings.into_iter().next().unwrap());
155    }
156
157    let mut max_tokens = 0;
158    let mut needs_padding = false;
159    for batch in &embeddings {
160        let (_, tokens, _) = batch.dims3()?;
161        if max_tokens == 0 {
162            max_tokens = tokens;
163        } else if tokens != max_tokens {
164            needs_padding = true;
165            max_tokens = max_tokens.max(tokens);
166        }
167    }
168
169    if !needs_padding {
170        return Tensor::cat(&embeddings, 0);
171    }
172
173    let mut padded_batches = Vec::with_capacity(embeddings.len());
174    for batch in embeddings {
175        let (batch_size, tokens, dim) = batch.dims3()?;
176        if tokens == max_tokens {
177            padded_batches.push(batch);
178            continue;
179        }
180
181        let padding = Tensor::zeros(
182            (batch_size, max_tokens - tokens, dim),
183            batch.dtype(),
184            batch.device(),
185        )?;
186        padded_batches.push(Tensor::cat(&[&batch, &padding], 1)?);
187    }
188
189    Tensor::cat(&padded_batches, 0)
190}
191
192/// Computes MaxSim similarity scores between query and document embeddings.
193///
194/// `queries_embeddings` has shape `(n_queries, q_tokens, dim)` and
195/// `documents_embeddings` has shape `(n_documents, d_tokens, dim)`. The
196/// returned matrix is `(n_queries, n_documents)` where each entry is
197/// `Σ_t max_k dot(Q[i,t], D[j,k])`.
198pub(crate) fn compute_similarities(
199    queries_embeddings: &Tensor,
200    documents_embeddings: &Tensor,
201) -> Result<Similarities, ColbertError> {
202    let scores =
203        compute_raw_similarity(queries_embeddings, documents_embeddings)?;
204    let max_scores = scores.max(3)?;
205    let similarities = max_scores.sum(2)?;
206    let similarities_vec = similarities.to_vec2::<f32>()?;
207    Ok(Similarities {
208        data: similarities_vec,
209    })
210}
211
212/// Computes the raw, un-reduced similarity matrix between query and document embeddings.
213///
214/// Output shape is `(n_queries, n_documents, q_tokens, d_tokens)` where each
215/// entry is `dot(Q[i,t], D[j,k])`.
216pub(crate) fn compute_raw_similarity(
217    queries_embeddings: &Tensor,
218    documents_embeddings: &Tensor,
219) -> Result<Tensor, ColbertError> {
220    queries_embeddings
221        .unsqueeze(1)?
222        .broadcast_matmul(&documents_embeddings.transpose(1, 2)?.unsqueeze(0)?)
223        .map_err(ColbertError::from)
224}
225
226/// Builds the Dense layer chain from raw module bytes.
227///
228/// Each entry's `config.json` must be `{ in_features, out_features, bias,
229/// activation_function, use_residual? }`. We only support bias=false and
230/// `activation_function == "torch.nn.modules.linear.Identity"` because every
231/// PyLate ColBERT model we ship is built that way; anything else is a
232/// fail-fast in the loader rather than a silently-wrong projection.
233pub(crate) fn build_dense_layers(
234    dense_modules: Vec<DenseModuleData>,
235    device: &Device,
236) -> Result<Vec<DenseLayer>, ColbertError> {
237    const SUPPORTED_ACTIVATION: &str = "torch.nn.modules.linear.Identity";
238
239    let mut layers = Vec::with_capacity(dense_modules.len());
240    for (idx, module) in dense_modules.into_iter().enumerate() {
241        let cfg: serde_json::Value =
242            serde_json::from_slice(&module.config_bytes)?;
243
244        let activation = cfg["activation_function"]
245            .as_str()
246            .unwrap_or(SUPPORTED_ACTIVATION);
247        if activation != SUPPORTED_ACTIVATION {
248            return Err(ColbertError::Operation(format!(
249                "Dense module {idx}: unsupported activation_function '{activation}' (only {SUPPORTED_ACTIVATION} is supported)"
250            )));
251        }
252        if cfg["bias"].as_bool().unwrap_or(false) {
253            return Err(ColbertError::Operation(format!(
254                "Dense module {idx}: bias=true is not supported"
255            )));
256        }
257        let in_features = cfg["in_features"].as_u64().ok_or_else(|| {
258            ColbertError::Operation(format!(
259                "Dense module {idx}: missing 'in_features'"
260            ))
261        })? as usize;
262        let out_features = cfg["out_features"].as_u64().ok_or_else(|| {
263            ColbertError::Operation(format!(
264                "Dense module {idx}: missing 'out_features'"
265            ))
266        })? as usize;
267        let use_residual = cfg["use_residual"].as_bool().unwrap_or(false);
268
269        let vb = VarBuilder::from_buffered_safetensors(
270            module.weights_bytes,
271            DType::F32,
272            device,
273        )?;
274        let linear = candle_nn::linear_no_bias(
275            in_features,
276            out_features,
277            vb.pp("linear"),
278        )?;
279        let residual = if use_residual {
280            Some(candle_nn::linear_no_bias(
281                in_features,
282                out_features,
283                vb.pp("residual"),
284            )?)
285        } else {
286            None
287        };
288        layers.push(DenseLayer { linear, residual });
289    }
290    Ok(layers)
291}
292
293/// One Dense projection layer in the SentenceTransformers pipeline.
294///
295/// Mirrors PyLate's `Dense` module: a linear projection optionally summed
296/// with a parallel learned `residual` projection of the same shape. Both
297/// branches are bias-free in every model we ship, and the activation in the
298/// model's `config.json` is always `torch.nn.modules.linear.Identity`, so
299/// this struct doesn't carry an activation.
300pub(crate) struct DenseLayer {
301    pub(crate) linear: Linear,
302    pub(crate) residual: Option<Linear>,
303}
304
305impl DenseLayer {
306    /// Applies the dense layer: `linear(x) + residual(x)` when residual is
307    /// present, else just `linear(x)`.
308    pub(crate) fn forward(
309        &self,
310        x: &Tensor,
311    ) -> Result<Tensor, candle_core::Error> {
312        let proj = self.linear.forward(x)?;
313        match &self.residual {
314            Some(residual) => proj + residual.forward(x)?,
315            None => Ok(proj),
316        }
317    }
318}
319
320/// The main ColBERT model structure.
321///
322/// This struct encapsulates the language model, the chain of Dense
323/// projection layers declared in `modules.json`, the tokenizer, and all
324/// necessary configuration for performing encoding and similarity
325/// calculations based on the ColBERT architecture.
326pub struct ColBERT {
327    pub(crate) model: BaseModel,
328    pub(crate) dense_layers: Vec<DenseLayer>,
329    pub(crate) tokenizer: Tokenizer,
330    pub(crate) mask_token_id: u32,
331    pub(crate) mask_token: String,
332    pub(crate) query_prefix: String,
333    pub(crate) document_prefix: String,
334    pub(crate) query_prompt: String,
335    pub(crate) document_prompt: String,
336    pub(crate) do_query_expansion: bool,
337    pub(crate) attend_to_expansion_tokens: bool,
338    pub(crate) query_length: usize,
339    pub(crate) document_length: usize,
340    pub(crate) batch_size: usize,
341    /// The device (CPU or GPU) on which the model is loaded.
342    pub device: Device,
343}
344
345impl ColBERT {
346    /// Creates a new instance of the `ColBERT` model from byte buffers.
347    ///
348    /// `dense_modules` carries the ordered list of Dense projection layers
349    /// declared in `modules.json`. They are applied left-to-right after the
350    /// transformer, with the last layer's `out_features` as the final stored
351    /// embedding dimension.
352    #[allow(clippy::too_many_arguments)]
353    pub fn new(
354        weights: Vec<u8>,
355        dense_modules: Vec<DenseModuleData>,
356        tokenizer_bytes: Vec<u8>,
357        config_bytes: Vec<u8>,
358        query_prefix: String,
359        document_prefix: String,
360        query_prompt: String,
361        document_prompt: String,
362        mask_token: String,
363        do_query_expansion: bool,
364        attend_to_expansion_tokens: bool,
365        query_length: Option<usize>,
366        document_length: Option<usize>,
367        batch_size: Option<usize>,
368        device: &Device,
369    ) -> Result<Self, ColbertError> {
370        if dense_modules.is_empty() {
371            return Err(ColbertError::Operation(
372                "ColBERT requires at least one Dense projection layer".into(),
373            ));
374        }
375
376        let vb =
377            VarBuilder::from_buffered_safetensors(weights, DType::F32, device)?;
378
379        let config_value: serde_json::Value =
380            serde_json::from_slice(&config_bytes)?;
381        let architectures = config_value["architectures"]
382            .as_array()
383            .and_then(|arr| arr.first())
384            .and_then(|v| v.as_str())
385            .ok_or_else(|| {
386                ColbertError::Operation(
387                    "Missing or invalid 'architectures' in config.json".into(),
388                )
389            })?;
390
391        let model = match architectures {
392            "ModernBertModel" => {
393                let config: ModernBertConfig =
394                    serde_json::from_slice(&config_bytes)?;
395                let model = ModernBert::load(vb.clone(), &config)?;
396                BaseModel::ModernBert(model)
397            }
398            "BertForMaskedLM" | "BertModel" => {
399                let config: BertConfig = serde_json::from_slice(&config_bytes)?;
400                let model = BertModel::load(vb.clone(), &config)?;
401                BaseModel::Bert(model)
402            }
403            arch => {
404                return Err(ColbertError::Operation(format!(
405                    "Unsupported architecture: {}",
406                    arch
407                )));
408            }
409        };
410
411        let tokenizer = Tokenizer::from_bytes(&tokenizer_bytes)?;
412
413        let mask_token_id =
414            tokenizer.token_to_id(mask_token.as_str()).ok_or_else(|| {
415                ColbertError::Operation(format!(
416                    "Token '{}' not found in the tokenizer's vocabulary.",
417                    mask_token
418                ))
419            })?;
420
421        let dense_layers = build_dense_layers(dense_modules, device)?;
422
423        // If do_query_expansion is false, attend_to_expansion_tokens should also be false
424        let final_attend_to_expansion_tokens = if !do_query_expansion {
425            false
426        } else {
427            attend_to_expansion_tokens
428        };
429
430        Ok(Self {
431            model,
432            dense_layers,
433            tokenizer,
434            mask_token_id,
435            mask_token,
436            query_prefix,
437            document_prefix,
438            query_prompt,
439            document_prompt,
440            do_query_expansion,
441            attend_to_expansion_tokens: final_attend_to_expansion_tokens,
442            query_length: query_length.unwrap_or(32),
443            document_length: document_length.unwrap_or(180),
444            batch_size: batch_size.unwrap_or(32),
445            device: device.clone(),
446        })
447    }
448
449    /// Creates a `ColbertBuilder` to construct a `ColBERT` model from a Hugging Face repository.
450    pub fn from(repo_id: &str) -> ColbertBuilder {
451        ColbertBuilder::new(repo_id)
452    }
453
454    /// Finalizes projected embeddings after the linear layer.
455    ///
456    /// Queries without query expansion and documents both use an on-device right-padding fast
457    /// path that preserves the same batch max token count while avoiding row-by-row CPU work.
458    fn finalize_embeddings(
459        &self,
460        projected_embeddings: &Tensor,
461        attention_mask: &Tensor,
462        max_valid_len: usize,
463        is_query: bool,
464    ) -> Result<Tensor, candle_core::Error> {
465        if is_query && self.do_query_expansion {
466            normalize_l2(projected_embeddings).map_err(candle_core::Error::from)
467        } else {
468            normalize_mask_and_truncate_right_padded(
469                projected_embeddings,
470                attention_mask,
471                max_valid_len,
472            )
473        }
474    }
475
476    /// Applies the full Dense projection chain: `dense_layers[0]` first,
477    /// then `dense_layers[1]`, and so on. The constructor guarantees the
478    /// chain has at least one layer, so this never returns the input
479    /// unchanged.
480    pub(crate) fn project(
481        &self,
482        token_embeddings: &Tensor,
483    ) -> Result<Tensor, candle_core::Error> {
484        let mut iter = self.dense_layers.iter();
485        let first = iter
486            .next()
487            .expect("ColBERT::new guarantees at least one Dense layer");
488        let mut out = first.forward(token_embeddings)?;
489        for layer in iter {
490            out = layer.forward(&out)?;
491        }
492        Ok(out)
493    }
494
495    /// Compute the post-truncation token count for each document sentence.
496    ///
497    /// Mirrors the prefix + prompt + tokenization pipeline used inside
498    /// [`encode`] for documents so the returned lengths match the
499    /// per-row valid-token count of the encoded tensor exactly. Used by
500    /// callers that need to slice the padded output tensor without
501    /// scanning for all-zero rows on the host.
502    pub fn document_token_lengths(
503        &mut self,
504        sentences: &[String],
505    ) -> Result<Vec<u32>, ColbertError> {
506        if sentences.is_empty() {
507            return Ok(Vec::new());
508        }
509        let _ = self.tokenizer.with_truncation(Some(
510            tokenizers::TruncationParams {
511                max_length: self.document_length,
512                ..Default::default()
513            },
514        ));
515        // `encode_batch_fast` is the tokenizer's per-call fast path; we
516        // skip padding here because padding doesn't contribute to the
517        // valid-token count we care about.
518        self.tokenizer.with_padding(None);
519
520        let prompt = self.document_prompt.as_str();
521        let prefix = self.document_prefix.as_str();
522        let prefixed_texts: Vec<String> =
523            if prompt.is_empty() && prefix.is_empty() {
524                sentences.to_vec()
525            } else {
526                sentences
527                    .iter()
528                    .map(|text| format!("{prefix}{prompt}{text}"))
529                    .collect()
530            };
531
532        let encodings =
533            self.tokenizer.encode_batch_fast(prefixed_texts, true)?;
534        Ok(encodings.iter().map(|e| e.get_ids().len() as u32).collect())
535    }
536
537    /// Encode documents and return `(Tensor, per_doc_valid_token_counts)`.
538    ///
539    /// The lengths are returned in the caller-supplied order and index
540    /// directly into the returned 3D tensor's `axis=1`: row `i` has the
541    /// first `lengths[i]` rows populated with L2-normalized embeddings,
542    /// and the remaining rows (up to the batch's padded length) zeroed
543    /// by [`finalize_embeddings`].
544    ///
545    /// Intended for the docbert indexing path, which slices the tensor
546    /// per-doc before serializing embeddings. Returning real counts
547    /// lets callers skip the previous O(padded_tokens · dim) per-doc
548    /// all-zero scan.
549    pub fn encode_documents_with_lengths(
550        &mut self,
551        sentences: &[String],
552    ) -> Result<(Tensor, Vec<u32>), ColbertError> {
553        let lengths = self.document_token_lengths(sentences)?;
554        let embeddings = self.encode(sentences, false)?;
555        Ok((embeddings, lengths))
556    }
557
558    /// Encodes a batch of sentences (queries or documents) into embeddings.
559    ///
560    /// On CPU, this method leverages Rayon for parallel batch processing
561    /// to accelerate encoding. On accelerators (GPU), it processes batches sequentially.
562    pub fn encode(
563        &mut self,
564        sentences: &[String],
565        is_query: bool,
566    ) -> Result<Tensor, ColbertError> {
567        if sentences.is_empty() {
568            return Err(ColbertError::Operation(
569                "Input sentences cannot be empty.".into(),
570            ));
571        }
572
573        let prompt = if is_query {
574            &self.query_prompt
575        } else {
576            &self.document_prompt
577        };
578        let prompted: Vec<String>;
579        let sentences: &[String] = if prompt.is_empty() {
580            sentences
581        } else {
582            prompted =
583                sentences.iter().map(|s| format!("{prompt}{s}")).collect();
584            &prompted
585        };
586
587        if self.device.is_cpu() {
588            let mut tokenized_batches = Vec::new();
589            for batch_sentences in sentences.chunks(self.batch_size) {
590                tokenized_batches
591                    .push(self.tokenize(batch_sentences, is_query)?);
592            }
593
594            let all_embeddings = tokenized_batches
595                .into_par_iter()
596                .map(
597                    |(
598                        token_ids,
599                        attention_mask,
600                        token_type_ids,
601                        max_valid_len,
602                    )|
603                     -> Result<Tensor, ColbertError> {
604                        let token_embeddings = self.model.forward(
605                            &token_ids,
606                            &attention_mask,
607                            &token_type_ids,
608                        )?;
609                        let token_embeddings =
610                            if token_embeddings.is_contiguous() {
611                                token_embeddings
612                            } else {
613                                token_embeddings.contiguous()?
614                            };
615                        let projected_embeddings =
616                            self.project(&token_embeddings)?;
617
618                        self.finalize_embeddings(
619                            &projected_embeddings,
620                            &attention_mask,
621                            max_valid_len,
622                            is_query,
623                        )
624                        .map_err(ColbertError::from)
625                    },
626                )
627                .collect::<Result<Vec<_>, _>>()?;
628
629            return concatenate_embedding_batches(all_embeddings)
630                .map_err(ColbertError::from);
631        }
632
633        // Fallback to sequential processing for GPU, WASM, or other devices.
634        if !is_query && sentences.len() > self.batch_size {
635            let texts_with_prefix: Vec<_> = sentences
636                .iter()
637                .map(|text| format!("{}{}", self.document_prefix, text))
638                .collect();
639            let _ = self.tokenizer.with_truncation(Some(
640                tokenizers::TruncationParams {
641                    max_length: self.document_length,
642                    ..Default::default()
643                },
644            ));
645            self.tokenizer.with_padding(None);
646
647            let encodings =
648                self.tokenizer.encode_batch_fast(texts_with_prefix, true)?;
649            let mut indexed_encodings: Vec<(usize, Encoding)> =
650                encodings.into_iter().enumerate().collect();
651            indexed_encodings.sort_unstable_by_key(|(_, encoding)| {
652                Reverse(encoding.get_ids().len())
653            });
654
655            let mut inverse = vec![0u32; indexed_encodings.len()];
656            for (sorted_idx, (original_idx, _)) in
657                indexed_encodings.iter().enumerate()
658            {
659                inverse[*original_idx] = sorted_idx as u32;
660            }
661            let inverse_len = inverse.len();
662            let mut sorted_encodings: Vec<Encoding> = indexed_encodings
663                .into_iter()
664                .map(|(_, encoding)| encoding)
665                .collect();
666
667            let mut all_embeddings = Vec::with_capacity(
668                sorted_encodings.len().div_ceil(self.batch_size),
669            );
670            let padding = PaddingParams {
671                strategy: PaddingStrategy::BatchLongest,
672                ..Default::default()
673            };
674            let max_tokens_per_batch =
675                self.batch_size * self.document_length.max(1);
676            let mut batch_start = 0usize;
677            while batch_start < sorted_encodings.len() {
678                let first_len =
679                    sorted_encodings[batch_start].get_ids().len().max(1);
680                let batch_cap = (max_tokens_per_batch / first_len).max(1);
681                let batch_end =
682                    (batch_start + batch_cap).min(sorted_encodings.len());
683                let batch_encodings =
684                    &mut sorted_encodings[batch_start..batch_end];
685                let first_len = batch_encodings
686                    .first()
687                    .map_or(0, |encoding| encoding.get_ids().len());
688                let last_len = batch_encodings
689                    .last()
690                    .map_or(0, |encoding| encoding.get_ids().len());
691                let has_padding = first_len != last_len;
692                if has_padding {
693                    pad_encodings(batch_encodings, &padding)?;
694                }
695                let (token_ids, attention_mask, token_type_ids, max_valid_len) =
696                    self.tensorize_encodings(batch_encodings, false)?;
697
698                let token_embeddings = {
699                    #[cfg(feature = "cuda")]
700                    {
701                        let valid_lens = if has_padding {
702                            Some(
703                                batch_encodings
704                                    .iter()
705                                    .map(|encoding| encoding.get_ids().len())
706                                    .collect::<Vec<_>>(),
707                            )
708                        } else {
709                            None
710                        };
711
712                        if !has_padding {
713                            if let BaseModel::ModernBert(model) = &self.model {
714                                model.forward_unmasked(&token_ids)?
715                            } else {
716                                self.model.forward(
717                                    &token_ids,
718                                    &attention_mask,
719                                    &token_type_ids,
720                                )?
721                            }
722                        } else if let (
723                            BaseModel::ModernBert(model),
724                            Some(valid_lens),
725                        ) = (&self.model, valid_lens.as_ref())
726                        {
727                            model
728                                .forward_varlen_padded(&token_ids, valid_lens)?
729                        } else {
730                            self.model.forward(
731                                &token_ids,
732                                &attention_mask,
733                                &token_type_ids,
734                            )?
735                        }
736                    }
737                    #[cfg(not(feature = "cuda"))]
738                    {
739                        self.model.forward(
740                            &token_ids,
741                            &attention_mask,
742                            &token_type_ids,
743                        )?
744                    }
745                };
746                let token_embeddings = if token_embeddings.is_contiguous() {
747                    token_embeddings
748                } else {
749                    token_embeddings.contiguous()?
750                };
751                let projected_embeddings = self.project(&token_embeddings)?;
752                let final_embeddings = self.finalize_embeddings(
753                    &projected_embeddings,
754                    &attention_mask,
755                    max_valid_len,
756                    false,
757                )?;
758                all_embeddings.push(final_embeddings);
759                batch_start = batch_end;
760            }
761
762            let embeddings = concatenate_embedding_batches(all_embeddings)
763                .map_err(ColbertError::from)?;
764            let restore_indices =
765                Tensor::from_vec(inverse, inverse_len, &self.device)?;
766            return embeddings
767                .index_select(&restore_indices, 0)
768                .map_err(ColbertError::from);
769        }
770
771        let mut all_embeddings =
772            Vec::with_capacity(sentences.len().div_ceil(self.batch_size));
773        for batch_sentences in sentences.chunks(self.batch_size) {
774            let (token_ids, attention_mask, token_type_ids, max_valid_len) =
775                self.tokenize(batch_sentences, is_query)?;
776
777            let token_embeddings = self.model.forward(
778                &token_ids,
779                &attention_mask,
780                &token_type_ids,
781            )?;
782            let token_embeddings = if token_embeddings.is_contiguous() {
783                token_embeddings
784            } else {
785                token_embeddings.contiguous()?
786            };
787
788            let projected_embeddings = self.project(&token_embeddings)?;
789
790            let final_embeddings = self.finalize_embeddings(
791                &projected_embeddings,
792                &attention_mask,
793                max_valid_len,
794                is_query,
795            )?;
796
797            all_embeddings.push(final_embeddings);
798        }
799
800        concatenate_embedding_batches(all_embeddings)
801            .map_err(ColbertError::from)
802    }
803
804    /// Calculates the similarity scores between query and document embeddings.
805    pub fn similarity(
806        &self,
807        queries_embeddings: &Tensor,
808        documents_embeddings: &Tensor,
809    ) -> Result<Similarities, ColbertError> {
810        compute_similarities(queries_embeddings, documents_embeddings)
811    }
812
813    /// Computes the raw, un-reduced similarity matrix between query and document embeddings.
814    pub fn raw_similarity(
815        &self,
816        queries_embeddings: &Tensor,
817        documents_embeddings: &Tensor,
818    ) -> Result<Tensor, ColbertError> {
819        compute_raw_similarity(queries_embeddings, documents_embeddings)
820    }
821
822    fn tensorize_encodings(
823        &self,
824        encodings: &[Encoding],
825        is_query: bool,
826    ) -> Result<(Tensor, Tensor, Tensor, usize), ColbertError> {
827        let device = &self.device;
828        let batch_size = encodings.len();
829        if batch_size == 0 {
830            return Err(ColbertError::Operation(
831                "Input sentences cannot be empty.".into(),
832            ));
833        }
834
835        // Collect tokenization outputs into flat vectors. For documents, the padded sequence
836        // length already equals the batch max valid length. For non-expansion queries, compute the
837        // max valid length while we are already walking the CPU-side attention masks so the CUDA
838        // path can skip its own mask-length readback.
839        let seq_len = encodings.first().map_or(0, |e| e.get_ids().len());
840        let needs_query_valid_len = is_query
841            && !self.do_query_expansion
842            && !self.attend_to_expansion_tokens;
843        let needs_token_type_ids = matches!(&self.model, BaseModel::Bert(_));
844        let mut max_valid_len = if needs_query_valid_len {
845            1
846        } else {
847            seq_len.max(1)
848        };
849        let flat_len = batch_size * seq_len;
850        let mut ids_vec = Vec::<u32>::with_capacity(flat_len);
851        let mut mask_vec = Vec::<u32>::with_capacity(flat_len);
852        let mut type_ids_vec =
853            needs_token_type_ids.then(|| Vec::<u32>::with_capacity(flat_len));
854        for enc in encodings {
855            ids_vec.extend(enc.get_ids());
856            let attention = enc.get_attention_mask();
857            if needs_query_valid_len {
858                let mut valid_len = 0usize;
859                for &mask in attention {
860                    valid_len += mask as usize;
861                    mask_vec.push(mask);
862                }
863                max_valid_len = max_valid_len.max(valid_len.max(1));
864            } else {
865                mask_vec.extend(attention);
866            }
867            if let Some(type_ids_vec) = type_ids_vec.as_mut() {
868                type_ids_vec.extend(enc.get_type_ids());
869            }
870        }
871
872        let token_ids =
873            Tensor::from_vec(ids_vec, (batch_size, seq_len), device)?;
874        let mut attention_mask =
875            Tensor::from_vec(mask_vec, (batch_size, seq_len), device)?;
876        let token_type_ids = match type_ids_vec {
877            Some(type_ids_vec) => {
878                Tensor::from_vec(type_ids_vec, (batch_size, seq_len), device)?
879            }
880            None => Tensor::zeros((1, 1), DType::U32, device)?,
881        };
882
883        if is_query && self.attend_to_expansion_tokens {
884            attention_mask = attention_mask.ones_like()?;
885        }
886
887        Ok((token_ids, attention_mask, token_type_ids, max_valid_len))
888    }
889
890    /// Tokenizes a batch of texts, applying specific logic for queries and documents.
891    pub(crate) fn tokenize(
892        &mut self,
893        texts: &[String],
894        is_query: bool,
895    ) -> Result<(Tensor, Tensor, Tensor, usize), ColbertError> {
896        let (prefix, max_length) = if is_query {
897            (self.query_prefix.as_str(), self.query_length)
898        } else {
899            (self.document_prefix.as_str(), self.document_length)
900        };
901
902        let texts_with_prefix: Vec<_> = texts
903            .iter()
904            .map(|text| format!("{}{}", prefix, text))
905            .collect();
906
907        let _ = self.tokenizer.with_truncation(Some(
908            tokenizers::TruncationParams {
909                max_length,
910                ..Default::default()
911            },
912        ));
913
914        let padding_params = if is_query {
915            PaddingParams {
916                strategy: PaddingStrategy::Fixed(max_length),
917                pad_id: self.mask_token_id,
918                pad_token: self.mask_token.clone(),
919                ..Default::default()
920            }
921        } else {
922            PaddingParams {
923                strategy: PaddingStrategy::BatchLongest,
924                ..Default::default()
925            }
926        };
927        self.tokenizer.with_padding(Some(padding_params));
928
929        let encodings =
930            self.tokenizer.encode_batch_fast(texts_with_prefix, true)?;
931        self.tensorize_encodings(&encodings, is_query)
932    }
933}
934
935/// Test-only device selector shared by `mod tests` and `mod hegel_tests`.
936///
937/// Prefers CUDA when the `cuda` feature is on, Metal when `metal` is on,
938/// and falls back to CPU — with a runtime fall-back to CPU when the
939/// preferred accelerator can't be initialised (typical on CI builds of
940/// `--features cuda` without a GPU). Matches the selection order in
941/// `ColBERT::with_device`, so tests run on the same backend production
942/// encodes would use.
943#[cfg(test)]
944fn test_device() -> Device {
945    #[cfg(feature = "cuda")]
946    {
947        if let Ok(d) = Device::new_cuda(0) {
948            return d;
949        }
950    }
951    #[cfg(feature = "metal")]
952    {
953        if let Ok(d) = Device::new_metal(0) {
954            return d;
955        }
956    }
957    Device::Cpu
958}
959
960#[cfg(test)]
961mod tests {
962    use candle_core::{DType, Tensor};
963
964    use super::{
965        concatenate_embedding_batches,
966        filter_normalize_and_pad_compact,
967        normalize_and_mask_padded,
968        normalize_mask_and_truncate_right_padded,
969    };
970
971    #[test]
972    fn fast_document_path_matches_compact_path_for_right_padded_masks() {
973        let device = super::test_device();
974        let embeddings = Tensor::from_vec(
975            vec![
976                1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, // doc 1
977                9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, // doc 2
978            ],
979            (2, 4, 2),
980            &device,
981        )
982        .unwrap();
983        let attention_mask =
984            Tensor::from_vec(vec![1u32, 1, 1, 1, 1, 1, 0, 0], (2, 4), &device)
985                .unwrap();
986
987        let compact = filter_normalize_and_pad_compact(
988            &embeddings,
989            &attention_mask,
990            &device,
991        )
992        .unwrap();
993        let fast =
994            normalize_and_mask_padded(&embeddings, &attention_mask).unwrap();
995
996        let compact = compact.to_vec3::<f32>().unwrap();
997        let fast = fast.to_vec3::<f32>().unwrap();
998
999        assert_eq!(compact.len(), fast.len());
1000        for (compact_doc, fast_doc) in compact.iter().zip(fast.iter()) {
1001            assert_eq!(compact_doc.len(), fast_doc.len());
1002            for (compact_row, fast_row) in
1003                compact_doc.iter().zip(fast_doc.iter())
1004            {
1005                assert_eq!(compact_row.len(), fast_row.len());
1006                for (compact_value, fast_value) in
1007                    compact_row.iter().zip(fast_row.iter())
1008                {
1009                    assert!((compact_value - fast_value).abs() < 1e-6);
1010                }
1011            }
1012        }
1013    }
1014
1015    #[test]
1016    fn fast_query_path_matches_compact_path_for_right_padded_masks() {
1017        let device = super::test_device();
1018        let embeddings = Tensor::from_vec(
1019            vec![
1020                1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, // q1
1021                9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, // q2
1022            ],
1023            (2, 4, 2),
1024            &device,
1025        )
1026        .unwrap();
1027        let attention_mask =
1028            Tensor::from_vec(vec![1u32, 1, 1, 0, 1, 1, 0, 0], (2, 4), &device)
1029                .unwrap();
1030
1031        let compact = filter_normalize_and_pad_compact(
1032            &embeddings,
1033            &attention_mask,
1034            &device,
1035        )
1036        .unwrap();
1037        let fast = normalize_mask_and_truncate_right_padded(
1038            &embeddings,
1039            &attention_mask,
1040            3,
1041        )
1042        .unwrap();
1043
1044        assert_eq!(
1045            compact.to_vec3::<f32>().unwrap(),
1046            fast.to_vec3::<f32>().unwrap()
1047        );
1048    }
1049
1050    #[test]
1051    fn fast_document_path_zeroes_masked_rows() {
1052        let device = super::test_device();
1053        let embeddings = Tensor::from_vec(
1054            vec![1.0f32, 0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 2.0],
1055            (1, 4, 2),
1056            &device,
1057        )
1058        .unwrap();
1059        let attention_mask =
1060            Tensor::from_vec(vec![1u32, 1, 0, 0], (1, 4), &device).unwrap();
1061
1062        let fast = normalize_and_mask_padded(&embeddings, &attention_mask)
1063            .unwrap()
1064            .to_vec3::<f32>()
1065            .unwrap();
1066
1067        assert!((fast[0][0][0] - 1.0).abs() < 1e-6);
1068        assert!((fast[0][0][1] - 0.0).abs() < 1e-6);
1069        assert!((fast[0][1][0] - 0.0).abs() < 1e-6);
1070        assert!((fast[0][1][1] - 1.0).abs() < 1e-6);
1071        assert_eq!(fast[0][2], vec![0.0, 0.0]);
1072        assert_eq!(fast[0][3], vec![0.0, 0.0]);
1073    }
1074
1075    #[test]
1076    fn concatenate_embedding_batches_pads_variable_sequence_lengths() {
1077        let device = super::test_device();
1078        let first = Tensor::zeros((64, 514, 128), DType::F32, &device).unwrap();
1079        let second =
1080            Tensor::zeros((64, 519, 128), DType::F32, &device).unwrap();
1081
1082        assert!(Tensor::cat(&[&first, &second], 0).is_err());
1083
1084        let combined =
1085            concatenate_embedding_batches(vec![first, second]).unwrap();
1086        assert_eq!(combined.dims3().unwrap(), (128, 519, 128));
1087    }
1088}
1089
1090#[cfg(test)]
1091mod hegel_tests {
1092    //! Hegel property tests for `model.rs` internals.
1093    //!
1094    //! Covers two tiers that need `pub(crate)` access and therefore can't live
1095    //! in `tests/properties.rs`:
1096    //!
1097    //! - **C** — `compute_similarities` / `compute_raw_similarity`: MaxSim
1098    //!   differential against a hand-rolled reference, shape contract,
1099    //!   zero-doc-token monotonicity, and query-scaling linearity.
1100    //! - **E** — masking and concatenation helpers: the right-padded fast
1101    //!   path must match the compact path, `normalize_and_mask_padded` must
1102    //!   zero masked rows and bound unmasked rows, and
1103    //!   `concatenate_embedding_batches` must preserve the batch/dim invariants
1104    //!   while zero-padding short batches up to the longest token length.
1105    use candle_core::{Device, Tensor};
1106    use candle_nn::{Linear, Module};
1107    use hegel::{TestCase, generators as gs};
1108
1109    use super::{
1110        DenseLayer,
1111        compute_raw_similarity,
1112        compute_similarities,
1113        concatenate_embedding_batches,
1114        filter_normalize_and_pad_compact,
1115        normalize_and_mask_padded,
1116        normalize_mask_and_truncate_right_padded,
1117        test_device,
1118    };
1119
1120    // -----------------------------------------------------------------------
1121    // Shared generators
1122    // -----------------------------------------------------------------------
1123
1124    /// Draws an `(b, s, d)` embedding tensor plus an `(b, s)` attention mask.
1125    /// The mask is arbitrary 0/1 — used for properties that don't require
1126    /// right-padding (E1, E2).
1127    #[hegel::composite]
1128    fn embeddings_with_free_mask(
1129        tc: TestCase,
1130        dev: Device,
1131    ) -> (Tensor, Tensor) {
1132        let b: usize =
1133            tc.draw(gs::integers::<usize>().min_value(1).max_value(3));
1134        let s: usize =
1135            tc.draw(gs::integers::<usize>().min_value(1).max_value(8));
1136        let d: usize =
1137            tc.draw(gs::integers::<usize>().min_value(1).max_value(6));
1138        let emb_data: Vec<f32> = tc.draw(
1139            gs::vecs(
1140                gs::floats::<f32>()
1141                    .min_value(-5.0)
1142                    .max_value(5.0)
1143                    .allow_nan(false)
1144                    .allow_infinity(false),
1145            )
1146            .min_size(b * s * d)
1147            .max_size(b * s * d),
1148        );
1149        let mask_data: Vec<u32> = tc.draw(
1150            gs::vecs(gs::integers::<u32>().min_value(0).max_value(1))
1151                .min_size(b * s)
1152                .max_size(b * s),
1153        );
1154        let embeddings = Tensor::from_vec(emb_data, (b, s, d), &dev).unwrap();
1155        let mask = Tensor::from_vec(mask_data, (b, s), &dev).unwrap();
1156        (embeddings, mask)
1157    }
1158
1159    /// Draws an `(b, s, d)` embedding tensor plus a right-padded attention
1160    /// mask (each row is a prefix of 1s then 0s). Returns `(emb, mask,
1161    /// max_valid_len)`. This is exactly the precondition under which the
1162    /// fast and compact paths must agree (E3, E4).
1163    #[hegel::composite]
1164    fn embeddings_with_right_padded_mask(
1165        tc: TestCase,
1166        dev: Device,
1167    ) -> (Tensor, Tensor, usize) {
1168        let b: usize =
1169            tc.draw(gs::integers::<usize>().min_value(1).max_value(3));
1170        let s: usize =
1171            tc.draw(gs::integers::<usize>().min_value(1).max_value(8));
1172        let d: usize =
1173            tc.draw(gs::integers::<usize>().min_value(1).max_value(6));
1174        let emb_data: Vec<f32> = tc.draw(
1175            gs::vecs(
1176                gs::floats::<f32>()
1177                    .min_value(-5.0)
1178                    .max_value(5.0)
1179                    .allow_nan(false)
1180                    .allow_infinity(false),
1181            )
1182            .min_size(b * s * d)
1183            .max_size(b * s * d),
1184        );
1185        let mut mask_flat = Vec::<u32>::with_capacity(b * s);
1186        let mut max_valid = 0usize;
1187        for _ in 0..b {
1188            let valid: usize =
1189                tc.draw(gs::integers::<usize>().min_value(0).max_value(s));
1190            max_valid = max_valid.max(valid);
1191            for j in 0..s {
1192                mask_flat.push(u32::from(j < valid));
1193            }
1194        }
1195        let embeddings = Tensor::from_vec(emb_data, (b, s, d), &dev).unwrap();
1196        let mask = Tensor::from_vec(mask_flat, (b, s), &dev).unwrap();
1197        (embeddings, mask, max_valid)
1198    }
1199
1200    /// Draws a non-empty `Vec<Tensor>` with matching `(batch, dim)` and
1201    /// varying per-batch sequence length. Used to exercise
1202    /// `concatenate_embedding_batches`.
1203    #[hegel::composite]
1204    fn embedding_batch_list(tc: TestCase, dev: Device) -> Vec<Tensor> {
1205        let n_batches: usize =
1206            tc.draw(gs::integers::<usize>().min_value(1).max_value(4));
1207        let batch: usize =
1208            tc.draw(gs::integers::<usize>().min_value(1).max_value(3));
1209        let dim: usize =
1210            tc.draw(gs::integers::<usize>().min_value(1).max_value(4));
1211        let finite = || {
1212            gs::floats::<f32>()
1213                .min_value(-3.0)
1214                .max_value(3.0)
1215                .allow_nan(false)
1216                .allow_infinity(false)
1217        };
1218        let mut out = Vec::with_capacity(n_batches);
1219        for _ in 0..n_batches {
1220            let tokens: usize =
1221                tc.draw(gs::integers::<usize>().min_value(1).max_value(6));
1222            let data: Vec<f32> = tc.draw(
1223                gs::vecs(finite())
1224                    .min_size(batch * tokens * dim)
1225                    .max_size(batch * tokens * dim),
1226            );
1227            out.push(
1228                Tensor::from_vec(data, (batch, tokens, dim), &dev).unwrap(),
1229            );
1230        }
1231        out
1232    }
1233
1234    /// Draws query and document embeddings that share the last dim but can
1235    /// differ in batch size and token count. Used for every C-tier property.
1236    #[hegel::composite]
1237    fn query_doc_pair(tc: TestCase, dev: Device) -> (Tensor, Tensor) {
1238        let dim: usize =
1239            tc.draw(gs::integers::<usize>().min_value(1).max_value(6));
1240        let q_batch: usize =
1241            tc.draw(gs::integers::<usize>().min_value(1).max_value(3));
1242        let q_tokens: usize =
1243            tc.draw(gs::integers::<usize>().min_value(1).max_value(6));
1244        let d_batch: usize =
1245            tc.draw(gs::integers::<usize>().min_value(1).max_value(3));
1246        let d_tokens: usize =
1247            tc.draw(gs::integers::<usize>().min_value(1).max_value(6));
1248        let finite = || {
1249            gs::floats::<f32>()
1250                .min_value(-1.0)
1251                .max_value(1.0)
1252                .allow_nan(false)
1253                .allow_infinity(false)
1254        };
1255        let q_data: Vec<f32> = tc.draw(
1256            gs::vecs(finite())
1257                .min_size(q_batch * q_tokens * dim)
1258                .max_size(q_batch * q_tokens * dim),
1259        );
1260        let d_data: Vec<f32> = tc.draw(
1261            gs::vecs(finite())
1262                .min_size(d_batch * d_tokens * dim)
1263                .max_size(d_batch * d_tokens * dim),
1264        );
1265        let q =
1266            Tensor::from_vec(q_data, (q_batch, q_tokens, dim), &dev).unwrap();
1267        let d =
1268            Tensor::from_vec(d_data, (d_batch, d_tokens, dim), &dev).unwrap();
1269        (q, d)
1270    }
1271
1272    // -----------------------------------------------------------------------
1273    // E — masking / concat helpers
1274    // -----------------------------------------------------------------------
1275
1276    /// E1: `normalize_and_mask_padded` zeros every row whose mask is 0.
1277    /// E2: rows whose mask is 1 have squared L2 norm ≤ 1 + ε.
1278    /// These are one property in two assertions because the generator is
1279    /// shared.
1280    #[hegel::test(test_cases = 200)]
1281    fn normalize_and_mask_padded_respects_mask(tc: TestCase) {
1282        let dev = test_device();
1283        let (emb, mask) = tc.draw(embeddings_with_free_mask(dev));
1284        let out = normalize_and_mask_padded(&emb, &mask).unwrap();
1285        assert_eq!(out.dims(), emb.dims(), "shape must be preserved");
1286
1287        let out_v: Vec<Vec<Vec<f32>>> = out.to_vec3::<f32>().unwrap();
1288        let mask_v: Vec<Vec<u32>> = mask.to_vec2::<u32>().unwrap();
1289        for (b_idx, row_block) in out_v.iter().enumerate() {
1290            for (s_idx, row) in row_block.iter().enumerate() {
1291                let bit = mask_v[b_idx][s_idx];
1292                if bit == 0 {
1293                    for v in row {
1294                        assert_eq!(
1295                            *v, 0.0,
1296                            "masked row at ({b_idx},{s_idx}) not zeroed",
1297                        );
1298                    }
1299                } else {
1300                    let n2: f32 = row.iter().map(|v| v * v).sum();
1301                    assert!(
1302                        n2 <= 1.0 + 1e-4,
1303                        "unmasked row at ({b_idx},{s_idx}) has n²={n2}",
1304                    );
1305                }
1306            }
1307        }
1308    }
1309
1310    /// E3: `normalize_mask_and_truncate_right_padded` output shape is
1311    /// `(batch, max(max_len, 1), dim)`.
1312    #[hegel::test(test_cases = 200)]
1313    fn truncate_right_padded_has_expected_shape(tc: TestCase) {
1314        let dev = test_device();
1315        let (emb, mask, max_valid) =
1316            tc.draw(embeddings_with_right_padded_mask(dev));
1317        let (b, _, d) = emb.dims3().unwrap();
1318        let out =
1319            normalize_mask_and_truncate_right_padded(&emb, &mask, max_valid)
1320                .unwrap();
1321        assert_eq!(out.dim(0).unwrap(), b);
1322        assert_eq!(out.dim(1).unwrap(), max_valid.max(1));
1323        assert_eq!(out.dim(2).unwrap(), d);
1324    }
1325
1326    /// E4: under the right-padded-mask precondition the fast path and the
1327    /// compact path produce the same tensor. This is the single highest-value
1328    /// property in the suite — a bug that diverges the two would silently
1329    /// corrupt document embeddings.
1330    #[hegel::test(test_cases = 200)]
1331    fn truncate_right_padded_matches_compact(tc: TestCase) {
1332        let dev = test_device();
1333        let (emb, mask, max_valid) =
1334            tc.draw(embeddings_with_right_padded_mask(dev.clone()));
1335        let fast =
1336            normalize_mask_and_truncate_right_padded(&emb, &mask, max_valid)
1337                .unwrap();
1338        let compact =
1339            filter_normalize_and_pad_compact(&emb, &mask, &dev).unwrap();
1340
1341        // When every row in a given batch is masked out, the compact path
1342        // emits one zero row while the fast path emits `max(max_valid, 1)`
1343        // zero rows. Both are legitimate zero-padding layouts; only compare
1344        // the rows that the compact path actually produced.
1345        let (fast_b, fast_s, fast_d) = fast.dims3().unwrap();
1346        let (comp_b, comp_s, comp_d) = compact.dims3().unwrap();
1347        assert_eq!(fast_b, comp_b);
1348        assert_eq!(fast_d, comp_d);
1349        let common = fast_s.min(comp_s);
1350        let fast_cmp = fast.narrow(1, 0, common).unwrap();
1351        let comp_cmp = compact.narrow(1, 0, common).unwrap();
1352
1353        let fv: Vec<Vec<Vec<f32>>> = fast_cmp.to_vec3::<f32>().unwrap();
1354        let cv: Vec<Vec<Vec<f32>>> = comp_cmp.to_vec3::<f32>().unwrap();
1355        for (fb, cb) in fv.iter().zip(cv.iter()) {
1356            for (fr, cr) in fb.iter().zip(cb.iter()) {
1357                for (fv, cv) in fr.iter().zip(cr.iter()) {
1358                    assert!(
1359                        (fv - cv).abs() < 1e-5,
1360                        "fast vs compact divergence: {fv} vs {cv}",
1361                    );
1362                }
1363            }
1364        }
1365    }
1366
1367    /// E5: `concatenate_embedding_batches` is identity on a single-element
1368    /// input — the fast-path clone returns the tensor unchanged.
1369    #[hegel::test(test_cases = 100)]
1370    fn concatenate_single_is_identity(tc: TestCase) {
1371        let dev = test_device();
1372        let list = tc.draw(embedding_batch_list(dev));
1373        let only = list.into_iter().next().unwrap();
1374        let clone = only.to_vec3::<f32>().unwrap();
1375        let out = concatenate_embedding_batches(vec![only.clone()]).unwrap();
1376        let out_v: Vec<Vec<Vec<f32>>> = out.to_vec3::<f32>().unwrap();
1377        assert_eq!(clone, out_v);
1378    }
1379
1380    /// E6: concatenation preserves `dim`, sums `batch` across inputs, and
1381    /// takes the max `tokens` across inputs. E7: every row beyond a batch's
1382    /// original token count is zero.
1383    #[hegel::test(test_cases = 150)]
1384    fn concatenate_shape_and_zero_padding(tc: TestCase) {
1385        let dev = test_device();
1386        let list = tc.draw(embedding_batch_list(dev));
1387        let expected_batch: usize =
1388            list.iter().map(|t| t.dim(0).unwrap()).sum();
1389        let expected_tokens: usize =
1390            list.iter().map(|t| t.dim(1).unwrap()).max().unwrap();
1391        let expected_dim = list[0].dim(2).unwrap();
1392
1393        let originals: Vec<Vec<Vec<Vec<f32>>>> =
1394            list.iter().map(|t| t.to_vec3::<f32>().unwrap()).collect();
1395
1396        let out = concatenate_embedding_batches(list).unwrap();
1397        assert_eq!(out.dim(0).unwrap(), expected_batch);
1398        assert_eq!(out.dim(1).unwrap(), expected_tokens);
1399        assert_eq!(out.dim(2).unwrap(), expected_dim);
1400
1401        let out_v: Vec<Vec<Vec<f32>>> = out.to_vec3::<f32>().unwrap();
1402        let mut row = 0usize;
1403        for orig_batch in originals {
1404            let tokens_here = orig_batch[0].len();
1405            for orig_row in orig_batch {
1406                let out_row = &out_v[row];
1407                // Unpadded region must match the input verbatim.
1408                for (t, ot) in orig_row.iter().enumerate() {
1409                    assert_eq!(&out_row[t], ot);
1410                }
1411                // Padded region beyond the batch's own tokens is zero.
1412                for (t, pad_row) in out_row.iter().enumerate().skip(tokens_here)
1413                {
1414                    for v in pad_row {
1415                        assert_eq!(
1416                            *v, 0.0,
1417                            "pad region at (row={row}, t={t}) not zero",
1418                        );
1419                    }
1420                }
1421                row += 1;
1422            }
1423        }
1424    }
1425
1426    // -----------------------------------------------------------------------
1427    // C — similarity / raw_similarity
1428    // -----------------------------------------------------------------------
1429
1430    fn naive_raw_similarity(q: &Tensor, d: &Tensor) -> Vec<Vec<Vec<Vec<f32>>>> {
1431        let qv: Vec<Vec<Vec<f32>>> = q.to_vec3::<f32>().unwrap();
1432        let dv: Vec<Vec<Vec<f32>>> = d.to_vec3::<f32>().unwrap();
1433        qv.iter()
1434            .map(|query| {
1435                dv.iter()
1436                    .map(|doc| {
1437                        query
1438                            .iter()
1439                            .map(|qt| {
1440                                doc.iter()
1441                                    .map(|dt| {
1442                                        qt.iter()
1443                                            .zip(dt.iter())
1444                                            .map(|(a, b)| a * b)
1445                                            .sum::<f32>()
1446                                    })
1447                                    .collect::<Vec<f32>>()
1448                            })
1449                            .collect::<Vec<Vec<f32>>>()
1450                    })
1451                    .collect::<Vec<Vec<Vec<f32>>>>()
1452            })
1453            .collect()
1454    }
1455
1456    fn naive_max_sim(q: &Tensor, d: &Tensor) -> Vec<Vec<f32>> {
1457        naive_raw_similarity(q, d)
1458            .iter()
1459            .map(|query| {
1460                query
1461                    .iter()
1462                    .map(|doc| {
1463                        doc.iter()
1464                            .map(|per_qtok| {
1465                                per_qtok
1466                                    .iter()
1467                                    .copied()
1468                                    .fold(f32::NEG_INFINITY, f32::max)
1469                            })
1470                            .sum::<f32>()
1471                    })
1472                    .collect::<Vec<f32>>()
1473            })
1474            .collect()
1475    }
1476
1477    fn approx_eq_matrix(a: &[Vec<f32>], b: &[Vec<f32>], tol: f32) {
1478        assert_eq!(a.len(), b.len());
1479        for (ra, rb) in a.iter().zip(b.iter()) {
1480            assert_eq!(ra.len(), rb.len());
1481            for (x, y) in ra.iter().zip(rb.iter()) {
1482                assert!(
1483                    (x - y).abs() < tol,
1484                    "matrix drift: {x} vs {y} (tol={tol})",
1485                );
1486            }
1487        }
1488    }
1489
1490    /// C1: `compute_similarities` agrees with a hand-rolled MaxSim reference.
1491    #[hegel::test(test_cases = 200)]
1492    fn similarity_matches_naive_maxsim(tc: TestCase) {
1493        let dev = test_device();
1494        let (q, d) = tc.draw(query_doc_pair(dev));
1495        let got = compute_similarities(&q, &d).unwrap();
1496        let want = naive_max_sim(&q, &d);
1497        approx_eq_matrix(&got.data, &want, 1e-4);
1498    }
1499
1500    /// C2: `compute_raw_similarity` equals the pointwise `Q · Dᵀ` reference.
1501    /// Candle only exposes `to_vec0`…`to_vec3`, so we reshape the 4-D output
1502    /// `(nq, nd, qt, dt)` down to 3-D `(nq*nd, qt, dt)` and walk the flat
1503    /// reference in the same order.
1504    #[hegel::test(test_cases = 150)]
1505    fn raw_similarity_matches_naive(tc: TestCase) {
1506        let dev = test_device();
1507        let (q, d) = tc.draw(query_doc_pair(dev));
1508        let raw = compute_raw_similarity(&q, &d).unwrap();
1509        let (nq, nd, qt, dt) = raw.dims4().unwrap();
1510        let flat = raw.reshape((nq * nd, qt, dt)).unwrap();
1511        let got: Vec<Vec<Vec<f32>>> = flat.to_vec3::<f32>().unwrap();
1512        let want = naive_raw_similarity(&q, &d);
1513
1514        let mut idx = 0usize;
1515        for query_block in &want {
1516            for doc_block in query_block {
1517                let got_slab = &got[idx];
1518                idx += 1;
1519                assert_eq!(got_slab.len(), doc_block.len());
1520                for (g_row, w_row) in got_slab.iter().zip(doc_block.iter()) {
1521                    assert_eq!(g_row.len(), w_row.len());
1522                    for (x, y) in g_row.iter().zip(w_row.iter()) {
1523                        assert!(
1524                            (x - y).abs() < 1e-4,
1525                            "raw sim drift: {x} vs {y}",
1526                        );
1527                    }
1528                }
1529            }
1530        }
1531        assert_eq!(idx, nq * nd);
1532    }
1533
1534    /// C3: output shape is `(n_queries, n_documents)` — the plumbing must not
1535    /// drop or duplicate rows.
1536    #[hegel::test(test_cases = 100)]
1537    fn similarity_shape_contract(tc: TestCase) {
1538        let dev = test_device();
1539        let (q, d) = tc.draw(query_doc_pair(dev));
1540        let nq = q.dim(0).unwrap();
1541        let nd = d.dim(0).unwrap();
1542        let out = compute_similarities(&q, &d).unwrap();
1543        assert_eq!(out.data.len(), nq);
1544        for row in &out.data {
1545            assert_eq!(row.len(), nd);
1546        }
1547    }
1548
1549    /// C4: appending a zero-valued token row to every document cannot reduce
1550    /// the similarity — `max_k` now includes `0.0` as an option, so the per-
1551    /// query-token max is non-decreasing and the sum follows.
1552    #[hegel::test(test_cases = 150)]
1553    fn zero_doc_token_is_non_decreasing(tc: TestCase) {
1554        let dev = test_device();
1555        let (q, d) = tc.draw(query_doc_pair(dev.clone()));
1556        let (db, dt, dd) = d.dims3().unwrap();
1557        let zeros = Tensor::zeros((db, 1, dd), d.dtype(), &dev).unwrap();
1558        let d_padded = Tensor::cat(&[&d, &zeros], 1).unwrap();
1559        assert_eq!(d_padded.dim(1).unwrap(), dt + 1);
1560
1561        let before = compute_similarities(&q, &d).unwrap();
1562        let after = compute_similarities(&q, &d_padded).unwrap();
1563        for (rb, ra) in before.data.iter().zip(after.data.iter()) {
1564            for (vb, va) in rb.iter().zip(ra.iter()) {
1565                assert!(
1566                    *va + 1e-4 >= *vb,
1567                    "zero-doc-token decreased similarity: {vb} → {va}",
1568                );
1569            }
1570        }
1571    }
1572
1573    // -----------------------------------------------------------------------
1574    // Tier F — Dense projection chain
1575    //
1576    // The PyLate Dense module pipeline is the slice of the loader most
1577    // exposed to silent dimension drift: the model card promises a final
1578    // `out_features`, but the real number is whatever the *last* loaded
1579    // Dense layer emits. These properties pin the chain semantics with
1580    // differential checks against linear-algebra equivalents (so the
1581    // reference doesn't share the SUT's iteration logic):
1582    //
1583    // - **F1** — no-residual layer collapses to a plain `Linear`.
1584    // - **F2** — residual sum is equivalent to a single `Linear` whose
1585    //   weight is `linear + residual`.
1586    // - **F3** — a two-layer chain composes like `Linear(W2 @ W1)`.
1587    // - **F4** — chain output dim equals the last layer's `out_features`,
1588    //   regardless of intermediate widths.
1589    // -----------------------------------------------------------------------
1590
1591    /// Generator: an `(out, in)` weight matrix in a numerically friendly
1592    /// range so f32 matmul drift stays below the assert tolerance.
1593    #[hegel::composite]
1594    fn weight_matrix(
1595        tc: TestCase,
1596        out_features: usize,
1597        in_features: usize,
1598        dev: Device,
1599    ) -> Tensor {
1600        let n = out_features * in_features;
1601        let data: Vec<f32> = tc.draw(
1602            gs::vecs(
1603                gs::floats::<f32>()
1604                    .min_value(-1.0)
1605                    .max_value(1.0)
1606                    .allow_nan(false)
1607                    .allow_infinity(false),
1608            )
1609            .min_size(n)
1610            .max_size(n),
1611        );
1612        Tensor::from_vec(data, (out_features, in_features), &dev).unwrap()
1613    }
1614
1615    /// Generator: an `(batch, tokens, dim)` activation tensor, also bounded.
1616    #[hegel::composite]
1617    fn activations(
1618        tc: TestCase,
1619        batch: usize,
1620        tokens: usize,
1621        dim: usize,
1622        dev: Device,
1623    ) -> Tensor {
1624        let n = batch * tokens * dim;
1625        let data: Vec<f32> = tc.draw(
1626            gs::vecs(
1627                gs::floats::<f32>()
1628                    .min_value(-1.0)
1629                    .max_value(1.0)
1630                    .allow_nan(false)
1631                    .allow_infinity(false),
1632            )
1633            .min_size(n)
1634            .max_size(n),
1635        );
1636        Tensor::from_vec(data, (batch, tokens, dim), &dev).unwrap()
1637    }
1638
1639    /// Element-wise max-abs distance between two tensors as a scalar `f32`.
1640    /// Crashes on shape mismatch — that's the property we want to fail loudly.
1641    fn max_abs_diff(a: &Tensor, b: &Tensor) -> f32 {
1642        let diff = (a - b).unwrap().abs().unwrap();
1643        let flat: Vec<f32> = diff.flatten_all().unwrap().to_vec1().unwrap();
1644        flat.into_iter().fold(0.0f32, f32::max)
1645    }
1646
1647    /// F1: a residual-less Dense layer must produce exactly what
1648    /// `candle_nn::Linear::new(weight, None)` produces — the residual
1649    /// branch is the only behaviour change relative to a plain linear.
1650    #[hegel::test(test_cases = 100)]
1651    fn dense_layer_without_residual_matches_plain_linear(tc: TestCase) {
1652        let dev = test_device();
1653        let in_dim: usize =
1654            tc.draw(gs::integers::<usize>().min_value(1).max_value(8));
1655        let out_dim: usize =
1656            tc.draw(gs::integers::<usize>().min_value(1).max_value(8));
1657        let batch: usize =
1658            tc.draw(gs::integers::<usize>().min_value(1).max_value(3));
1659        let tokens: usize =
1660            tc.draw(gs::integers::<usize>().min_value(1).max_value(4));
1661
1662        let w = tc.draw(weight_matrix(out_dim, in_dim, dev.clone()));
1663        let x = tc.draw(activations(batch, tokens, in_dim, dev));
1664
1665        let layer = DenseLayer {
1666            linear: Linear::new(w.clone(), None),
1667            residual: None,
1668        };
1669        let plain = Linear::new(w, None);
1670
1671        let got = layer.forward(&x).unwrap();
1672        let want = plain.forward(&x).unwrap();
1673        assert_eq!(got.dims(), want.dims());
1674        assert!(
1675            max_abs_diff(&got, &want) < 1e-5,
1676            "no-residual DenseLayer diverged from plain Linear",
1677        );
1678    }
1679
1680    /// F2: with residual, `DenseLayer { linear: W1, residual: W2 }` must
1681    /// equal `Linear(W1 + W2)` because both branches are bias-free linear
1682    /// maps over the same input — addition distributes over matmul.
1683    #[hegel::test(test_cases = 200)]
1684    fn dense_layer_with_residual_matches_summed_weights(tc: TestCase) {
1685        let dev = test_device();
1686        let in_dim: usize =
1687            tc.draw(gs::integers::<usize>().min_value(1).max_value(8));
1688        let out_dim: usize =
1689            tc.draw(gs::integers::<usize>().min_value(1).max_value(8));
1690        let batch: usize =
1691            tc.draw(gs::integers::<usize>().min_value(1).max_value(3));
1692        let tokens: usize =
1693            tc.draw(gs::integers::<usize>().min_value(1).max_value(4));
1694
1695        let w_linear = tc.draw(weight_matrix(out_dim, in_dim, dev.clone()));
1696        let w_residual = tc.draw(weight_matrix(out_dim, in_dim, dev.clone()));
1697        let x = tc.draw(activations(batch, tokens, in_dim, dev));
1698
1699        let layer = DenseLayer {
1700            linear: Linear::new(w_linear.clone(), None),
1701            residual: Some(Linear::new(w_residual.clone(), None)),
1702        };
1703        let summed = Linear::new((&w_linear + &w_residual).unwrap(), None);
1704
1705        let got = layer.forward(&x).unwrap();
1706        let want = summed.forward(&x).unwrap();
1707        assert_eq!(got.dims(), want.dims());
1708        assert!(
1709            max_abs_diff(&got, &want) < 1e-4,
1710            "residual DenseLayer diverged from Linear(linear + residual)",
1711        );
1712    }
1713
1714    /// F3: a two-layer no-residual chain must be equivalent to a single
1715    /// `Linear` whose weight is the product of the layer weights, because
1716    /// matmul is associative: `(x @ W1.T) @ W2.T = x @ (W2 @ W1).T`. This
1717    /// catches any reversed-iteration bug in `ColBERT::project`.
1718    #[hegel::test(test_cases = 200)]
1719    fn two_layer_chain_equivalent_to_composed_weights(tc: TestCase) {
1720        let dev = test_device();
1721        let in_dim: usize =
1722            tc.draw(gs::integers::<usize>().min_value(1).max_value(6));
1723        let mid_dim: usize =
1724            tc.draw(gs::integers::<usize>().min_value(1).max_value(6));
1725        let out_dim: usize =
1726            tc.draw(gs::integers::<usize>().min_value(1).max_value(6));
1727        let batch: usize =
1728            tc.draw(gs::integers::<usize>().min_value(1).max_value(3));
1729        let tokens: usize =
1730            tc.draw(gs::integers::<usize>().min_value(1).max_value(4));
1731
1732        let w1 = tc.draw(weight_matrix(mid_dim, in_dim, dev.clone()));
1733        let w2 = tc.draw(weight_matrix(out_dim, mid_dim, dev.clone()));
1734        let x = tc.draw(activations(batch, tokens, in_dim, dev));
1735
1736        let layers = [
1737            DenseLayer {
1738                linear: Linear::new(w1.clone(), None),
1739                residual: None,
1740            },
1741            DenseLayer {
1742                linear: Linear::new(w2.clone(), None),
1743                residual: None,
1744            },
1745        ];
1746
1747        // SUT mirrors `ColBERT::project`'s left-to-right fold without
1748        // depending on the ColBERT struct (which would require loading a
1749        // full transformer just to test its projection chain).
1750        let mut iter = layers.iter();
1751        let first = iter.next().unwrap();
1752        let mut chain_out = first.forward(&x).unwrap();
1753        for layer in iter {
1754            chain_out = layer.forward(&chain_out).unwrap();
1755        }
1756
1757        // Reference: composed-weight `Linear` (out_dim x in_dim).
1758        let composed_weight = w2.matmul(&w1).unwrap();
1759        let composed = Linear::new(composed_weight, None);
1760        let reference = composed.forward(&x).unwrap();
1761
1762        assert_eq!(chain_out.dims(), reference.dims());
1763        assert!(
1764            max_abs_diff(&chain_out, &reference) < 1e-3,
1765            "two-layer chain diverged from composed-weight Linear",
1766        );
1767    }
1768
1769    /// F4: the chain's output dim is whatever the last layer's `linear`
1770    /// emits, regardless of intermediate widths or whether intermediate
1771    /// layers carry a residual branch. This is the property the LateOn
1772    /// fix turns on: 768 → 1536 → 768 → 128 must end at 128, not 1536.
1773    #[hegel::test(test_cases = 100)]
1774    fn chain_output_dim_matches_last_layer_out_features(tc: TestCase) {
1775        let dev = test_device();
1776        let in_dim: usize =
1777            tc.draw(gs::integers::<usize>().min_value(1).max_value(6));
1778        let mid_dim: usize =
1779            tc.draw(gs::integers::<usize>().min_value(1).max_value(6));
1780        let final_dim: usize =
1781            tc.draw(gs::integers::<usize>().min_value(1).max_value(6));
1782        let batch: usize =
1783            tc.draw(gs::integers::<usize>().min_value(1).max_value(3));
1784        let tokens: usize =
1785            tc.draw(gs::integers::<usize>().min_value(1).max_value(4));
1786        let mid_has_residual: bool = tc.draw(gs::booleans());
1787
1788        let w1 = tc.draw(weight_matrix(mid_dim, in_dim, dev.clone()));
1789        let w1_res = mid_has_residual
1790            .then(|| tc.draw(weight_matrix(mid_dim, in_dim, dev.clone())));
1791        let w2 = tc.draw(weight_matrix(final_dim, mid_dim, dev.clone()));
1792        let x = tc.draw(activations(batch, tokens, in_dim, dev));
1793
1794        let layers = [
1795            DenseLayer {
1796                linear: Linear::new(w1, None),
1797                residual: w1_res.map(|w| Linear::new(w, None)),
1798            },
1799            DenseLayer {
1800                linear: Linear::new(w2, None),
1801                residual: None,
1802            },
1803        ];
1804
1805        let mut iter = layers.iter();
1806        let first = iter.next().unwrap();
1807        let mut out = first.forward(&x).unwrap();
1808        for layer in iter {
1809            out = layer.forward(&out).unwrap();
1810        }
1811        assert_eq!(out.dims(), &[batch, tokens, final_dim]);
1812    }
1813
1814    /// C5: scaling queries uniformly by `k > 0` scales the similarity matrix
1815    /// by `k`. MaxSim is `Σ_t max_k dot(k·Q[i,t], D[j,k]) = k · Σ_t max_k
1816    /// dot(Q[i,t], D[j,k])` because `k > 0` preserves the argmax of each
1817    /// per-q-token inner dot-product row.
1818    #[hegel::test(test_cases = 150)]
1819    fn similarity_linear_in_positive_query_scale(tc: TestCase) {
1820        let dev = test_device();
1821        let (q, d) = tc.draw(query_doc_pair(dev));
1822        let k: f32 = tc.draw(
1823            gs::floats::<f32>()
1824                .min_value(0.25)
1825                .max_value(4.0)
1826                .allow_nan(false)
1827                .allow_infinity(false),
1828        );
1829        let q_scaled = q.affine(f64::from(k), 0.0).unwrap();
1830
1831        let base = compute_similarities(&q, &d).unwrap();
1832        let scaled = compute_similarities(&q_scaled, &d).unwrap();
1833        for (rb, rs) in base.data.iter().zip(scaled.data.iter()) {
1834            for (vb, vs) in rb.iter().zip(rs.iter()) {
1835                assert!(
1836                    (*vs - vb * k).abs() < 1e-3,
1837                    "scale-linearity drift: k·{vb}={} vs {vs} (k={k})",
1838                    vb * k,
1839                );
1840            }
1841        }
1842    }
1843}