1use crate::autograd::add;
12use crate::error::{Error, Result};
13use crate::Tensor;
14use std::collections::HashMap;
15use std::path::Path;
16
17use super::config::TransformerConfig;
18use super::embedding::{Embedding, LearnedPositionEmbedding};
19use super::encoder_block::EncoderBlock;
20use super::norm::LayerNorm;
21use super::weights::{load_safetensors_weights, Architecture};
22
23pub struct EncoderModel {
25 pub config: TransformerConfig,
27 pub embed_tokens: Embedding,
29 pub position_embeddings: LearnedPositionEmbedding,
31 pub token_type_embeddings: Option<Embedding>,
33 pub embeddings_layernorm: LayerNorm,
35 pub layers: Vec<EncoderBlock>,
37}
38
39impl EncoderModel {
40 pub fn new(config: &TransformerConfig) -> Self {
42 let max_positions = config.max_position_embeddings;
43 let eps = config.rms_norm_eps;
44 let layers = (0..config.num_hidden_layers).map(|i| EncoderBlock::new(config, i)).collect();
45
46 Self {
47 config: config.clone(),
48 embed_tokens: Embedding::new(config.vocab_size, config.hidden_size),
49 position_embeddings: LearnedPositionEmbedding::new(max_positions, config.hidden_size),
50 token_type_embeddings: Some(Embedding::new(2, config.hidden_size)),
51 embeddings_layernorm: LayerNorm::new(config.hidden_size, eps),
52 layers,
53 }
54 }
55
56 pub fn from_params(
65 config: &TransformerConfig,
66 params: &HashMap<String, Tensor>,
67 ) -> Option<Self> {
68 let max_positions = config.max_position_embeddings;
69
70 let embed_tokens = Embedding::from_params(
71 params,
72 "encoder.embed_tokens.weight",
73 config.vocab_size,
74 config.hidden_size,
75 )?;
76
77 let position_embeddings = LearnedPositionEmbedding::from_params(
78 params,
79 "encoder.position_embeddings.weight",
80 max_positions,
81 config.hidden_size,
82 )?;
83
84 let token_type_embeddings =
88 params.get("encoder.token_type_embeddings.weight").and_then(|tensor| {
89 let type_vocab_size = tensor.len() / config.hidden_size;
90 if type_vocab_size == 0 || tensor.len() != type_vocab_size * config.hidden_size {
91 return None;
92 }
93 Embedding::from_params(
94 params,
95 "encoder.token_type_embeddings.weight",
96 type_vocab_size,
97 config.hidden_size,
98 )
99 });
100
101 let embeddings_layernorm = LayerNorm::from_params(
102 params,
103 "encoder.embeddings_layernorm",
104 config.rms_norm_eps,
105 config.hidden_size,
106 )?;
107
108 let layers: Option<Vec<EncoderBlock>> = (0..config.num_hidden_layers)
109 .map(|i| EncoderBlock::from_params(config, params, i))
110 .collect();
111 let layers = layers?;
112
113 Some(Self {
114 config: config.clone(),
115 embed_tokens,
116 position_embeddings,
117 token_type_embeddings,
118 embeddings_layernorm,
119 layers,
120 })
121 }
122
123 pub fn from_safetensors(config: &TransformerConfig, model_path: &Path) -> Result<Self> {
125 let weights = load_safetensors_weights(model_path, Architecture::RoBERTa)?;
126 Self::from_params(config, &weights).ok_or_else(|| {
127 Error::ConfigError("Failed to construct encoder from loaded weights".into())
128 })
129 }
130
131 pub fn forward(&self, token_ids: &[u32]) -> Tensor {
139 let seq_len = token_ids.len();
140 let h = self.config.hidden_size;
141
142 let token_emb = self.embed_tokens.forward(token_ids);
144
145 let pos_emb = self.position_embeddings.forward(seq_len);
147
148 let mut combined = add(&token_emb, &pos_emb);
150
151 if let Some(ref tte) = self.token_type_embeddings {
153 let type_ids: Vec<u32> = vec![0; seq_len];
154 let type_emb = tte.forward(&type_ids);
155 combined = add(&combined, &type_emb);
156 }
157
158 let mut hidden = self.embeddings_layernorm.forward_batched(&combined, seq_len, h);
160
161 for layer in &self.layers {
163 hidden = layer.forward(&hidden, seq_len);
164 }
165
166 hidden
167 }
168
169 pub fn cls_embedding(&self, token_ids: &[u32]) -> Tensor {
174 let hidden = self.forward(token_ids);
175 let h = self.config.hidden_size;
176 let data = hidden.data();
177 let slice = data.as_slice().expect("hidden contiguous");
178 Tensor::from_vec(slice[..h].to_vec(), false)
179 }
180
181 pub fn num_parameters(&self) -> usize {
183 let mut count = 0;
184 count += self.embed_tokens.vocab_size() * self.embed_tokens.hidden_size();
185 count += self.position_embeddings.weight.len();
186 if let Some(ref tte) = self.token_type_embeddings {
187 count += tte.vocab_size() * tte.hidden_size();
188 }
189 count += self.embeddings_layernorm.weight.len() * 2; for layer in &self.layers {
191 count += layer.parameters().iter().map(|p| p.len()).sum::<usize>();
192 }
193 count
194 }
195}
196
197#[cfg(test)]
198#[allow(clippy::unwrap_used)]
199mod tests {
200 use super::*;
201 use crate::transformer::ModelArchitecture;
202
203 fn tiny_encoder_config() -> TransformerConfig {
204 TransformerConfig {
206 hidden_size: 32,
207 num_hidden_layers: 2,
208 num_attention_heads: 4,
209 num_kv_heads: 4,
210 intermediate_size: 64,
211 vocab_size: 100,
212 max_position_embeddings: 32,
213 rms_norm_eps: 1e-5,
214 architecture: ModelArchitecture::Encoder,
215 ..TransformerConfig::tiny()
216 }
217 }
218
219 #[test]
220 fn clf_001_encoder_model_forward_shape() {
221 let config = tiny_encoder_config();
222 let model = EncoderModel::new(&config);
223 let token_ids = vec![1, 2, 3, 4];
224 let output = model.forward(&token_ids);
225 assert_eq!(output.len(), 4 * config.hidden_size);
226 }
227
228 #[test]
229 fn clf_001_encoder_model_forward_finite() {
230 let config = tiny_encoder_config();
231 let model = EncoderModel::new(&config);
232 let token_ids = vec![10, 20, 30];
233 let output = model.forward(&token_ids);
234 let data = output.data();
235 let slice = data.as_slice().unwrap();
236 assert!(slice.iter().all(|v| v.is_finite()));
237 }
238
239 #[test]
240 fn clf_001_encoder_cls_embedding_shape() {
241 let config = tiny_encoder_config();
242 let model = EncoderModel::new(&config);
243 let token_ids = vec![5, 10, 15];
244 let cls = model.cls_embedding(&token_ids);
245 assert_eq!(cls.len(), config.hidden_size);
246 }
247
248 #[test]
249 fn clf_001_encoder_cls_embedding_deterministic() {
250 let config = tiny_encoder_config();
251 let model = EncoderModel::new(&config);
252 let token_ids = vec![1, 2, 3];
253 let cls1 = model.cls_embedding(&token_ids);
254 let cls2 = model.cls_embedding(&token_ids);
255 let d1 = cls1.data();
256 let d2 = cls2.data();
257 let s1 = d1.as_slice().unwrap();
258 let s2 = d2.as_slice().unwrap();
259 assert_eq!(s1, s2, "CLS embedding must be deterministic");
260 }
261
262 #[test]
263 fn clf_001_encoder_num_parameters() {
264 let config = tiny_encoder_config();
265 let model = EncoderModel::new(&config);
266 let count = model.num_parameters();
267 assert!(count > 1000, "encoder should have substantial params, got {count}");
269 }
270
271 #[test]
272 fn test_encoder_forward_single_token() {
273 let config = tiny_encoder_config();
274 let model = EncoderModel::new(&config);
275 let output = model.forward(&[42]);
276 assert_eq!(output.len(), config.hidden_size);
277 let data = output.data();
278 let slice = data.as_slice().unwrap();
279 assert!(slice.iter().all(|v| v.is_finite()));
280 }
281
282 #[test]
283 fn test_encoder_cls_embedding_finite() {
284 let config = tiny_encoder_config();
285 let model = EncoderModel::new(&config);
286 let cls = model.cls_embedding(&[1, 2, 3, 4, 5]);
287 let data = cls.data();
288 let slice = data.as_slice().unwrap();
289 assert!(slice.iter().all(|v| v.is_finite()));
290 }
291
292 #[test]
293 fn test_encoder_config_stored() {
294 let config = tiny_encoder_config();
295 let model = EncoderModel::new(&config);
296 assert_eq!(model.config.hidden_size, 32);
297 assert_eq!(model.config.num_hidden_layers, 2);
298 assert_eq!(model.config.vocab_size, 100);
299 }
300
301 #[test]
302 fn test_encoder_layers_count() {
303 let config = tiny_encoder_config();
304 let model = EncoderModel::new(&config);
305 assert_eq!(model.layers.len(), 2);
306 }
307
308 #[test]
309 fn test_encoder_token_type_embeddings_present() {
310 let config = tiny_encoder_config();
311 let model = EncoderModel::new(&config);
312 assert!(model.token_type_embeddings.is_some());
313 }
314
315 #[test]
316 fn test_encoder_from_params_missing_weights() {
317 let config = tiny_encoder_config();
318 let empty_params: HashMap<String, Tensor> = HashMap::new();
319 let result = EncoderModel::from_params(&config, &empty_params);
320 assert!(result.is_none(), "from_params should return None with empty params");
321 }
322
323 #[test]
324 fn test_encoder_from_safetensors_missing_file() {
325 let config = tiny_encoder_config();
326 let result = EncoderModel::from_safetensors(&config, std::path::Path::new("/nonexistent"));
327 assert!(result.is_err());
328 }
329
330 #[test]
331 fn test_encoder_forward_different_seq_lens() {
332 let config = tiny_encoder_config();
333 let model = EncoderModel::new(&config);
334
335 for seq_len in [1, 2, 4, 8, 16] {
336 let token_ids: Vec<u32> = (0..seq_len as u32).collect();
337 let output = model.forward(&token_ids);
338 assert_eq!(
339 output.len(),
340 seq_len * config.hidden_size,
341 "Output mismatch for seq_len={seq_len}"
342 );
343 }
344 }
345
346 #[test]
347 fn test_encoder_num_params_includes_all_components() {
348 let config = tiny_encoder_config();
349 let model = EncoderModel::new(&config);
350 let total = model.num_parameters();
351
352 let embed_params = config.vocab_size * config.hidden_size;
354 let pos_params = config.max_position_embeddings * config.hidden_size;
356 let tte_params = 2 * config.hidden_size;
358 let ln_params = config.hidden_size * 2;
360
361 let non_layer_params = embed_params + pos_params + tte_params + ln_params;
362 assert!(
363 total > non_layer_params,
364 "Total params ({total}) should exceed non-layer params ({non_layer_params})"
365 );
366 }
367
368 #[test]
369 fn test_encoder_forward_max_token_id() {
370 let config = tiny_encoder_config();
371 let model = EncoderModel::new(&config);
372 let output = model.forward(&[99]); assert_eq!(output.len(), config.hidden_size);
375 }
376
377 #[test]
378 fn test_encoder_deterministic_across_calls() {
379 let config = tiny_encoder_config();
380 let model = EncoderModel::new(&config);
381 let ids = vec![10, 20, 30, 40];
382
383 let out1 = model.forward(&ids);
384 let out2 = model.forward(&ids);
385
386 let d1 = out1.data();
387 let d2 = out2.data();
388 let s1 = d1.as_slice().unwrap();
389 let s2 = d2.as_slice().unwrap();
390 assert_eq!(s1, s2);
391 }
392
393 #[test]
396 fn test_encoder_forward_varying_vocab_ids() {
397 let config = tiny_encoder_config();
398 let model = EncoderModel::new(&config);
399 let ids: Vec<u32> = (0..20).collect();
401 let output = model.forward(&ids);
402 assert_eq!(output.len(), 20 * config.hidden_size);
403 let data = output.data();
404 let slice = data.as_slice().unwrap();
405 assert!(slice.iter().all(|v| v.is_finite()));
406 }
407
408 #[test]
409 fn test_encoder_from_params_partial_weights() {
410 let config = tiny_encoder_config();
411 let h = config.hidden_size;
412 let v = config.vocab_size;
413 let mut params: HashMap<String, Tensor> = HashMap::new();
414
415 let embed_data = vec![0.0_f32; v * h];
417 params
418 .insert("encoder.embed_tokens.weight".to_string(), Tensor::from_vec(embed_data, false));
419
420 let result = EncoderModel::from_params(&config, ¶ms);
421 assert!(result.is_none());
422 }
423
424 #[test]
425 fn test_encoder_cls_embedding_different_inputs_differ() {
426 let config = tiny_encoder_config();
427 let model = EncoderModel::new(&config);
428 let cls1 = model.cls_embedding(&[1, 2, 3]);
429 let cls2 = model.cls_embedding(&[10, 20, 30]);
430 let d1 = cls1.data();
431 let d2 = cls2.data();
432 let s1 = d1.as_slice().unwrap();
433 let s2 = d2.as_slice().unwrap();
434 assert_ne!(s1, s2);
436 }
437
438 #[test]
439 fn test_encoder_position_embeddings_present() {
440 let config = tiny_encoder_config();
441 let model = EncoderModel::new(&config);
442 assert_eq!(
443 model.position_embeddings.weight.len(),
444 config.max_position_embeddings * config.hidden_size
445 );
446 }
447
448 #[test]
449 fn test_encoder_embeddings_layernorm_present() {
450 let config = tiny_encoder_config();
451 let model = EncoderModel::new(&config);
452 assert_eq!(model.embeddings_layernorm.weight.len(), config.hidden_size);
453 }
454
455 #[test]
456 fn test_encoder_num_parameters_varies_with_config() {
457 let config1 = tiny_encoder_config();
458 let model1 = EncoderModel::new(&config1);
459
460 let config2 = TransformerConfig {
461 hidden_size: 64,
462 num_hidden_layers: 4,
463 num_attention_heads: 8,
464 num_kv_heads: 8,
465 intermediate_size: 128,
466 vocab_size: 200,
467 max_position_embeddings: 64,
468 rms_norm_eps: 1e-5,
469 architecture: ModelArchitecture::Encoder,
470 ..TransformerConfig::tiny()
471 };
472 let model2 = EncoderModel::new(&config2);
473
474 assert!(model2.num_parameters() > model1.num_parameters());
476 }
477
478 #[test]
479 fn test_encoder_forward_two_tokens() {
480 let config = tiny_encoder_config();
481 let model = EncoderModel::new(&config);
482 let output = model.forward(&[5, 10]);
483 assert_eq!(output.len(), 2 * config.hidden_size);
484 }
485
486 #[test]
487 fn test_encoder_forward_at_max_position() {
488 let config = tiny_encoder_config();
489 let model = EncoderModel::new(&config);
490 let ids: Vec<u32> = (0..config.max_position_embeddings as u32).collect();
492 let output = model.forward(&ids);
493 assert_eq!(output.len(), config.max_position_embeddings * config.hidden_size);
494 }
495
496 #[test]
497 fn test_encoder_no_token_type_embeddings() {
498 let config = tiny_encoder_config();
499 let mut model = EncoderModel::new(&config);
500 model.token_type_embeddings = None;
502 let output = model.forward(&[1, 2, 3]);
503 assert_eq!(output.len(), 3 * config.hidden_size);
504 let data = output.data();
505 let slice = data.as_slice().unwrap();
506 assert!(slice.iter().all(|v| v.is_finite()));
507 }
508
509 #[test]
510 fn test_encoder_num_parameters_without_tte() {
511 let config = tiny_encoder_config();
512 let mut model = EncoderModel::new(&config);
513 let with_tte = model.num_parameters();
514 model.token_type_embeddings = None;
515 let without_tte = model.num_parameters();
516 assert!(with_tte > without_tte);
517 assert_eq!(with_tte - without_tte, 2 * config.hidden_size);
519 }
520
521 #[test]
522 fn test_encoder_config_is_encoder() {
523 let config = tiny_encoder_config();
524 let model = EncoderModel::new(&config);
525 assert!(model.config.is_encoder());
526 }
527}