1use crate::Tensor;
6use std::collections::HashMap;
7
8pub struct Embedding {
10 pub weight: Tensor,
12 vocab_size: usize,
14 hidden_size: usize,
16}
17
18impl Embedding {
19 pub fn new(vocab_size: usize, hidden_size: usize) -> Self {
21 use super::init::{get_init_seed, rand_normal_seeded};
22 Self {
24 weight: Tensor::from_vec(
25 rand_normal_seeded(vocab_size * hidden_size, get_init_seed(), "embed_tokens"),
26 true,
27 ),
28 vocab_size,
29 hidden_size,
30 }
31 }
32
33 pub fn from_params(
39 params: &HashMap<String, Tensor>,
40 name: &str,
41 vocab_size: usize,
42 hidden_size: usize,
43 ) -> Option<Self> {
44 let weight = params.get(name)?.clone();
45 let expected = vocab_size * hidden_size;
46 if weight.len() != expected {
47 eprintln!(
48 "[PMAT-326] Embedding '{name}': shape mismatch — got {} elements, expected {expected} ({vocab_size}x{hidden_size})",
49 weight.len()
50 );
51 return None;
52 }
53 Some(Self { weight, vocab_size, hidden_size })
54 }
55
56 pub fn forward(&self, token_ids: &[u32]) -> Tensor {
64 contract_pre_embedding_lookup!(token_ids);
65 let mut output = Vec::with_capacity(token_ids.len() * self.hidden_size);
66
67 for &token_id in token_ids {
68 let idx = token_id as usize;
69 if idx >= self.vocab_size {
70 eprintln!(
72 "Warning: Embedding::forward token_id {} >= vocab_size {}. N-09 OOB escape.",
73 token_id, self.vocab_size
74 );
75 output.extend(std::iter::repeat_n(0.0, self.hidden_size));
76 } else {
77 let start = idx * self.hidden_size;
78 let end = start + self.hidden_size;
79 output.extend_from_slice(
80 &self.weight.data().as_slice().expect("embedding weight must be contiguous")
81 [start..end],
82 );
83 }
84 }
85
86 let result = Tensor::from_vec(output, true);
87 contract_post_embedding_lookup!(result.data().as_slice().unwrap_or(&[]));
88 result
89 }
90
91 pub fn vocab_size(&self) -> usize {
93 self.vocab_size
94 }
95
96 pub fn hidden_size(&self) -> usize {
98 self.hidden_size
99 }
100}
101
102pub struct LearnedPositionEmbedding {
112 pub weight: Tensor,
114 max_positions: usize,
116 hidden_size: usize,
118}
119
120impl LearnedPositionEmbedding {
121 pub fn new(max_positions: usize, hidden_size: usize) -> Self {
123 let scale = (1.0 / hidden_size as f32).sqrt();
124 Self {
125 weight: Tensor::from_vec(
126 (0..max_positions * hidden_size)
127 .map(|i| (i as f32 * 0.0731).sin() * scale)
128 .collect(),
129 true,
130 ),
131 max_positions,
132 hidden_size,
133 }
134 }
135
136 pub fn from_params(
138 params: &HashMap<String, Tensor>,
139 name: &str,
140 max_positions: usize,
141 hidden_size: usize,
142 ) -> Option<Self> {
143 let weight = params.get(name)?.clone();
144 let expected = max_positions * hidden_size;
145 if weight.len() != expected {
146 eprintln!(
147 "[ENC-003] LearnedPositionEmbedding '{name}': shape mismatch — \
148 got {} elements, expected {expected} ({max_positions}×{hidden_size})",
149 weight.len()
150 );
151 return None;
152 }
153 Some(Self { weight, max_positions, hidden_size })
154 }
155
156 pub fn forward(&self, seq_len: usize) -> Tensor {
160 let clamped_len = seq_len.min(self.max_positions);
161 let weight_slice = &self.weight.data().as_slice().expect("position weight contiguous")
162 [..clamped_len * self.hidden_size];
163 if seq_len <= self.max_positions {
165 Tensor::from_vec(weight_slice.to_vec(), true)
166 } else {
167 let mut output = weight_slice.to_vec();
168 let last_start = (self.max_positions - 1) * self.hidden_size;
169 let last_end = last_start + self.hidden_size;
170 let last_pos = &self.weight.data().as_slice().expect("position weight contiguous")
171 [last_start..last_end];
172 for _ in self.max_positions..seq_len {
173 output.extend_from_slice(last_pos);
174 }
175 Tensor::from_vec(output, true)
176 }
177 }
178
179 pub fn max_positions(&self) -> usize {
181 self.max_positions
182 }
183
184 pub fn hidden_size(&self) -> usize {
186 self.hidden_size
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193
194 #[test]
195 fn test_embedding_forward() {
196 let embed = Embedding::new(100, 8);
197 let tokens = vec![0, 5, 10];
198 let output = embed.forward(&tokens);
199 assert_eq!(output.len(), 3 * 8);
200 }
201
202 #[test]
203 fn test_embedding_out_of_vocab() {
204 let embed = Embedding::new(100, 8);
205 let tokens = vec![0, 200]; let output = embed.forward(&tokens);
207 assert_eq!(output.len(), 2 * 8);
208 let data = output.data();
210 for i in 8..16 {
211 assert_eq!(data[i], 0.0);
212 }
213 }
214
215 #[test]
216 fn test_embedding_vocab_and_hidden_size() {
217 let embed = Embedding::new(500, 16);
218 assert_eq!(embed.vocab_size(), 500);
219 assert_eq!(embed.hidden_size(), 16);
220 }
221
222 #[test]
223 fn test_embedding_single_token() {
224 let embed = Embedding::new(100, 8);
225 let tokens = vec![42];
226 let output = embed.forward(&tokens);
227 assert_eq!(output.len(), 8);
228 assert!(output.requires_grad());
229 }
230
231 #[test]
232 fn test_embedding_requires_grad() {
233 let embed = Embedding::new(100, 8);
234 assert!(embed.weight.requires_grad());
235 }
236
237 #[test]
238 fn test_embedding_from_params() {
239 let mut params = HashMap::new();
240 params.insert("embed.weight".to_string(), Tensor::from_vec(vec![0.1; 100 * 8], true));
241 let embed = Embedding::from_params(¶ms, "embed.weight", 100, 8);
242 assert!(embed.is_some());
243 let embed = embed.expect("operation should succeed");
244 assert_eq!(embed.vocab_size(), 100);
245 assert_eq!(embed.hidden_size(), 8);
246 }
247
248 #[test]
249 fn test_embedding_from_params_missing() {
250 let params: HashMap<String, Tensor> = HashMap::new();
251 let embed = Embedding::from_params(¶ms, "missing.weight", 100, 8);
252 assert!(embed.is_none());
253 }
254
255 #[test]
260 fn enc_003_learned_position_embedding_shape() {
261 let pos_embed = LearnedPositionEmbedding::new(514, 768);
262 assert_eq!(pos_embed.max_positions(), 514);
263 assert_eq!(pos_embed.hidden_size(), 768);
264 let output = pos_embed.forward(10);
265 assert_eq!(output.len(), 10 * 768);
266 }
267
268 #[test]
269 fn enc_003_learned_position_embedding_deterministic() {
270 let pe1 = LearnedPositionEmbedding::new(128, 32);
271 let pe2 = LearnedPositionEmbedding::new(128, 32);
272 let o1 = pe1.forward(10);
273 let o2 = pe2.forward(10);
274 assert_eq!(
275 o1.data().as_slice().expect("contiguous"),
276 o2.data().as_slice().expect("contiguous"),
277 );
278 }
279
280 #[test]
281 fn enc_003_learned_position_embedding_clamp_beyond_max() {
282 let pe = LearnedPositionEmbedding::new(4, 8);
283 let output = pe.forward(6); assert_eq!(output.len(), 6 * 8);
285 let data = output.data();
287 let slice = data.as_slice().expect("contiguous");
288 let pos3 = &slice[3 * 8..4 * 8];
289 let pos4 = &slice[4 * 8..5 * 8];
290 let pos5 = &slice[5 * 8..6 * 8];
291 assert_eq!(pos3, pos4);
292 assert_eq!(pos3, pos5);
293 }
294
295 #[test]
296 fn enc_003_learned_position_from_params() {
297 let mut params = HashMap::new();
298 params.insert("pos.weight".to_string(), Tensor::from_vec(vec![0.1; 128 * 32], true));
299 let pe = LearnedPositionEmbedding::from_params(¶ms, "pos.weight", 128, 32);
300 assert!(pe.is_some());
301 }
302
303 #[test]
304 fn enc_003_learned_position_from_params_rejects_wrong_shape() {
305 let mut params = HashMap::new();
306 params.insert("pos.weight".to_string(), Tensor::from_vec(vec![0.1; 50], true));
307 let pe = LearnedPositionEmbedding::from_params(¶ms, "pos.weight", 128, 32);
308 assert!(pe.is_none());
309 }
310
311 #[test]
330 fn falsify_e7a_init_produces_valid_embedding() {
331 let embed = Embedding::new(100, 64);
332 let data = embed.weight.data();
333 let slice = data.as_slice().expect("data as slice");
334
335 let nan_count = slice.iter().filter(|v| v.is_nan()).count();
337 assert_eq!(nan_count, 0, "FALSIFY-E7a: Init must not produce NaN");
338
339 let inf_count = slice.iter().filter(|v| v.is_infinite()).count();
341 assert_eq!(inf_count, 0, "FALSIFY-E7a: Init must not produce Inf");
342
343 let zero_count = slice.iter().filter(|v| v.abs() < 1e-10).count();
345 let zero_pct = 100.0 * zero_count as f64 / slice.len() as f64;
346 assert!(zero_pct < 50.0,
347 "FALSIFY-E7a: Init has {zero_pct:.1}% zeros — exceeds embedding contract threshold (50%)");
348
349 let min = slice.iter().copied().fold(f32::INFINITY, f32::min);
351 let max = slice.iter().copied().fold(f32::NEG_INFINITY, f32::max);
352 assert!(
353 (max - min).abs() > 1e-6,
354 "FALSIFY-E7a: Init values are constant ({min}..{max}) — degenerate embedding"
355 );
356 }
357
358 #[test]
360 fn falsify_e7b_shape_matches_dimensions() {
361 let vocab_size = 151;
362 let hidden_size = 32;
363 let embed = Embedding::new(vocab_size, hidden_size);
364 assert_eq!(
365 embed.weight.len(),
366 vocab_size * hidden_size,
367 "FALSIFY-E7b: Embedding length must be vocab_size * hidden_size"
368 );
369 }
370
371 #[test]
376 fn falsify_e7c_from_params_rejects_wrong_shape() {
377 let mut params = HashMap::new();
378 params.insert("embed.weight".to_string(), Tensor::from_vec(vec![0.1; 50], true));
380 let embed = Embedding::from_params(¶ms, "embed.weight", 100, 8);
381 assert!(
383 embed.is_none(),
384 "FALSIFY-E7c: PMAT-326 fix — from_params MUST reject wrong-shape embedding"
385 );
386 }
387
388 #[test]
393 fn falsify_e7d_oob_token_produces_zeros_not_panic() {
394 let embed = Embedding::new(100, 8);
395 let tokens = vec![0, 999]; let output = embed.forward(&tokens);
397 assert_eq!(output.len(), 2 * 8);
398 let data = output.data();
400 let token0_l2: f32 = (0..8).map(|i| data[i] * data[i]).sum::<f32>().sqrt();
401 assert!(token0_l2 > 1e-6, "Token 0 should have non-zero embedding");
402 let token999_l2: f32 = (8..16).map(|i| data[i] * data[i]).sum::<f32>().sqrt();
404 assert!(token999_l2 < 1e-10, "OOB token should be zero-filled");
405 }
406
407 #[test]
409 fn falsify_e7e_init_deterministic() {
410 let embed1 = Embedding::new(100, 64);
411 let embed2 = Embedding::new(100, 64);
412 let d1 = embed1.weight.data();
413 let d2 = embed2.weight.data();
414 assert_eq!(
415 d1.as_slice().expect("operation should succeed"),
416 d2.as_slice().expect("operation should succeed"),
417 "FALSIFY-E7e: Same vocab+hidden must produce identical initialization"
418 );
419 }
420
421 #[test]
438 fn falsify_em_001_forward_output_shape() {
439 let embed = Embedding::new(100, 32);
440
441 for seq_len in [1, 3, 10, 50] {
442 let tokens: Vec<u32> = (0..seq_len).collect();
443 let output = embed.forward(&tokens);
444 assert_eq!(
445 output.len(),
446 seq_len as usize * 32,
447 "FALSIFIED EM-001: forward({seq_len} tokens) produced {} elements, expected {}",
448 output.len(),
449 seq_len as usize * 32
450 );
451 }
452 }
453
454 #[test]
456 fn falsify_em_001b_forward_empty_input() {
457 let embed = Embedding::new(100, 32);
458 let output = embed.forward(&[]);
459 assert_eq!(output.len(), 0, "FALSIFIED EM-001b: empty input should produce 0 elements");
460 }
461
462 #[test]
467 fn falsify_em_002_oob_safety() {
468 let vocab_size = 50;
469 let hidden = 8;
470 let embed = Embedding::new(vocab_size, hidden);
471
472 let oob_output = embed.forward(&[999, 50, 100]);
474 let oob_data = oob_output.data();
475 for (i, &v) in oob_data.iter().enumerate() {
476 assert!(v.abs() < 1e-10, "FALSIFIED EM-002: OOB output[{i}] = {v}, expected 0.0");
477 }
478
479 let mixed_output = embed.forward(&[0, 999, 49]);
481 let mixed_data = mixed_output.data();
482 let weight_data = embed.weight.data();
483
484 for d in 0..hidden {
486 assert_eq!(
487 mixed_data[d], weight_data[d],
488 "FALSIFIED EM-002: valid token 0 corrupted at dim {d}"
489 );
490 }
491
492 for d in 0..hidden {
494 assert!(
495 mixed_data[hidden + d].abs() < 1e-10,
496 "FALSIFIED EM-002: OOB token 999 at dim {d} = {}, expected 0.0",
497 mixed_data[hidden + d]
498 );
499 }
500
501 for d in 0..hidden {
503 assert_eq!(
504 mixed_data[2 * hidden + d],
505 weight_data[49 * hidden + d],
506 "FALSIFIED EM-002: valid boundary token 49 corrupted at dim {d}"
507 );
508 }
509 }
510
511 #[test]
513 fn falsify_em_003_forward_determinism() {
514 let embed = Embedding::new(100, 64);
515 let tokens = vec![5u32, 42, 0, 99, 17];
516
517 let o1 = embed.forward(&tokens);
518 let o2 = embed.forward(&tokens);
519
520 assert_eq!(
521 o1.data().as_slice().expect("operation should succeed"),
522 o2.data().as_slice().expect("operation should succeed"),
523 "FALSIFIED EM-003: forward() is non-deterministic"
524 );
525 }
526
527 #[test]
529 fn falsify_em_004_forward_finite_output() {
530 let embed = Embedding::new(200, 16);
531 let tokens: Vec<u32> = (0..200).collect();
532 let output = embed.forward(&tokens);
533 let data = output.data();
534
535 let nan_count = data.iter().filter(|v| v.is_nan()).count();
536 let inf_count = data.iter().filter(|v| v.is_infinite()).count();
537
538 assert_eq!(
539 nan_count, 0,
540 "FALSIFIED EM-004: forward output contains {nan_count} NaN values"
541 );
542 assert_eq!(
543 inf_count, 0,
544 "FALSIFIED EM-004: forward output contains {inf_count} Inf values"
545 );
546 }
547
548 #[test]
550 fn falsify_em_005_forward_value_correctness() {
551 let embed = Embedding::new(50, 8);
552 let tokens = vec![0u32, 10, 49];
553 let output = embed.forward(&tokens);
554 let out_data = output.data();
555 let weight_data = embed.weight.data();
556
557 for i in 0..8 {
559 assert_eq!(
560 out_data[i], weight_data[i],
561 "FALSIFIED EM-005: output[{i}] != weight[{i}] for token 0"
562 );
563 }
564 for i in 0..8 {
566 assert_eq!(
567 out_data[8 + i],
568 weight_data[80 + i],
569 "FALSIFIED EM-005: output[{}] != weight[{}] for token 10",
570 8 + i,
571 80 + i
572 );
573 }
574 }
575
576 #[test]
600 fn falsify_emb_001_lookup_determinism() {
601 let embed = Embedding::new(200, 48);
602 for t in [0u32, 1, 42, 100, 199] {
603 let v1 = embed.forward(&[t]);
604 let v2 = embed.forward(&[t]);
605 assert_eq!(
606 v1.data(),
607 v2.data(),
608 "FALSIFIED EMB-001: embed({t}) != embed({t}) — non-deterministic lookup"
609 );
610 }
611 }
612
613 #[test]
626 fn falsify_emb_002_shape_preservation() {
627 for (v, d) in [(100, 32), (200, 64), (500, 128), (50, 16)] {
628 let embed = Embedding::new(v, d);
629 let output = embed.forward(&[0, 1, 2]);
630 assert_eq!(
631 output.data().len(),
632 3 * d,
633 "FALSIFIED EMB-002: vocab={v}, d_model={d}, output len={} != 3*{d}",
634 output.data().len()
635 );
636 }
637 }
638
639 #[test]
652 fn falsify_emb_004_vocabulary_bounds() {
653 let vocab = 50;
654 let d = 16;
655 let embed = Embedding::new(vocab, d);
656
657 let valid_output = embed.forward(&[vocab as u32 - 1]);
659 let valid_norm: f32 = valid_output.data().iter().map(|v| v * v).sum();
660 assert!(
661 valid_norm > 0.0,
662 "FALSIFIED EMB-004: valid token {} produced zero embedding",
663 vocab - 1
664 );
665
666 let oob_output = embed.forward(&[vocab as u32]);
668 let oob_norm: f32 = oob_output.data().iter().map(|v| v * v).sum();
669 assert!(
670 oob_norm == 0.0,
671 "FALSIFIED EMB-004: OOB token {vocab} produced non-zero (norm={oob_norm})"
672 );
673 }
674
675 #[test]
677 fn falsify_emb_005_forward_non_zero() {
678 let embed = Embedding::new(100, 64);
679 let tokens = vec![0u32, 42, 99];
680 let output = embed.forward(&tokens);
681 let data = output.data();
682
683 let l2_norm: f32 = data.iter().map(|v| v * v).sum::<f32>().sqrt();
684 assert!(l2_norm > 1e-6, "FALSIFIED EMB-005: forward output is all-zero (L2={l2_norm})");
685 }
686
687 mod em_proptest_falsify {
699 use super::*;
700 use proptest::prelude::*;
701
702 proptest! {
704 #![proptest_config(ProptestConfig::with_cases(100))]
705 #[test]
706 fn falsify_em_001_prop_output_shape(
707 vocab_size in prop::sample::select(vec![50_usize, 100, 200, 500]),
708 hidden_size in prop::sample::select(vec![16_usize, 32, 48, 64]),
709 seq_len in 1_usize..32,
710 ) {
711 let embed = Embedding::new(vocab_size, hidden_size);
712 let tokens: Vec<u32> = (0..seq_len).map(|i| (i % vocab_size) as u32).collect();
713 let output = embed.forward(&tokens);
714 prop_assert_eq!(
715 output.len(), seq_len * hidden_size,
716 "FALSIFIED EM-001-prop: len={} != {}*{}={} (v={})",
717 output.len(), seq_len, hidden_size, seq_len * hidden_size, vocab_size
718 );
719 }
720 }
721
722 proptest! {
724 #![proptest_config(ProptestConfig::with_cases(50))]
725 #[test]
726 fn falsify_em_003_prop_determinism(
727 vocab_size in prop::sample::select(vec![50_usize, 100, 200]),
728 hidden_size in prop::sample::select(vec![16_usize, 32, 64]),
729 token_ids in proptest::collection::vec(0_u32..49, 1..16),
730 ) {
731 let embed = Embedding::new(vocab_size, hidden_size);
732 let out1 = embed.forward(&token_ids);
733 let out2 = embed.forward(&token_ids);
734 prop_assert_eq!(
735 out1.data(), out2.data(),
736 "FALSIFIED EM-003-prop: two calls differ (v={}, h={})",
737 vocab_size, hidden_size
738 );
739 }
740 }
741
742 proptest! {
744 #![proptest_config(ProptestConfig::with_cases(100))]
745 #[test]
746 fn falsify_em_004_prop_finite(
747 vocab_size in prop::sample::select(vec![50_usize, 100, 200]),
748 hidden_size in prop::sample::select(vec![16_usize, 32, 64]),
749 token_ids in proptest::collection::vec(0_u32..49, 1..16),
750 ) {
751 let embed = Embedding::new(vocab_size, hidden_size);
752 let output = embed.forward(&token_ids);
753 for (i, v) in output.data().iter().enumerate() {
754 prop_assert!(
755 v.is_finite(),
756 "FALSIFIED EM-004-prop: output[{}]={} not finite (v={}, h={})",
757 i, v, vocab_size, hidden_size
758 );
759 }
760 }
761 }
762 }
763
764 mod emb_proptest_falsify {
776 use super::*;
777 use proptest::prelude::*;
778
779 proptest! {
781 #![proptest_config(ProptestConfig::with_cases(100))]
782 #[test]
783 fn falsify_emb_001_prop_determinism(
784 vocab_size in prop::sample::select(vec![50_usize, 100, 200]),
785 hidden_size in prop::sample::select(vec![16_usize, 32, 64]),
786 token_id in 0_u32..49,
787 ) {
788 let embed = Embedding::new(vocab_size, hidden_size);
789 let v1 = embed.forward(&[token_id]);
790 let v2 = embed.forward(&[token_id]);
791 prop_assert_eq!(
792 v1.data(), v2.data(),
793 "FALSIFIED EMB-001-prop: embed({}) non-deterministic (v={}, h={})",
794 token_id, vocab_size, hidden_size
795 );
796 }
797 }
798
799 proptest! {
801 #![proptest_config(ProptestConfig::with_cases(100))]
802 #[test]
803 fn falsify_emb_002_prop_shape(
804 vocab_size in prop::sample::select(vec![50_usize, 100, 200, 500]),
805 hidden_size in prop::sample::select(vec![16_usize, 32, 48, 64, 128]),
806 seq_len in 1_usize..16,
807 ) {
808 let embed = Embedding::new(vocab_size, hidden_size);
809 let tokens: Vec<u32> = (0..seq_len).map(|i| (i % vocab_size) as u32).collect();
810 let output = embed.forward(&tokens);
811 prop_assert_eq!(
812 output.data().len(), seq_len * hidden_size,
813 "FALSIFIED EMB-002-prop: data len={} != {}*{}={} (v={})",
814 output.data().len(), seq_len, hidden_size, seq_len * hidden_size, vocab_size
815 );
816 }
817 }
818
819 proptest! {
821 #![proptest_config(ProptestConfig::with_cases(100))]
822 #[test]
823 fn falsify_emb_005_prop_non_zero(
824 vocab_size in prop::sample::select(vec![50_usize, 100, 200]),
825 hidden_size in prop::sample::select(vec![16_usize, 32, 64]),
826 token_ids in proptest::collection::vec(0_u32..49, 1..8),
827 ) {
828 let embed = Embedding::new(vocab_size, hidden_size);
829 let output = embed.forward(&token_ids);
830 let l2_norm: f32 = output.data().iter().map(|v| v * v).sum::<f32>().sqrt();
831 prop_assert!(
832 l2_norm > 1e-6,
833 "FALSIFIED EMB-005-prop: output all-zero (L2={}, v={}, h={})",
834 l2_norm, vocab_size, hidden_size
835 );
836 }
837 }
838 }
839}