1use crate::error::{Result, TextError};
199use scirs2_core::ndarray::{s, Array1, Array2, Array3, ArrayView2};
200use scirs2_core::random::{Rng, RngExt};
201use statrs::statistics::Statistics;
202use std::collections::HashMap;
203
204#[derive(Debug, Clone)]
206pub struct TransformerConfig {
207 pub d_model: usize,
209 pub nheads: usize,
211 pub d_ff: usize,
213 pub n_encoder_layers: usize,
215 pub n_decoder_layers: usize,
217 pub max_seqlen: usize,
219 pub dropout: f64,
221 pub vocab_size: usize,
223}
224
225impl Default for TransformerConfig {
226 fn default() -> Self {
227 Self {
228 d_model: 512,
229 nheads: 8,
230 d_ff: 2048,
231 n_encoder_layers: 6,
232 n_decoder_layers: 6,
233 max_seqlen: 512,
234 dropout: 0.1,
235 vocab_size: 10000,
236 }
237 }
238}
239
240pub struct PositionalEncoding {
242 encodings: Array2<f64>,
243 max_len: usize,
244 #[allow(dead_code)]
245 d_model: usize,
246}
247
248impl PositionalEncoding {
249 pub fn new(_max_len: usize, dmodel: usize) -> Self {
251 let mut encodings = Array2::<f64>::zeros((_max_len, dmodel));
252
253 for pos in 0.._max_len {
254 for i in (0..dmodel).step_by(2) {
255 let angle = pos as f64 / (10000.0_f64).powf(i as f64 / dmodel as f64);
256 encodings[[pos, i]] = angle.sin();
257 if i + 1 < dmodel {
258 encodings[[pos, i + 1]] = angle.cos();
259 }
260 }
261 }
262
263 Self {
264 encodings,
265 max_len: _max_len,
266 d_model: dmodel,
267 }
268 }
269
270 pub fn get_encoding(&self, seqlen: usize) -> Result<ArrayView2<f64>> {
272 if seqlen > self.max_len {
273 return Err(TextError::InvalidInput(format!(
274 "Sequence length {} exceeds maximum {}",
275 seqlen, self.max_len
276 )));
277 }
278 Ok(self.encodings.slice(s![0..seqlen, ..]))
279 }
280
281 pub fn get_encodings(&self) -> &Array2<f64> {
283 &self.encodings
284 }
285
286 pub fn set_encodings(&mut self, encodings: Array2<f64>) -> Result<()> {
288 let shape = encodings.shape();
289 if shape[0] != self.max_len || shape[1] != self.d_model {
290 return Err(TextError::InvalidInput(format!(
291 "Positional encoding shape {:?} does not match expected ({}, {})",
292 shape, self.max_len, self.d_model
293 )));
294 }
295 self.encodings = encodings;
296 Ok(())
297 }
298}
299
300pub struct MultiHeadAttention {
302 d_model: usize,
303 nheads: usize,
304 d_k: usize,
305 w_q: Array2<f64>,
306 w_k: Array2<f64>,
307 w_v: Array2<f64>,
308 w_o: Array2<f64>,
309}
310
311impl MultiHeadAttention {
312 pub fn new(d_model: usize, nheads: usize) -> Result<Self> {
314 if !d_model.is_multiple_of(nheads) {
315 return Err(TextError::InvalidInput(
316 "d_model must be divisible by nheads".to_string(),
317 ));
318 }
319
320 let d_k = d_model / nheads;
321
322 let scale = (2.0 / d_model as f64).sqrt();
324
325 let w_q = Array2::from_shape_fn((d_model, d_model), |_| {
326 scirs2_core::random::rng().random_range(-scale..scale)
327 });
328 let w_k = Array2::from_shape_fn((d_model, d_model), |_| {
329 scirs2_core::random::rng().random_range(-scale..scale)
330 });
331 let w_v = Array2::from_shape_fn((d_model, d_model), |_| {
332 scirs2_core::random::rng().random_range(-scale..scale)
333 });
334 let w_o = Array2::from_shape_fn((d_model, d_model), |_| {
335 scirs2_core::random::rng().random_range(-scale..scale)
336 });
337
338 Ok(Self {
339 d_model,
340 nheads,
341 d_k,
342 w_q,
343 w_k,
344 w_v,
345 w_o,
346 })
347 }
348
349 fn scaled_dot_product_attention(
351 &self,
352 q: ArrayView2<f64>,
353 k: ArrayView2<f64>,
354 v: ArrayView2<f64>,
355 mask: Option<ArrayView2<bool>>,
356 ) -> Result<Array2<f64>> {
357 let d_k = self.d_k as f64;
358
359 let scores = q.dot(&k.t()) / d_k.sqrt();
361
362 let mut masked_scores = scores;
364 if let Some(mask) = mask {
365 for ((i, j), &should_mask) in mask.indexed_iter() {
366 if should_mask {
367 masked_scores[[i, j]] = f64::NEG_INFINITY;
368 }
369 }
370 }
371
372 let attention_weights = self.softmax_2d(&masked_scores)?;
374
375 Ok(attention_weights.dot(&v))
377 }
378
379 fn softmax_2d(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
381 let mut result = x.clone();
382
383 for mut row in result.rows_mut() {
384 let max_val = row.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
385 row.mapv_inplace(|x| (x - max_val).exp());
386 let sum: f64 = row.sum();
387 if sum > 0.0 {
388 row /= sum;
389 }
390 }
391
392 Ok(result)
393 }
394
395 pub fn forward(
397 &self,
398 query: ArrayView2<f64>,
399 key: ArrayView2<f64>,
400 value: ArrayView2<f64>,
401 mask: Option<ArrayView2<bool>>,
402 ) -> Result<Array2<f64>> {
403 let _seqlen = query.shape()[0];
404
405 let q = query.dot(&self.w_q);
407 let k = key.dot(&self.w_k);
408 let v = value.dot(&self.w_v);
409
410 let q_heads = self.reshape_for_heads(&q)?;
412 let k_heads = self.reshape_for_heads(&k)?;
413 let v_heads = self.reshape_for_heads(&v)?;
414
415 let mut head_outputs = Vec::new();
417 for head in 0..self.nheads {
418 let q_head = q_heads.slice(s![head, .., ..]);
419 let k_head = k_heads.slice(s![head, .., ..]);
420 let v_head = v_heads.slice(s![head, .., ..]);
421
422 let head_output = self.scaled_dot_product_attention(q_head, k_head, v_head, mask)?;
423 head_outputs.push(head_output);
424 }
425
426 let concatenated = self.concatenate_heads(&head_outputs)?;
428
429 Ok(concatenated.dot(&self.w_o))
431 }
432
433 fn reshape_for_heads(&self, x: &Array2<f64>) -> Result<Array3<f64>> {
435 let (seqlen, d_model) = x.dim();
436 let reshaped = x
437 .clone()
438 .into_shape_with_order((seqlen, self.nheads, self.d_k))
439 .map_err(|e| TextError::InvalidInput(format!("Reshape error: {e}")))?;
440
441 Ok(reshaped.permuted_axes([1, 0, 2]))
443 }
444
445 fn concatenate_heads(&self, heads: &[Array2<f64>]) -> Result<Array2<f64>> {
447 if heads.is_empty() {
448 return Err(TextError::InvalidInput("No heads provided".to_string()));
449 }
450
451 let seqlen = heads[0].shape()[0];
452 let mut result = Array2::zeros((seqlen, self.d_model));
453
454 for (i, head) in heads.iter().enumerate() {
455 let start_col = i * self.d_k;
456 let end_col = start_col + self.d_k;
457 result.slice_mut(s![.., start_col..end_col]).assign(head);
458 }
459
460 Ok(result)
461 }
462
463 pub fn get_weights(&self) -> (&Array2<f64>, &Array2<f64>, &Array2<f64>, &Array2<f64>) {
465 (&self.w_q, &self.w_k, &self.w_v, &self.w_o)
466 }
467
468 pub fn set_weights(
470 &mut self,
471 w_q: Array2<f64>,
472 w_k: Array2<f64>,
473 w_v: Array2<f64>,
474 w_o: Array2<f64>,
475 ) -> Result<()> {
476 if w_q.shape() != [self.d_model, self.d_model] {
477 return Err(TextError::InvalidInput("Invalid w_q shape".to_string()));
478 }
479 if w_k.shape() != [self.d_model, self.d_model] {
480 return Err(TextError::InvalidInput("Invalid w_k shape".to_string()));
481 }
482 if w_v.shape() != [self.d_model, self.d_model] {
483 return Err(TextError::InvalidInput("Invalid w_v shape".to_string()));
484 }
485 if w_o.shape() != [self.d_model, self.d_model] {
486 return Err(TextError::InvalidInput("Invalid w_o shape".to_string()));
487 }
488
489 self.w_q = w_q;
490 self.w_k = w_k;
491 self.w_v = w_v;
492 self.w_o = w_o;
493 Ok(())
494 }
495}
496
497pub struct FeedForward {
499 w1: Array2<f64>,
500 w2: Array2<f64>,
501 b1: Array1<f64>,
502 b2: Array1<f64>,
503}
504
505impl FeedForward {
506 pub fn new(_dmodel: usize, dff: usize) -> Self {
508 let scale = (2.0 / _dmodel as f64).sqrt();
509
510 let w1 = Array2::from_shape_fn((_dmodel, dff), |_| {
511 scirs2_core::random::rng().random_range(-scale..scale)
512 });
513 let w2 = Array2::from_shape_fn((dff, _dmodel), |_| {
514 scirs2_core::random::rng().random_range(-scale..scale)
515 });
516 let b1 = Array1::zeros(dff);
517 let b2 = Array1::zeros(_dmodel);
518
519 Self { w1, w2, b1, b2 }
520 }
521
522 pub fn forward(&self, x: ArrayView2<f64>) -> Array2<f64> {
524 let hidden = x.dot(&self.w1) + &self.b1;
526 let activated = hidden.mapv(|x| x.max(0.0)); activated.dot(&self.w2) + &self.b2
530 }
531
532 pub fn get_weights(&self) -> (&Array2<f64>, &Array2<f64>, &Array1<f64>, &Array1<f64>) {
534 (&self.w1, &self.w2, &self.b1, &self.b2)
535 }
536
537 pub fn set_weights(
539 &mut self,
540 w1: Array2<f64>,
541 w2: Array2<f64>,
542 b1: Array1<f64>,
543 b2: Array1<f64>,
544 ) -> Result<()> {
545 if w1.shape()[1] != w2.shape()[0] {
546 return Err(TextError::InvalidInput(
547 "Weight matrix dimensions don't match".to_string(),
548 ));
549 }
550 if b1.len() != w1.shape()[1] {
551 return Err(TextError::InvalidInput(
552 "Bias b1 size doesn't match w1".to_string(),
553 ));
554 }
555 if b2.len() != w2.shape()[1] {
556 return Err(TextError::InvalidInput(
557 "Bias b2 size doesn't match w2".to_string(),
558 ));
559 }
560
561 self.w1 = w1;
562 self.w2 = w2;
563 self.b1 = b1;
564 self.b2 = b2;
565 Ok(())
566 }
567}
568
569pub struct LayerNorm {
571 gamma: Array1<f64>,
572 beta: Array1<f64>,
573 eps: f64,
574}
575
576impl LayerNorm {
577 pub fn new(_dmodel: usize, eps: f64) -> Self {
579 Self {
580 gamma: Array1::ones(_dmodel),
581 beta: Array1::zeros(_dmodel),
582 eps,
583 }
584 }
585
586 pub fn forward(&self, x: ArrayView2<f64>) -> Array2<f64> {
588 let mut result = Array2::zeros(x.raw_dim());
589
590 for (i, row) in x.rows().into_iter().enumerate() {
591 let mean = row.mean();
592 let var = row.mapv(|x| (x - mean).powi(2)).mean();
593 let std = (var + self.eps).sqrt();
594
595 let normalized = row.mapv(|x| (x - mean) / std);
596 let scaled = &normalized * &self.gamma + &self.beta;
597
598 result.row_mut(i).assign(&scaled);
599 }
600
601 result
602 }
603
604 pub fn get_params(&self) -> (&Array1<f64>, &Array1<f64>) {
606 (&self.gamma, &self.beta)
607 }
608
609 pub fn set_params(&mut self, gamma: Array1<f64>, beta: Array1<f64>) -> Result<()> {
611 if gamma.len() != beta.len() {
612 return Err(TextError::InvalidInput(
613 "Gamma and beta must have same length".to_string(),
614 ));
615 }
616 if gamma.len() != self.gamma.len() {
617 return Err(TextError::InvalidInput(
618 "Parameter size doesn't match layer dimension".to_string(),
619 ));
620 }
621
622 self.gamma = gamma;
623 self.beta = beta;
624 Ok(())
625 }
626}
627
628pub struct TransformerEncoderLayer {
630 self_attention: MultiHeadAttention,
631 feed_forward: FeedForward,
632 norm1: LayerNorm,
633 norm2: LayerNorm,
634 #[allow(dead_code)]
635 dropout: f64,
636}
637
638impl TransformerEncoderLayer {
639 pub fn new(config: &TransformerConfig) -> Result<Self> {
641 Ok(Self {
642 self_attention: MultiHeadAttention::new(config.d_model, config.nheads)?,
643 feed_forward: FeedForward::new(config.d_model, config.d_ff),
644 norm1: LayerNorm::new(config.d_model, 1e-6),
645 norm2: LayerNorm::new(config.d_model, 1e-6),
646 dropout: config.dropout,
647 })
648 }
649
650 pub fn forward(
652 &self,
653 x: ArrayView2<f64>,
654 mask: Option<ArrayView2<bool>>,
655 ) -> Result<Array2<f64>> {
656 let attn_output = self.self_attention.forward(x, x, x, mask)?;
658 let x = &self.norm1.forward(x) + &attn_output;
659
660 let ff_output = self.feed_forward.forward(x.view());
662 let output = &self.norm2.forward(x.view()) + &ff_output;
663
664 Ok(output)
665 }
666
667 pub fn get_components_mut(
669 &mut self,
670 ) -> (
671 &mut MultiHeadAttention,
672 &mut FeedForward,
673 &mut LayerNorm,
674 &mut LayerNorm,
675 ) {
676 (
677 &mut self.self_attention,
678 &mut self.feed_forward,
679 &mut self.norm1,
680 &mut self.norm2,
681 )
682 }
683
684 pub fn get_components(&self) -> (&MultiHeadAttention, &FeedForward, &LayerNorm, &LayerNorm) {
686 (
687 &self.self_attention,
688 &self.feed_forward,
689 &self.norm1,
690 &self.norm2,
691 )
692 }
693}
694
695pub struct TransformerEncoder {
697 layers: Vec<TransformerEncoderLayer>,
698 position_encoding: PositionalEncoding,
699 config: TransformerConfig,
700}
701
702impl TransformerEncoder {
703 pub fn new(config: TransformerConfig) -> Result<Self> {
705 let mut layers = Vec::new();
706 for _ in 0..config.n_encoder_layers {
707 layers.push(TransformerEncoderLayer::new(&config)?);
708 }
709
710 let position_encoding = PositionalEncoding::new(config.max_seqlen, config.d_model);
711
712 Ok(Self {
713 layers,
714 position_encoding,
715 config,
716 })
717 }
718
719 pub fn encode(
721 &self,
722 embeddings: ArrayView2<f64>,
723 mask: Option<ArrayView2<bool>>,
724 ) -> Result<Array2<f64>> {
725 let seqlen = embeddings.shape()[0];
726
727 let pos_enc = self.position_encoding.get_encoding(seqlen)?;
729 let mut x = embeddings.to_owned() + pos_enc;
730
731 for layer in &self.layers {
733 x = layer.forward(x.view(), mask)?;
734 }
735
736 Ok(x)
737 }
738
739 pub fn config(&self) -> &TransformerConfig {
741 &self.config
742 }
743
744 pub fn get_layers_mut(&mut self) -> &mut Vec<TransformerEncoderLayer> {
746 &mut self.layers
747 }
748
749 pub fn get_layers(&self) -> &Vec<TransformerEncoderLayer> {
751 &self.layers
752 }
753
754 pub fn get_position_encoding(&self) -> &Array2<f64> {
756 self.position_encoding.get_encodings()
757 }
758
759 pub fn set_position_encoding(&mut self, encodings: Array2<f64>) -> Result<()> {
761 self.position_encoding.set_encodings(encodings)
762 }
763}
764
765pub struct TransformerDecoderLayer {
767 self_attention: MultiHeadAttention,
768 cross_attention: MultiHeadAttention,
769 feed_forward: FeedForward,
770 norm1: LayerNorm,
771 norm2: LayerNorm,
772 norm3: LayerNorm,
773 #[allow(dead_code)]
774 dropout: f64,
775}
776
777impl TransformerDecoderLayer {
778 pub fn new(config: &TransformerConfig) -> Result<Self> {
780 Ok(Self {
781 self_attention: MultiHeadAttention::new(config.d_model, config.nheads)?,
782 cross_attention: MultiHeadAttention::new(config.d_model, config.nheads)?,
783 feed_forward: FeedForward::new(config.d_model, config.d_ff),
784 norm1: LayerNorm::new(config.d_model, 1e-6),
785 norm2: LayerNorm::new(config.d_model, 1e-6),
786 norm3: LayerNorm::new(config.d_model, 1e-6),
787 dropout: config.dropout,
788 })
789 }
790
791 pub fn forward(
793 &self,
794 x: ArrayView2<f64>,
795 encoder_output: ArrayView2<f64>,
796 self_attn_mask: Option<ArrayView2<bool>>,
797 cross_attn_mask: Option<ArrayView2<bool>>,
798 ) -> Result<Array2<f64>> {
799 let self_attn_out = self.self_attention.forward(x, x, x, self_attn_mask)?;
801 let x = self.norm1.forward((x.to_owned() + self_attn_out).view());
802
803 let cross_attn_out = self.cross_attention.forward(
805 x.view(),
806 encoder_output,
807 encoder_output,
808 cross_attn_mask,
809 )?;
810 let x = self.norm2.forward((x + cross_attn_out).view());
811
812 let ff_out = self.feed_forward.forward(x.view());
814 let _output = self.norm3.forward((x + ff_out).view());
815
816 Ok(_output)
817 }
818}
819
820pub struct TransformerDecoder {
822 layers: Vec<TransformerDecoderLayer>,
823 position_encoding: PositionalEncoding,
824 config: TransformerConfig,
825}
826
827impl TransformerDecoder {
828 pub fn new(config: TransformerConfig) -> Result<Self> {
830 let mut layers = Vec::new();
831 for _ in 0..config.n_decoder_layers {
832 layers.push(TransformerDecoderLayer::new(&config)?);
833 }
834
835 let position_encoding = PositionalEncoding::new(config.max_seqlen, config.d_model);
836
837 Ok(Self {
838 layers,
839 position_encoding,
840 config,
841 })
842 }
843
844 pub fn forward(
846 &self,
847 embeddings: ArrayView2<f64>,
848 encoder_output: ArrayView2<f64>,
849 self_attn_mask: Option<ArrayView2<bool>>,
850 cross_attn_mask: Option<ArrayView2<bool>>,
851 ) -> Result<Array2<f64>> {
852 let seqlen = embeddings.shape()[0];
853
854 let pos_enc = self.position_encoding.get_encoding(seqlen)?;
856 let mut x = embeddings.to_owned() + pos_enc;
857
858 for layer in &self.layers {
860 x = layer.forward(x.view(), encoder_output, self_attn_mask, cross_attn_mask)?;
861 }
862
863 Ok(x)
864 }
865
866 pub fn config(&self) -> &TransformerConfig {
868 &self.config
869 }
870}
871
872pub struct TokenEmbedding {
874 embeddings: Array2<f64>,
875 vocab_size: usize,
876 d_model: usize,
877}
878
879impl TokenEmbedding {
880 pub fn new(_vocab_size: usize, dmodel: usize) -> Self {
882 let scale = (1.0 / dmodel as f64).sqrt();
883 let embeddings = Array2::from_shape_fn((_vocab_size, dmodel), |_| {
884 scirs2_core::random::rng().random_range(-scale..scale)
885 });
886
887 Self {
888 embeddings,
889 vocab_size: _vocab_size,
890 d_model: dmodel,
891 }
892 }
893
894 pub fn forward(&self, tokenids: &[usize]) -> Result<Array2<f64>> {
896 let mut result = Array2::zeros((tokenids.len(), self.d_model));
897
898 for (i, &token_id) in tokenids.iter().enumerate() {
899 if token_id >= self.vocab_size {
900 return Err(TextError::InvalidInput(format!(
901 "Token ID {} exceeds vocabulary size {}",
902 token_id, self.vocab_size
903 )));
904 }
905 result.row_mut(i).assign(&self.embeddings.row(token_id));
906 }
907
908 Ok(result)
909 }
910
911 pub fn get_embeddings(&self) -> &Array2<f64> {
913 &self.embeddings
914 }
915
916 pub fn set_embeddings(&mut self, embeddings: Array2<f64>) -> Result<()> {
918 if embeddings.shape()[0] != self.vocab_size || embeddings.shape()[1] != self.d_model {
919 return Err(TextError::InvalidInput(format!(
920 "Embedding shape {:?} doesn't match expected ({}, {})",
921 embeddings.shape(),
922 self.vocab_size,
923 self.d_model
924 )));
925 }
926 self.embeddings = embeddings;
927 Ok(())
928 }
929}
930
931pub struct TransformerModel {
933 pub config: TransformerConfig,
935 pub token_embedding: TokenEmbedding,
937 pub encoder: TransformerEncoder,
939 pub decoder: Option<TransformerDecoder>,
941 vocab_to_id: HashMap<String, usize>,
942 id_to_vocab: HashMap<usize, String>,
943}
944
945impl TransformerModel {
946 pub fn new(config: TransformerConfig, vocabulary: Vec<String>) -> Result<Self> {
948 let vocab_size = vocabulary.len();
949 if vocab_size != config.vocab_size {
950 return Err(TextError::InvalidInput(format!(
951 "Vocabulary size {} doesn't match config {}",
952 vocab_size, config.vocab_size
953 )));
954 }
955
956 let mut vocab_to_id = HashMap::new();
957 let mut id_to_vocab = HashMap::new();
958
959 for (id, token) in vocabulary.into_iter().enumerate() {
960 vocab_to_id.insert(token.clone(), id);
961 id_to_vocab.insert(id, token);
962 }
963
964 Ok(Self {
965 config: config.clone(),
966 token_embedding: TokenEmbedding::new(config.vocab_size, config.d_model),
967 encoder: TransformerEncoder::new(config)?,
968 decoder: None, vocab_to_id,
970 id_to_vocab,
971 })
972 }
973
974 pub fn encode_tokens(&self, tokens: &[String]) -> Result<Array2<f64>> {
976 let tokenids: Result<Vec<usize>> = tokens
978 .iter()
979 .map(|token| {
980 self.vocab_to_id
981 .get(token)
982 .cloned()
983 .ok_or_else(|| TextError::InvalidInput(format!("Unknown token: {token}")))
984 })
985 .collect();
986 let tokenids = tokenids?;
987
988 let embeddings = self.token_embedding.forward(&tokenids)?;
990
991 self.encoder.encode(embeddings.view(), None)
993 }
994
995 pub fn new_encoder_decoder(config: TransformerConfig, vocabulary: Vec<String>) -> Result<Self> {
997 let vocab_size = vocabulary.len();
998 if vocab_size != config.vocab_size {
999 return Err(TextError::InvalidInput(format!(
1000 "Vocabulary size {} doesn't match config {}",
1001 vocab_size, config.vocab_size
1002 )));
1003 }
1004
1005 let mut vocab_to_id = HashMap::new();
1006 let mut id_to_vocab = HashMap::new();
1007
1008 for (id, token) in vocabulary.into_iter().enumerate() {
1009 vocab_to_id.insert(token.clone(), id);
1010 id_to_vocab.insert(id, token);
1011 }
1012
1013 Ok(Self {
1014 config: config.clone(),
1015 token_embedding: TokenEmbedding::new(config.vocab_size, config.d_model),
1016 encoder: TransformerEncoder::new(config.clone())?,
1017 decoder: Some(TransformerDecoder::new(config)?),
1018 vocab_to_id,
1019 id_to_vocab,
1020 })
1021 }
1022
1023 pub fn encode_decode(
1025 &self,
1026 input_tokens: &[String],
1027 target_tokens: &[String],
1028 ) -> Result<Array2<f64>> {
1029 let decoder = self
1030 .decoder
1031 .as_ref()
1032 .ok_or_else(|| TextError::InvalidInput("Model has no decoder".to_string()))?;
1033
1034 let encoder_output = self.encode_tokens(input_tokens)?;
1036
1037 let target_ids: Result<Vec<usize>> = target_tokens
1039 .iter()
1040 .map(|token| {
1041 self.vocab_to_id
1042 .get(token)
1043 .copied()
1044 .ok_or_else(|| TextError::InvalidInput(format!("Unknown token: {token}")))
1045 })
1046 .collect();
1047 let target_ids = target_ids?;
1048
1049 let target_embeddings = self.token_embedding.forward(&target_ids)?;
1050
1051 let seqlen = target_tokens.len();
1053 let mut causal_mask = Array2::from_elem((seqlen, seqlen), false);
1054 for i in 0..seqlen {
1055 for j in (i + 1)..seqlen {
1056 causal_mask[[i, j]] = true; }
1058 }
1059
1060 decoder.forward(
1062 target_embeddings.view(),
1063 encoder_output.view(),
1064 Some(causal_mask.view()),
1065 None,
1066 )
1067 }
1068
1069 pub fn generate(
1071 &self,
1072 input_tokens: &[String],
1073 max_length: usize,
1074 start_token: &str,
1075 ) -> Result<Vec<String>> {
1076 let decoder = self
1077 .decoder
1078 .as_ref()
1079 .ok_or_else(|| TextError::InvalidInput("Model has no decoder".to_string()))?;
1080
1081 let encoder_output = self.encode_tokens(input_tokens)?;
1083
1084 let mut generated_tokens = vec![start_token.to_string()];
1086
1087 for _ in 0..max_length {
1088 let current_ids: Result<Vec<usize>> = generated_tokens
1090 .iter()
1091 .map(|_token| {
1092 self.vocab_to_id
1093 .get(_token)
1094 .copied()
1095 .ok_or_else(|| TextError::InvalidInput(format!("Unknown token: {_token}")))
1096 })
1097 .collect();
1098 let current_ids = current_ids?;
1099
1100 let current_embeddings = self.token_embedding.forward(¤t_ids)?;
1101
1102 let seqlen = generated_tokens.len();
1104 let mut causal_mask = Array2::from_elem((seqlen, seqlen), false);
1105 for i in 0..seqlen {
1106 for j in (i + 1)..seqlen {
1107 causal_mask[[i, j]] = true;
1108 }
1109 }
1110
1111 let decoder_output = decoder.forward(
1113 current_embeddings.view(),
1114 encoder_output.view(),
1115 Some(causal_mask.view()),
1116 None,
1117 )?;
1118
1119 let last_output = decoder_output.row(decoder_output.nrows() - 1);
1121
1122 let mut best_token_id = 0;
1124 let mut best_score = last_output[0];
1125 for (i, &score) in last_output.iter().enumerate() {
1126 if score > best_score {
1127 best_score = score;
1128 best_token_id = i;
1129 }
1130 }
1131
1132 if let Some(_token) = self.id_to_vocab.get(&best_token_id) {
1134 generated_tokens.push(_token.clone());
1135
1136 if _token == "</s>" || _token == "<eos>" {
1138 break;
1139 }
1140 } else {
1141 break;
1142 }
1143 }
1144
1145 Ok(generated_tokens)
1146 }
1147
1148 pub fn vocabulary(&self) -> (&HashMap<String, usize>, &HashMap<usize, String>) {
1150 (&self.vocab_to_id, &self.id_to_vocab)
1151 }
1152}
1153
1154#[cfg(test)]
1155mod tests {
1156 use super::*;
1157
1158 #[test]
1159 fn test_positional_encoding() {
1160 let pos_enc = PositionalEncoding::new(10, 4);
1161 let encoding = pos_enc.get_encoding(5).expect("Operation failed");
1162 assert_eq!(encoding.shape(), &[5, 4]);
1163
1164 let pos0 = encoding.row(0);
1166 let pos1 = encoding.row(1);
1167 assert!(pos0
1168 .iter()
1169 .zip(pos1.iter())
1170 .any(|(a, b)| (a - b).abs() > 1e-6));
1171 }
1172
1173 #[test]
1174 fn test_multi_head_attention() {
1175 let mha = MultiHeadAttention::new(8, 2).expect("Operation failed");
1176 let seqlen = 4;
1177 let d_model = 8;
1178
1179 let input = Array2::ones((seqlen, d_model));
1180 let output = mha
1181 .forward(input.view(), input.view(), input.view(), None)
1182 .expect("Operation failed");
1183
1184 assert_eq!(output.shape(), &[seqlen, d_model]);
1185 }
1186
1187 #[test]
1188 fn test_transformer_encoder() {
1189 let config = TransformerConfig {
1190 d_model: 8,
1191 nheads: 2,
1192 d_ff: 16,
1193 n_encoder_layers: 2,
1194 ..Default::default()
1195 };
1196
1197 let encoder = TransformerEncoder::new(config).expect("Operation failed");
1198 let input = Array2::ones((4, 8));
1199 let output = encoder
1200 .encode(input.view(), None)
1201 .expect("Operation failed");
1202
1203 assert_eq!(output.shape(), &[4, 8]);
1204 }
1205}