Skip to main content

entrenar/transformer/
encoder.rs

1//! Complete encoder model (BERT/RoBERTa/CodeBERT)
2//!
3//! Provides the full encoder pipeline:
4//! ```text
5//! token_ids → Embedding + PositionEmbedding + TokenTypeEmbedding → LayerNorm
6//!           → N × EncoderBlock → [seq_len, hidden_size]
7//! ```
8//!
9//! For classification, use CLS pooling on the output (position 0).
10
11use crate::autograd::add;
12use crate::error::{Error, Result};
13use crate::Tensor;
14use std::collections::HashMap;
15use std::path::Path;
16
17use super::config::TransformerConfig;
18use super::embedding::{Embedding, LearnedPositionEmbedding};
19use super::encoder_block::EncoderBlock;
20use super::norm::LayerNorm;
21use super::weights::{load_safetensors_weights, Architecture};
22
23/// Complete encoder model (BERT/RoBERTa/CodeBERT).
24pub struct EncoderModel {
25    /// Configuration
26    pub config: TransformerConfig,
27    /// Token embedding
28    pub embed_tokens: Embedding,
29    /// Position embedding (learned absolute)
30    pub position_embeddings: LearnedPositionEmbedding,
31    /// Token type embedding (segment A/B, optional for RoBERTa but present in weights)
32    pub token_type_embeddings: Option<Embedding>,
33    /// Post-embedding LayerNorm
34    pub embeddings_layernorm: LayerNorm,
35    /// Encoder layers
36    pub layers: Vec<EncoderBlock>,
37}
38
39impl EncoderModel {
40    /// Create new encoder with default initialization.
41    pub fn new(config: &TransformerConfig) -> Self {
42        let max_positions = config.max_position_embeddings;
43        let eps = config.rms_norm_eps;
44        let layers = (0..config.num_hidden_layers).map(|i| EncoderBlock::new(config, i)).collect();
45
46        Self {
47            config: config.clone(),
48            embed_tokens: Embedding::new(config.vocab_size, config.hidden_size),
49            position_embeddings: LearnedPositionEmbedding::new(max_positions, config.hidden_size),
50            token_type_embeddings: Some(Embedding::new(2, config.hidden_size)),
51            embeddings_layernorm: LayerNorm::new(config.hidden_size, eps),
52            layers,
53        }
54    }
55
56    /// Create encoder from pre-trained parameters (after RoBERTa name mapping).
57    ///
58    /// Expected parameter names:
59    /// - `encoder.embed_tokens.weight`
60    /// - `encoder.position_embeddings.weight`
61    /// - `encoder.token_type_embeddings.weight` (optional)
62    /// - `encoder.embeddings_layernorm.{weight,bias}`
63    /// - `encoder.layers.{i}.*`
64    pub fn from_params(
65        config: &TransformerConfig,
66        params: &HashMap<String, Tensor>,
67    ) -> Option<Self> {
68        let max_positions = config.max_position_embeddings;
69
70        let embed_tokens = Embedding::from_params(
71            params,
72            "encoder.embed_tokens.weight",
73            config.vocab_size,
74            config.hidden_size,
75        )?;
76
77        let position_embeddings = LearnedPositionEmbedding::from_params(
78            params,
79            "encoder.position_embeddings.weight",
80            max_positions,
81            config.hidden_size,
82        )?;
83
84        // Token type embeddings are optional (RoBERTa has them but sets to zero)
85        // CodeBERT has type_vocab_size=1, standard BERT/RoBERTa has 2.
86        // Infer from actual tensor shape rather than hardcoding.
87        let token_type_embeddings =
88            params.get("encoder.token_type_embeddings.weight").and_then(|tensor| {
89                let type_vocab_size = tensor.len() / config.hidden_size;
90                if type_vocab_size == 0 || tensor.len() != type_vocab_size * config.hidden_size {
91                    return None;
92                }
93                Embedding::from_params(
94                    params,
95                    "encoder.token_type_embeddings.weight",
96                    type_vocab_size,
97                    config.hidden_size,
98                )
99            });
100
101        let embeddings_layernorm = LayerNorm::from_params(
102            params,
103            "encoder.embeddings_layernorm",
104            config.rms_norm_eps,
105            config.hidden_size,
106        )?;
107
108        let layers: Option<Vec<EncoderBlock>> = (0..config.num_hidden_layers)
109            .map(|i| EncoderBlock::from_params(config, params, i))
110            .collect();
111        let layers = layers?;
112
113        Some(Self {
114            config: config.clone(),
115            embed_tokens,
116            position_embeddings,
117            token_type_embeddings,
118            embeddings_layernorm,
119            layers,
120        })
121    }
122
123    /// Load encoder from SafeTensors file(s).
124    pub fn from_safetensors(config: &TransformerConfig, model_path: &Path) -> Result<Self> {
125        let weights = load_safetensors_weights(model_path, Architecture::RoBERTa)?;
126        Self::from_params(config, &weights).ok_or_else(|| {
127            Error::ConfigError("Failed to construct encoder from loaded weights".into())
128        })
129    }
130
131    /// Forward pass: token_ids → hidden states [seq_len × hidden_size].
132    ///
133    /// # Arguments
134    /// * `token_ids` - Input token IDs
135    ///
136    /// # Returns
137    /// Hidden states tensor (seq_len * hidden_size, flattened)
138    pub fn forward(&self, token_ids: &[u32]) -> Tensor {
139        let seq_len = token_ids.len();
140        let h = self.config.hidden_size;
141
142        // Token embeddings
143        let token_emb = self.embed_tokens.forward(token_ids);
144
145        // Position embeddings
146        let pos_emb = self.position_embeddings.forward(seq_len);
147
148        // Add token + position embeddings
149        let mut combined = add(&token_emb, &pos_emb);
150
151        // Add token type embeddings (all zeros = segment A)
152        if let Some(ref tte) = self.token_type_embeddings {
153            let type_ids: Vec<u32> = vec![0; seq_len];
154            let type_emb = tte.forward(&type_ids);
155            combined = add(&combined, &type_emb);
156        }
157
158        // Post-embedding LayerNorm
159        let mut hidden = self.embeddings_layernorm.forward_batched(&combined, seq_len, h);
160
161        // Pass through encoder layers
162        for layer in &self.layers {
163            hidden = layer.forward(&hidden, seq_len);
164        }
165
166        hidden
167    }
168
169    /// Extract [CLS] embedding (position 0) from hidden states.
170    ///
171    /// For classification, the [CLS] token at position 0 attends bidirectionally
172    /// to all other tokens, making it a summary representation.
173    pub fn cls_embedding(&self, token_ids: &[u32]) -> Tensor {
174        let hidden = self.forward(token_ids);
175        let h = self.config.hidden_size;
176        let data = hidden.data();
177        let slice = data.as_slice().expect("hidden contiguous");
178        Tensor::from_vec(slice[..h].to_vec(), false)
179    }
180
181    /// Get total parameter count.
182    pub fn num_parameters(&self) -> usize {
183        let mut count = 0;
184        count += self.embed_tokens.vocab_size() * self.embed_tokens.hidden_size();
185        count += self.position_embeddings.weight.len();
186        if let Some(ref tte) = self.token_type_embeddings {
187            count += tte.vocab_size() * tte.hidden_size();
188        }
189        count += self.embeddings_layernorm.weight.len() * 2; // weight + bias
190        for layer in &self.layers {
191            count += layer.parameters().iter().map(|p| p.len()).sum::<usize>();
192        }
193        count
194    }
195}
196
197#[cfg(test)]
198#[allow(clippy::unwrap_used)]
199mod tests {
200    use super::*;
201    use crate::transformer::ModelArchitecture;
202
203    fn tiny_encoder_config() -> TransformerConfig {
204        // Minimal config for testing
205        TransformerConfig {
206            hidden_size: 32,
207            num_hidden_layers: 2,
208            num_attention_heads: 4,
209            num_kv_heads: 4,
210            intermediate_size: 64,
211            vocab_size: 100,
212            max_position_embeddings: 32,
213            rms_norm_eps: 1e-5,
214            architecture: ModelArchitecture::Encoder,
215            ..TransformerConfig::tiny()
216        }
217    }
218
219    #[test]
220    fn clf_001_encoder_model_forward_shape() {
221        let config = tiny_encoder_config();
222        let model = EncoderModel::new(&config);
223        let token_ids = vec![1, 2, 3, 4];
224        let output = model.forward(&token_ids);
225        assert_eq!(output.len(), 4 * config.hidden_size);
226    }
227
228    #[test]
229    fn clf_001_encoder_model_forward_finite() {
230        let config = tiny_encoder_config();
231        let model = EncoderModel::new(&config);
232        let token_ids = vec![10, 20, 30];
233        let output = model.forward(&token_ids);
234        let data = output.data();
235        let slice = data.as_slice().unwrap();
236        assert!(slice.iter().all(|v| v.is_finite()));
237    }
238
239    #[test]
240    fn clf_001_encoder_cls_embedding_shape() {
241        let config = tiny_encoder_config();
242        let model = EncoderModel::new(&config);
243        let token_ids = vec![5, 10, 15];
244        let cls = model.cls_embedding(&token_ids);
245        assert_eq!(cls.len(), config.hidden_size);
246    }
247
248    #[test]
249    fn clf_001_encoder_cls_embedding_deterministic() {
250        let config = tiny_encoder_config();
251        let model = EncoderModel::new(&config);
252        let token_ids = vec![1, 2, 3];
253        let cls1 = model.cls_embedding(&token_ids);
254        let cls2 = model.cls_embedding(&token_ids);
255        let d1 = cls1.data();
256        let d2 = cls2.data();
257        let s1 = d1.as_slice().unwrap();
258        let s2 = d2.as_slice().unwrap();
259        assert_eq!(s1, s2, "CLS embedding must be deterministic");
260    }
261
262    #[test]
263    fn clf_001_encoder_num_parameters() {
264        let config = tiny_encoder_config();
265        let model = EncoderModel::new(&config);
266        let count = model.num_parameters();
267        // Should be > 0 and reasonable
268        assert!(count > 1000, "encoder should have substantial params, got {count}");
269    }
270
271    #[test]
272    fn test_encoder_forward_single_token() {
273        let config = tiny_encoder_config();
274        let model = EncoderModel::new(&config);
275        let output = model.forward(&[42]);
276        assert_eq!(output.len(), config.hidden_size);
277        let data = output.data();
278        let slice = data.as_slice().unwrap();
279        assert!(slice.iter().all(|v| v.is_finite()));
280    }
281
282    #[test]
283    fn test_encoder_cls_embedding_finite() {
284        let config = tiny_encoder_config();
285        let model = EncoderModel::new(&config);
286        let cls = model.cls_embedding(&[1, 2, 3, 4, 5]);
287        let data = cls.data();
288        let slice = data.as_slice().unwrap();
289        assert!(slice.iter().all(|v| v.is_finite()));
290    }
291
292    #[test]
293    fn test_encoder_config_stored() {
294        let config = tiny_encoder_config();
295        let model = EncoderModel::new(&config);
296        assert_eq!(model.config.hidden_size, 32);
297        assert_eq!(model.config.num_hidden_layers, 2);
298        assert_eq!(model.config.vocab_size, 100);
299    }
300
301    #[test]
302    fn test_encoder_layers_count() {
303        let config = tiny_encoder_config();
304        let model = EncoderModel::new(&config);
305        assert_eq!(model.layers.len(), 2);
306    }
307
308    #[test]
309    fn test_encoder_token_type_embeddings_present() {
310        let config = tiny_encoder_config();
311        let model = EncoderModel::new(&config);
312        assert!(model.token_type_embeddings.is_some());
313    }
314
315    #[test]
316    fn test_encoder_from_params_missing_weights() {
317        let config = tiny_encoder_config();
318        let empty_params: HashMap<String, Tensor> = HashMap::new();
319        let result = EncoderModel::from_params(&config, &empty_params);
320        assert!(result.is_none(), "from_params should return None with empty params");
321    }
322
323    #[test]
324    fn test_encoder_from_safetensors_missing_file() {
325        let config = tiny_encoder_config();
326        let result = EncoderModel::from_safetensors(&config, std::path::Path::new("/nonexistent"));
327        assert!(result.is_err());
328    }
329
330    #[test]
331    fn test_encoder_forward_different_seq_lens() {
332        let config = tiny_encoder_config();
333        let model = EncoderModel::new(&config);
334
335        for seq_len in [1, 2, 4, 8, 16] {
336            let token_ids: Vec<u32> = (0..seq_len as u32).collect();
337            let output = model.forward(&token_ids);
338            assert_eq!(
339                output.len(),
340                seq_len * config.hidden_size,
341                "Output mismatch for seq_len={seq_len}"
342            );
343        }
344    }
345
346    #[test]
347    fn test_encoder_num_params_includes_all_components() {
348        let config = tiny_encoder_config();
349        let model = EncoderModel::new(&config);
350        let total = model.num_parameters();
351
352        // Embedding: vocab_size * hidden_size = 100 * 32 = 3200
353        let embed_params = config.vocab_size * config.hidden_size;
354        // Position: max_pos * hidden = 32 * 32 = 1024
355        let pos_params = config.max_position_embeddings * config.hidden_size;
356        // Token type: 2 * hidden = 64
357        let tte_params = 2 * config.hidden_size;
358        // LayerNorm: hidden * 2 = 64 (weight + bias)
359        let ln_params = config.hidden_size * 2;
360
361        let non_layer_params = embed_params + pos_params + tte_params + ln_params;
362        assert!(
363            total > non_layer_params,
364            "Total params ({total}) should exceed non-layer params ({non_layer_params})"
365        );
366    }
367
368    #[test]
369    fn test_encoder_forward_max_token_id() {
370        let config = tiny_encoder_config();
371        let model = EncoderModel::new(&config);
372        // Use token ID at the edge of vocab
373        let output = model.forward(&[99]); // vocab_size = 100, max valid = 99
374        assert_eq!(output.len(), config.hidden_size);
375    }
376
377    #[test]
378    fn test_encoder_deterministic_across_calls() {
379        let config = tiny_encoder_config();
380        let model = EncoderModel::new(&config);
381        let ids = vec![10, 20, 30, 40];
382
383        let out1 = model.forward(&ids);
384        let out2 = model.forward(&ids);
385
386        let d1 = out1.data();
387        let d2 = out2.data();
388        let s1 = d1.as_slice().unwrap();
389        let s2 = d2.as_slice().unwrap();
390        assert_eq!(s1, s2);
391    }
392
393    // ── Additional coverage tests ─────────────────────────────────
394
395    #[test]
396    fn test_encoder_forward_varying_vocab_ids() {
397        let config = tiny_encoder_config();
398        let model = EncoderModel::new(&config);
399        // Use a range of token IDs
400        let ids: Vec<u32> = (0..20).collect();
401        let output = model.forward(&ids);
402        assert_eq!(output.len(), 20 * config.hidden_size);
403        let data = output.data();
404        let slice = data.as_slice().unwrap();
405        assert!(slice.iter().all(|v| v.is_finite()));
406    }
407
408    #[test]
409    fn test_encoder_from_params_partial_weights() {
410        let config = tiny_encoder_config();
411        let h = config.hidden_size;
412        let v = config.vocab_size;
413        let mut params: HashMap<String, Tensor> = HashMap::new();
414
415        // Add only embed_tokens, not position_embeddings → should return None
416        let embed_data = vec![0.0_f32; v * h];
417        params
418            .insert("encoder.embed_tokens.weight".to_string(), Tensor::from_vec(embed_data, false));
419
420        let result = EncoderModel::from_params(&config, &params);
421        assert!(result.is_none());
422    }
423
424    #[test]
425    fn test_encoder_cls_embedding_different_inputs_differ() {
426        let config = tiny_encoder_config();
427        let model = EncoderModel::new(&config);
428        let cls1 = model.cls_embedding(&[1, 2, 3]);
429        let cls2 = model.cls_embedding(&[10, 20, 30]);
430        let d1 = cls1.data();
431        let d2 = cls2.data();
432        let s1 = d1.as_slice().unwrap();
433        let s2 = d2.as_slice().unwrap();
434        // Different inputs should (with overwhelming probability) produce different embeddings
435        assert_ne!(s1, s2);
436    }
437
438    #[test]
439    fn test_encoder_position_embeddings_present() {
440        let config = tiny_encoder_config();
441        let model = EncoderModel::new(&config);
442        assert_eq!(
443            model.position_embeddings.weight.len(),
444            config.max_position_embeddings * config.hidden_size
445        );
446    }
447
448    #[test]
449    fn test_encoder_embeddings_layernorm_present() {
450        let config = tiny_encoder_config();
451        let model = EncoderModel::new(&config);
452        assert_eq!(model.embeddings_layernorm.weight.len(), config.hidden_size);
453    }
454
455    #[test]
456    fn test_encoder_num_parameters_varies_with_config() {
457        let config1 = tiny_encoder_config();
458        let model1 = EncoderModel::new(&config1);
459
460        let config2 = TransformerConfig {
461            hidden_size: 64,
462            num_hidden_layers: 4,
463            num_attention_heads: 8,
464            num_kv_heads: 8,
465            intermediate_size: 128,
466            vocab_size: 200,
467            max_position_embeddings: 64,
468            rms_norm_eps: 1e-5,
469            architecture: ModelArchitecture::Encoder,
470            ..TransformerConfig::tiny()
471        };
472        let model2 = EncoderModel::new(&config2);
473
474        // Larger model should have more parameters
475        assert!(model2.num_parameters() > model1.num_parameters());
476    }
477
478    #[test]
479    fn test_encoder_forward_two_tokens() {
480        let config = tiny_encoder_config();
481        let model = EncoderModel::new(&config);
482        let output = model.forward(&[5, 10]);
483        assert_eq!(output.len(), 2 * config.hidden_size);
484    }
485
486    #[test]
487    fn test_encoder_forward_at_max_position() {
488        let config = tiny_encoder_config();
489        let model = EncoderModel::new(&config);
490        // Use max_position_embeddings tokens
491        let ids: Vec<u32> = (0..config.max_position_embeddings as u32).collect();
492        let output = model.forward(&ids);
493        assert_eq!(output.len(), config.max_position_embeddings * config.hidden_size);
494    }
495
496    #[test]
497    fn test_encoder_no_token_type_embeddings() {
498        let config = tiny_encoder_config();
499        let mut model = EncoderModel::new(&config);
500        // Remove token type embeddings
501        model.token_type_embeddings = None;
502        let output = model.forward(&[1, 2, 3]);
503        assert_eq!(output.len(), 3 * config.hidden_size);
504        let data = output.data();
505        let slice = data.as_slice().unwrap();
506        assert!(slice.iter().all(|v| v.is_finite()));
507    }
508
509    #[test]
510    fn test_encoder_num_parameters_without_tte() {
511        let config = tiny_encoder_config();
512        let mut model = EncoderModel::new(&config);
513        let with_tte = model.num_parameters();
514        model.token_type_embeddings = None;
515        let without_tte = model.num_parameters();
516        assert!(with_tte > without_tte);
517        // Difference should be 2 * hidden_size (token type embedding for 2 types)
518        assert_eq!(with_tte - without_tte, 2 * config.hidden_size);
519    }
520
521    #[test]
522    fn test_encoder_config_is_encoder() {
523        let config = tiny_encoder_config();
524        let model = EncoderModel::new(&config);
525        assert!(model.config.is_encoder());
526    }
527}