Skip to main content

scirs2_text/
transformer.rs

1//! # Transformer Architecture Module
2//!
3//! This module provides a complete implementation of the Transformer architecture,
4//! the foundation of modern language models like BERT, GPT, and T5. It includes
5//! all essential components for building state-of-the-art NLP models.
6//!
7//! ## Overview
8//!
9//! The Transformer architecture revolutionized natural language processing by
10//! introducing the self-attention mechanism. This module implements:
11//!
12//! - **Multi-Head Attention**: Core attention mechanism with multiple attention heads
13//! - **Positional Encoding**: Sinusoidal and learned position representations
14//! - **Encoder-Decoder Architecture**: Full transformer with both encoder and decoder stacks
15//! - **Layer Normalization**: Pre-norm and post-norm variants
16//! - **Feed-Forward Networks**: Position-wise fully connected layers
17//! - **Token Embeddings**: Learnable word and position embeddings
18//!
19//! ## Quick Start
20//!
21//! ```rust
22//! use scirs2_text::transformer::{TransformerModel, TransformerConfig};
23//!
24//! // Configure the transformer
25//! let config = TransformerConfig {
26//!     d_model: 512,           // Model dimension
27//!     nheads: 8,             // Number of attention heads
28//!     d_ff: 2048,             // Feed-forward dimension
29//!     n_encoder_layers: 6,    // Number of encoder layers
30//!     n_decoder_layers: 6,    // Number of decoder layers
31//!     max_seqlen: 512,       // Maximum sequence length
32//!     dropout: 0.1,           // Dropout rate
33//!     vocab_size: 10000,      // Vocabulary size
34//! };
35//!
36//! // Create the model
37//! let vocabulary = (0..10000).map(|i| format!("token_{}", i)).collect();
38//! let mut transformer = TransformerModel::new(config, vocabulary).expect("Operation failed");
39//!
40//! // Example input sequences (string tokens)
41//! let src_tokens = vec!["token_1".to_string(), "token_2".to_string(), "token_3".to_string()];
42//!
43//! // Encode the tokens
44//! let output = transformer.encode_tokens(&src_tokens).expect("Operation failed");
45//! println!("Model output shape: {:?}", output.shape());
46//! ```
47//!
48//! ## Building Individual Components
49//!
50//! ### Multi-Head Attention
51//!
52//! ```rust
53//! use scirs2_text::transformer::MultiHeadAttention;
54//! use scirs2_core::ndarray::Array2;
55//!
56//! let d_model = 512;
57//! let nheads = 8;
58//! let mut attention = MultiHeadAttention::new(d_model, nheads).expect("Operation failed");
59//!
60//! // Create dummy input (batch_size=2, seqlen=10, d_model=512)
61//! let input = Array2::zeros((10, 512));
62//! let output = attention.forward(input.view(), input.view(), input.view(), None).expect("Operation failed");
63//! ```
64//!
65//! ### Positional Encoding
66//!
67//! ```rust
68//! use scirs2_text::transformer::PositionalEncoding;
69//! use scirs2_core::ndarray::Array2;
70//!
71//! let d_model = 512;
72//! let max_len = 1000;
73//! let pos_encoding = PositionalEncoding::new(d_model, max_len);
74//!
75//! // Apply positional encoding to embeddings
76//! let seqlen = 20;
77//! let embeddings = Array2::<f64>::zeros((seqlen, d_model));
78//! let positional_encodings = pos_encoding.get_encoding(seqlen).expect("Operation failed");
79//! println!("Embeddings shape: {:?}", embeddings.shape());
80//! println!("Positional encodings shape: {:?}", positional_encodings.shape());
81//! ```
82//!
83//! ### Complete Encoder
84//!
85//! ```rust
86//! use scirs2_text::transformer::{TransformerEncoder, TransformerConfig};
87//! use scirs2_core::ndarray::Array2;
88//!
89//! let config = TransformerConfig {
90//!     d_model: 256,
91//!     nheads: 4,
92//!     d_ff: 1024,
93//!     n_encoder_layers: 3,
94//!     dropout: 0.1,
95//!     ..Default::default()
96//! };
97//!
98//! let encoder = TransformerEncoder::new(config).expect("Operation failed");
99//! let input = Array2::zeros((50, 256)); // (seqlen, d_model)
100//! let encoded = encoder.encode(input.view(), None).expect("Operation failed");
101//! ```
102//!
103//! ## Advanced Usage
104//!
105//! ### Custom Attention Patterns
106//!
107//! ```rust
108//! use scirs2_text::transformer::MultiHeadAttention;
109//! use scirs2_core::ndarray::Array2;
110//!
111//! let mut attention = MultiHeadAttention::new(512, 8).expect("Operation failed");
112//!
113//! // Create attention mask for autoregressive generation
114//! let seqlen = 10;
115//! let mut mask = Array2::from_elem((seqlen, seqlen), false);
116//! for i in 0..seqlen {
117//!     for j in (i+1)..seqlen {
118//!         mask[[i, j]] = true; // Mask future positions
119//!     }
120//! }
121//!
122//! let query = Array2::zeros((seqlen, 512));
123//! let key = Array2::zeros((seqlen, 512));
124//! let value = Array2::zeros((seqlen, 512));
125//! let output = attention.forward(query.view(), key.view(), value.view(), Some(mask.view())).expect("Operation failed");
126//! ```
127//!
128//! ### Layer-wise Learning Rate Decay
129//!
130//! ```rust
131//! use scirs2_text::transformer::{TransformerModel, TransformerConfig};
132//!
133//! # let config = TransformerConfig::default();
134//! # let vocabulary: Vec<String> = (0..config.vocab_size).map(|i| format!("token_{}", i)).collect();
135//! // Apply different learning rates to different layers  
136//! let mut model = TransformerModel::new(config, vocabulary).expect("Operation failed");
137//!
138//! // Typically: deeper layers get smaller learning rates
139//! let base_lr = 1e-4;
140//! // Note: Layer parameters would be accessed through training APIs
141//! println!("Base learning rate: {}", base_lr);
142//! ```
143//!
144//! ## Architecture Details
145//!
146//! ### Attention Mechanism
147//!
148//! The multi-head attention computes:
149//!
150//! ```text
151//! Attention(Q, K, V) = softmax(QK^T / √d_k)V
152//! MultiHead(Q, K, V) = Concat(head_1, ..., head_h)W^O
153//! where head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)
154//! ```
155//!
156//! ### Positional Encoding
157//!
158//! Uses sinusoidal functions to encode position information:
159//!
160//! ```text
161//! PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
162//! PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
163//! ```
164//!
165//! ### Layer Structure
166//!
167//! Each encoder/decoder layer follows the pattern:
168//!
169//! ```text
170//! x = LayerNorm(x + SelfAttention(x))
171//! x = LayerNorm(x + FeedForward(x))
172//! ```
173//!
174//! ## Performance Optimization
175//!
176//! 1. **Gradient Checkpointing**: Trade memory for computation in deep models
177//! 2. **Mixed Precision**: Use FP16 for faster training with minimal quality loss
178//! 3. **Key-Value Caching**: Cache attention keys and values during inference
179//! 4. **Attention Patterns**: Use sparse attention for very long sequences
180//! 5. **Model Parallelism**: Split large models across multiple GPUs
181//!
182//! ## Common Use Cases
183//!
184//! - **Machine Translation**: Encoder-decoder for seq2seq tasks
185//! - **Language Modeling**: Decoder-only for autoregressive generation
186//! - **Text Classification**: Encoder with classification head
187//! - **Question Answering**: Encoder with span prediction heads
188//! - **Text Summarization**: Encoder-decoder with copy mechanism
189//!
190//! ## Best Practices
191//!
192//! 1. **Warmup Learning Rate**: Start with small LR and gradually increase
193//! 2. **Layer Normalization**: Pre-norm generally works better than post-norm
194//! 3. **Residual Connections**: Essential for training deep networks
195//! 4. **Attention Dropout**: Apply dropout to attention weights, not just outputs
196//! 5. **Weight Initialization**: Use Xavier/Glorot initialization for stability
197
198use crate::error::{Result, TextError};
199use scirs2_core::ndarray::{s, Array1, Array2, Array3, ArrayView2};
200use scirs2_core::random::{Rng, RngExt};
201use statrs::statistics::Statistics;
202use std::collections::HashMap;
203
204/// Configuration for transformer models
205#[derive(Debug, Clone)]
206pub struct TransformerConfig {
207    /// Model dimension (embedding size)
208    pub d_model: usize,
209    /// Number of attention heads
210    pub nheads: usize,
211    /// Feed-forward network dimension
212    pub d_ff: usize,
213    /// Number of encoder layers
214    pub n_encoder_layers: usize,
215    /// Number of decoder layers
216    pub n_decoder_layers: usize,
217    /// Maximum sequence length
218    pub max_seqlen: usize,
219    /// Dropout rate
220    pub dropout: f64,
221    /// Vocabulary size
222    pub vocab_size: usize,
223}
224
225impl Default for TransformerConfig {
226    fn default() -> Self {
227        Self {
228            d_model: 512,
229            nheads: 8,
230            d_ff: 2048,
231            n_encoder_layers: 6,
232            n_decoder_layers: 6,
233            max_seqlen: 512,
234            dropout: 0.1,
235            vocab_size: 10000,
236        }
237    }
238}
239
240/// Position encoding for transformer models
241pub struct PositionalEncoding {
242    encodings: Array2<f64>,
243    max_len: usize,
244    #[allow(dead_code)]
245    d_model: usize,
246}
247
248impl PositionalEncoding {
249    /// Create new positional encoding
250    pub fn new(_max_len: usize, dmodel: usize) -> Self {
251        let mut encodings = Array2::<f64>::zeros((_max_len, dmodel));
252
253        for pos in 0.._max_len {
254            for i in (0..dmodel).step_by(2) {
255                let angle = pos as f64 / (10000.0_f64).powf(i as f64 / dmodel as f64);
256                encodings[[pos, i]] = angle.sin();
257                if i + 1 < dmodel {
258                    encodings[[pos, i + 1]] = angle.cos();
259                }
260            }
261        }
262
263        Self {
264            encodings,
265            max_len: _max_len,
266            d_model: dmodel,
267        }
268    }
269
270    /// Get position encoding for given sequence length
271    pub fn get_encoding(&self, seqlen: usize) -> Result<ArrayView2<f64>> {
272        if seqlen > self.max_len {
273            return Err(TextError::InvalidInput(format!(
274                "Sequence length {} exceeds maximum {}",
275                seqlen, self.max_len
276            )));
277        }
278        Ok(self.encodings.slice(s![0..seqlen, ..]))
279    }
280
281    /// Get a view of the full encodings matrix (for serialization)
282    pub fn get_encodings(&self) -> &Array2<f64> {
283        &self.encodings
284    }
285
286    /// Set the encodings matrix from loaded weights, validating shape
287    pub fn set_encodings(&mut self, encodings: Array2<f64>) -> Result<()> {
288        let shape = encodings.shape();
289        if shape[0] != self.max_len || shape[1] != self.d_model {
290            return Err(TextError::InvalidInput(format!(
291                "Positional encoding shape {:?} does not match expected ({}, {})",
292                shape, self.max_len, self.d_model
293            )));
294        }
295        self.encodings = encodings;
296        Ok(())
297    }
298}
299
300/// Multi-head attention mechanism
301pub struct MultiHeadAttention {
302    d_model: usize,
303    nheads: usize,
304    d_k: usize,
305    w_q: Array2<f64>,
306    w_k: Array2<f64>,
307    w_v: Array2<f64>,
308    w_o: Array2<f64>,
309}
310
311impl MultiHeadAttention {
312    /// Create new multi-head attention layer
313    pub fn new(d_model: usize, nheads: usize) -> Result<Self> {
314        if !d_model.is_multiple_of(nheads) {
315            return Err(TextError::InvalidInput(
316                "d_model must be divisible by nheads".to_string(),
317            ));
318        }
319
320        let d_k = d_model / nheads;
321
322        // Initialize weight matrices with Xavier initialization
323        let scale = (2.0 / d_model as f64).sqrt();
324
325        let w_q = Array2::from_shape_fn((d_model, d_model), |_| {
326            scirs2_core::random::rng().random_range(-scale..scale)
327        });
328        let w_k = Array2::from_shape_fn((d_model, d_model), |_| {
329            scirs2_core::random::rng().random_range(-scale..scale)
330        });
331        let w_v = Array2::from_shape_fn((d_model, d_model), |_| {
332            scirs2_core::random::rng().random_range(-scale..scale)
333        });
334        let w_o = Array2::from_shape_fn((d_model, d_model), |_| {
335            scirs2_core::random::rng().random_range(-scale..scale)
336        });
337
338        Ok(Self {
339            d_model,
340            nheads,
341            d_k,
342            w_q,
343            w_k,
344            w_v,
345            w_o,
346        })
347    }
348
349    /// Compute scaled dot-product attention
350    fn scaled_dot_product_attention(
351        &self,
352        q: ArrayView2<f64>,
353        k: ArrayView2<f64>,
354        v: ArrayView2<f64>,
355        mask: Option<ArrayView2<bool>>,
356    ) -> Result<Array2<f64>> {
357        let d_k = self.d_k as f64;
358
359        // Compute attention scores: Q * K^T / sqrt(d_k)
360        let scores = q.dot(&k.t()) / d_k.sqrt();
361
362        // Apply mask if provided
363        let mut masked_scores = scores;
364        if let Some(mask) = mask {
365            for ((i, j), &should_mask) in mask.indexed_iter() {
366                if should_mask {
367                    masked_scores[[i, j]] = f64::NEG_INFINITY;
368                }
369            }
370        }
371
372        // Apply softmax
373        let attention_weights = self.softmax_2d(&masked_scores)?;
374
375        // Apply attention to values
376        Ok(attention_weights.dot(&v))
377    }
378
379    /// Apply softmax to 2D array along last axis
380    fn softmax_2d(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
381        let mut result = x.clone();
382
383        for mut row in result.rows_mut() {
384            let max_val = row.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
385            row.mapv_inplace(|x| (x - max_val).exp());
386            let sum: f64 = row.sum();
387            if sum > 0.0 {
388                row /= sum;
389            }
390        }
391
392        Ok(result)
393    }
394
395    /// Forward pass through multi-head attention
396    pub fn forward(
397        &self,
398        query: ArrayView2<f64>,
399        key: ArrayView2<f64>,
400        value: ArrayView2<f64>,
401        mask: Option<ArrayView2<bool>>,
402    ) -> Result<Array2<f64>> {
403        let _seqlen = query.shape()[0];
404
405        // Linear projections
406        let q = query.dot(&self.w_q);
407        let k = key.dot(&self.w_k);
408        let v = value.dot(&self.w_v);
409
410        // Reshape for multi-head attention
411        let q_heads = self.reshape_for_heads(&q)?;
412        let k_heads = self.reshape_for_heads(&k)?;
413        let v_heads = self.reshape_for_heads(&v)?;
414
415        // Apply attention for each head
416        let mut head_outputs = Vec::new();
417        for head in 0..self.nheads {
418            let q_head = q_heads.slice(s![head, .., ..]);
419            let k_head = k_heads.slice(s![head, .., ..]);
420            let v_head = v_heads.slice(s![head, .., ..]);
421
422            let head_output = self.scaled_dot_product_attention(q_head, k_head, v_head, mask)?;
423            head_outputs.push(head_output);
424        }
425
426        // Concatenate heads
427        let concatenated = self.concatenate_heads(&head_outputs)?;
428
429        // Final linear projection
430        Ok(concatenated.dot(&self.w_o))
431    }
432
433    /// Reshape tensor for multi-head attention
434    fn reshape_for_heads(&self, x: &Array2<f64>) -> Result<Array3<f64>> {
435        let (seqlen, d_model) = x.dim();
436        let reshaped = x
437            .clone()
438            .into_shape_with_order((seqlen, self.nheads, self.d_k))
439            .map_err(|e| TextError::InvalidInput(format!("Reshape error: {e}")))?;
440
441        // Transpose to (nheads, seqlen, d_k)
442        Ok(reshaped.permuted_axes([1, 0, 2]))
443    }
444
445    /// Concatenate attention heads
446    fn concatenate_heads(&self, heads: &[Array2<f64>]) -> Result<Array2<f64>> {
447        if heads.is_empty() {
448            return Err(TextError::InvalidInput("No heads provided".to_string()));
449        }
450
451        let seqlen = heads[0].shape()[0];
452        let mut result = Array2::zeros((seqlen, self.d_model));
453
454        for (i, head) in heads.iter().enumerate() {
455            let start_col = i * self.d_k;
456            let end_col = start_col + self.d_k;
457            result.slice_mut(s![.., start_col..end_col]).assign(head);
458        }
459
460        Ok(result)
461    }
462
463    /// Get attention weight matrices for serialization
464    pub fn get_weights(&self) -> (&Array2<f64>, &Array2<f64>, &Array2<f64>, &Array2<f64>) {
465        (&self.w_q, &self.w_k, &self.w_v, &self.w_o)
466    }
467
468    /// Set attention weight matrices from loaded weights
469    pub fn set_weights(
470        &mut self,
471        w_q: Array2<f64>,
472        w_k: Array2<f64>,
473        w_v: Array2<f64>,
474        w_o: Array2<f64>,
475    ) -> Result<()> {
476        if w_q.shape() != [self.d_model, self.d_model] {
477            return Err(TextError::InvalidInput("Invalid w_q shape".to_string()));
478        }
479        if w_k.shape() != [self.d_model, self.d_model] {
480            return Err(TextError::InvalidInput("Invalid w_k shape".to_string()));
481        }
482        if w_v.shape() != [self.d_model, self.d_model] {
483            return Err(TextError::InvalidInput("Invalid w_v shape".to_string()));
484        }
485        if w_o.shape() != [self.d_model, self.d_model] {
486            return Err(TextError::InvalidInput("Invalid w_o shape".to_string()));
487        }
488
489        self.w_q = w_q;
490        self.w_k = w_k;
491        self.w_v = w_v;
492        self.w_o = w_o;
493        Ok(())
494    }
495}
496
497/// Feed-forward network layer
498pub struct FeedForward {
499    w1: Array2<f64>,
500    w2: Array2<f64>,
501    b1: Array1<f64>,
502    b2: Array1<f64>,
503}
504
505impl FeedForward {
506    /// Create new feed-forward layer
507    pub fn new(_dmodel: usize, dff: usize) -> Self {
508        let scale = (2.0 / _dmodel as f64).sqrt();
509
510        let w1 = Array2::from_shape_fn((_dmodel, dff), |_| {
511            scirs2_core::random::rng().random_range(-scale..scale)
512        });
513        let w2 = Array2::from_shape_fn((dff, _dmodel), |_| {
514            scirs2_core::random::rng().random_range(-scale..scale)
515        });
516        let b1 = Array1::zeros(dff);
517        let b2 = Array1::zeros(_dmodel);
518
519        Self { w1, w2, b1, b2 }
520    }
521
522    /// Forward pass through feed-forward network
523    pub fn forward(&self, x: ArrayView2<f64>) -> Array2<f64> {
524        // First linear transformation + ReLU
525        let hidden = x.dot(&self.w1) + &self.b1;
526        let activated = hidden.mapv(|x| x.max(0.0)); // ReLU activation
527
528        // Second linear transformation
529        activated.dot(&self.w2) + &self.b2
530    }
531
532    /// Get feed-forward weight matrices for serialization
533    pub fn get_weights(&self) -> (&Array2<f64>, &Array2<f64>, &Array1<f64>, &Array1<f64>) {
534        (&self.w1, &self.w2, &self.b1, &self.b2)
535    }
536
537    /// Set feed-forward weight matrices from loaded weights
538    pub fn set_weights(
539        &mut self,
540        w1: Array2<f64>,
541        w2: Array2<f64>,
542        b1: Array1<f64>,
543        b2: Array1<f64>,
544    ) -> Result<()> {
545        if w1.shape()[1] != w2.shape()[0] {
546            return Err(TextError::InvalidInput(
547                "Weight matrix dimensions don't match".to_string(),
548            ));
549        }
550        if b1.len() != w1.shape()[1] {
551            return Err(TextError::InvalidInput(
552                "Bias b1 size doesn't match w1".to_string(),
553            ));
554        }
555        if b2.len() != w2.shape()[1] {
556            return Err(TextError::InvalidInput(
557                "Bias b2 size doesn't match w2".to_string(),
558            ));
559        }
560
561        self.w1 = w1;
562        self.w2 = w2;
563        self.b1 = b1;
564        self.b2 = b2;
565        Ok(())
566    }
567}
568
569/// Layer normalization
570pub struct LayerNorm {
571    gamma: Array1<f64>,
572    beta: Array1<f64>,
573    eps: f64,
574}
575
576impl LayerNorm {
577    /// Create new layer normalization
578    pub fn new(_dmodel: usize, eps: f64) -> Self {
579        Self {
580            gamma: Array1::ones(_dmodel),
581            beta: Array1::zeros(_dmodel),
582            eps,
583        }
584    }
585
586    /// Apply layer normalization
587    pub fn forward(&self, x: ArrayView2<f64>) -> Array2<f64> {
588        let mut result = Array2::zeros(x.raw_dim());
589
590        for (i, row) in x.rows().into_iter().enumerate() {
591            let mean = row.mean();
592            let var = row.mapv(|x| (x - mean).powi(2)).mean();
593            let std = (var + self.eps).sqrt();
594
595            let normalized = row.mapv(|x| (x - mean) / std);
596            let scaled = &normalized * &self.gamma + &self.beta;
597
598            result.row_mut(i).assign(&scaled);
599        }
600
601        result
602    }
603
604    /// Get layer normalization parameters for serialization
605    pub fn get_params(&self) -> (&Array1<f64>, &Array1<f64>) {
606        (&self.gamma, &self.beta)
607    }
608
609    /// Set layer normalization parameters from loaded weights
610    pub fn set_params(&mut self, gamma: Array1<f64>, beta: Array1<f64>) -> Result<()> {
611        if gamma.len() != beta.len() {
612            return Err(TextError::InvalidInput(
613                "Gamma and beta must have same length".to_string(),
614            ));
615        }
616        if gamma.len() != self.gamma.len() {
617            return Err(TextError::InvalidInput(
618                "Parameter size doesn't match layer dimension".to_string(),
619            ));
620        }
621
622        self.gamma = gamma;
623        self.beta = beta;
624        Ok(())
625    }
626}
627
628/// Transformer encoder layer
629pub struct TransformerEncoderLayer {
630    self_attention: MultiHeadAttention,
631    feed_forward: FeedForward,
632    norm1: LayerNorm,
633    norm2: LayerNorm,
634    #[allow(dead_code)]
635    dropout: f64,
636}
637
638impl TransformerEncoderLayer {
639    /// Create new transformer encoder layer
640    pub fn new(config: &TransformerConfig) -> Result<Self> {
641        Ok(Self {
642            self_attention: MultiHeadAttention::new(config.d_model, config.nheads)?,
643            feed_forward: FeedForward::new(config.d_model, config.d_ff),
644            norm1: LayerNorm::new(config.d_model, 1e-6),
645            norm2: LayerNorm::new(config.d_model, 1e-6),
646            dropout: config.dropout,
647        })
648    }
649
650    /// Forward pass through encoder layer
651    pub fn forward(
652        &self,
653        x: ArrayView2<f64>,
654        mask: Option<ArrayView2<bool>>,
655    ) -> Result<Array2<f64>> {
656        // Self-attention with residual connection and layer norm
657        let attn_output = self.self_attention.forward(x, x, x, mask)?;
658        let x = &self.norm1.forward(x) + &attn_output;
659
660        // Feed-forward with residual connection and layer norm
661        let ff_output = self.feed_forward.forward(x.view());
662        let output = &self.norm2.forward(x.view()) + &ff_output;
663
664        Ok(output)
665    }
666
667    /// Get mutable access to layer components for weight loading
668    pub fn get_components_mut(
669        &mut self,
670    ) -> (
671        &mut MultiHeadAttention,
672        &mut FeedForward,
673        &mut LayerNorm,
674        &mut LayerNorm,
675    ) {
676        (
677            &mut self.self_attention,
678            &mut self.feed_forward,
679            &mut self.norm1,
680            &mut self.norm2,
681        )
682    }
683
684    /// Get access to layer components for weight access
685    pub fn get_components(&self) -> (&MultiHeadAttention, &FeedForward, &LayerNorm, &LayerNorm) {
686        (
687            &self.self_attention,
688            &self.feed_forward,
689            &self.norm1,
690            &self.norm2,
691        )
692    }
693}
694
695/// Complete transformer encoder
696pub struct TransformerEncoder {
697    layers: Vec<TransformerEncoderLayer>,
698    position_encoding: PositionalEncoding,
699    config: TransformerConfig,
700}
701
702impl TransformerEncoder {
703    /// Create new transformer encoder
704    pub fn new(config: TransformerConfig) -> Result<Self> {
705        let mut layers = Vec::new();
706        for _ in 0..config.n_encoder_layers {
707            layers.push(TransformerEncoderLayer::new(&config)?);
708        }
709
710        let position_encoding = PositionalEncoding::new(config.max_seqlen, config.d_model);
711
712        Ok(Self {
713            layers,
714            position_encoding,
715            config,
716        })
717    }
718
719    /// Encode input sequence
720    pub fn encode(
721        &self,
722        embeddings: ArrayView2<f64>,
723        mask: Option<ArrayView2<bool>>,
724    ) -> Result<Array2<f64>> {
725        let seqlen = embeddings.shape()[0];
726
727        // Add positional encoding
728        let pos_enc = self.position_encoding.get_encoding(seqlen)?;
729        let mut x = embeddings.to_owned() + pos_enc;
730
731        // Pass through encoder layers
732        for layer in &self.layers {
733            x = layer.forward(x.view(), mask)?;
734        }
735
736        Ok(x)
737    }
738
739    /// Get configuration
740    pub fn config(&self) -> &TransformerConfig {
741        &self.config
742    }
743
744    /// Get mutable access to encoder layers for weight loading
745    pub fn get_layers_mut(&mut self) -> &mut Vec<TransformerEncoderLayer> {
746        &mut self.layers
747    }
748
749    /// Get access to encoder layers for weight access
750    pub fn get_layers(&self) -> &Vec<TransformerEncoderLayer> {
751        &self.layers
752    }
753
754    /// Get the positional encoding matrix (for serialization)
755    pub fn get_position_encoding(&self) -> &Array2<f64> {
756        self.position_encoding.get_encodings()
757    }
758
759    /// Set the positional encoding matrix from loaded weights
760    pub fn set_position_encoding(&mut self, encodings: Array2<f64>) -> Result<()> {
761        self.position_encoding.set_encodings(encodings)
762    }
763}
764
765/// Transformer decoder layer with self-attention, cross-attention, and feed-forward
766pub struct TransformerDecoderLayer {
767    self_attention: MultiHeadAttention,
768    cross_attention: MultiHeadAttention,
769    feed_forward: FeedForward,
770    norm1: LayerNorm,
771    norm2: LayerNorm,
772    norm3: LayerNorm,
773    #[allow(dead_code)]
774    dropout: f64,
775}
776
777impl TransformerDecoderLayer {
778    /// Create new decoder layer
779    pub fn new(config: &TransformerConfig) -> Result<Self> {
780        Ok(Self {
781            self_attention: MultiHeadAttention::new(config.d_model, config.nheads)?,
782            cross_attention: MultiHeadAttention::new(config.d_model, config.nheads)?,
783            feed_forward: FeedForward::new(config.d_model, config.d_ff),
784            norm1: LayerNorm::new(config.d_model, 1e-6),
785            norm2: LayerNorm::new(config.d_model, 1e-6),
786            norm3: LayerNorm::new(config.d_model, 1e-6),
787            dropout: config.dropout,
788        })
789    }
790
791    /// Forward pass with encoder output for cross-attention
792    pub fn forward(
793        &self,
794        x: ArrayView2<f64>,
795        encoder_output: ArrayView2<f64>,
796        self_attn_mask: Option<ArrayView2<bool>>,
797        cross_attn_mask: Option<ArrayView2<bool>>,
798    ) -> Result<Array2<f64>> {
799        // Self-attention with residual connection and layer norm
800        let self_attn_out = self.self_attention.forward(x, x, x, self_attn_mask)?;
801        let x = self.norm1.forward((x.to_owned() + self_attn_out).view());
802
803        // Cross-attention with encoder _output
804        let cross_attn_out = self.cross_attention.forward(
805            x.view(),
806            encoder_output,
807            encoder_output,
808            cross_attn_mask,
809        )?;
810        let x = self.norm2.forward((x + cross_attn_out).view());
811
812        // Feed-forward with residual connection and layer norm
813        let ff_out = self.feed_forward.forward(x.view());
814        let _output = self.norm3.forward((x + ff_out).view());
815
816        Ok(_output)
817    }
818}
819
820/// Transformer decoder stack
821pub struct TransformerDecoder {
822    layers: Vec<TransformerDecoderLayer>,
823    position_encoding: PositionalEncoding,
824    config: TransformerConfig,
825}
826
827impl TransformerDecoder {
828    /// Create new decoder
829    pub fn new(config: TransformerConfig) -> Result<Self> {
830        let mut layers = Vec::new();
831        for _ in 0..config.n_decoder_layers {
832            layers.push(TransformerDecoderLayer::new(&config)?);
833        }
834
835        let position_encoding = PositionalEncoding::new(config.max_seqlen, config.d_model);
836
837        Ok(Self {
838            layers,
839            position_encoding,
840            config,
841        })
842    }
843
844    /// Forward pass through decoder
845    pub fn forward(
846        &self,
847        embeddings: ArrayView2<f64>,
848        encoder_output: ArrayView2<f64>,
849        self_attn_mask: Option<ArrayView2<bool>>,
850        cross_attn_mask: Option<ArrayView2<bool>>,
851    ) -> Result<Array2<f64>> {
852        let seqlen = embeddings.shape()[0];
853
854        // Add positional encoding
855        let pos_enc = self.position_encoding.get_encoding(seqlen)?;
856        let mut x = embeddings.to_owned() + pos_enc;
857
858        // Pass through decoder layers
859        for layer in &self.layers {
860            x = layer.forward(x.view(), encoder_output, self_attn_mask, cross_attn_mask)?;
861        }
862
863        Ok(x)
864    }
865
866    /// Get configuration
867    pub fn config(&self) -> &TransformerConfig {
868        &self.config
869    }
870}
871
872/// Token embedding layer
873pub struct TokenEmbedding {
874    embeddings: Array2<f64>,
875    vocab_size: usize,
876    d_model: usize,
877}
878
879impl TokenEmbedding {
880    /// Create new token embedding layer
881    pub fn new(_vocab_size: usize, dmodel: usize) -> Self {
882        let scale = (1.0 / dmodel as f64).sqrt();
883        let embeddings = Array2::from_shape_fn((_vocab_size, dmodel), |_| {
884            scirs2_core::random::rng().random_range(-scale..scale)
885        });
886
887        Self {
888            embeddings,
889            vocab_size: _vocab_size,
890            d_model: dmodel,
891        }
892    }
893
894    /// Get embeddings for token IDs
895    pub fn forward(&self, tokenids: &[usize]) -> Result<Array2<f64>> {
896        let mut result = Array2::zeros((tokenids.len(), self.d_model));
897
898        for (i, &token_id) in tokenids.iter().enumerate() {
899            if token_id >= self.vocab_size {
900                return Err(TextError::InvalidInput(format!(
901                    "Token ID {} exceeds vocabulary size {}",
902                    token_id, self.vocab_size
903                )));
904            }
905            result.row_mut(i).assign(&self.embeddings.row(token_id));
906        }
907
908        Ok(result)
909    }
910
911    /// Get access to the embedding matrix for serialization
912    pub fn get_embeddings(&self) -> &Array2<f64> {
913        &self.embeddings
914    }
915
916    /// Set the embedding matrix from loaded weights
917    pub fn set_embeddings(&mut self, embeddings: Array2<f64>) -> Result<()> {
918        if embeddings.shape()[0] != self.vocab_size || embeddings.shape()[1] != self.d_model {
919            return Err(TextError::InvalidInput(format!(
920                "Embedding shape {:?} doesn't match expected ({}, {})",
921                embeddings.shape(),
922                self.vocab_size,
923                self.d_model
924            )));
925        }
926        self.embeddings = embeddings;
927        Ok(())
928    }
929}
930
931/// Complete transformer model for text processing
932pub struct TransformerModel {
933    /// Model configuration
934    pub config: TransformerConfig,
935    /// Token embedding layer
936    pub token_embedding: TokenEmbedding,
937    /// Transformer encoder
938    pub encoder: TransformerEncoder,
939    /// Optional transformer decoder
940    pub decoder: Option<TransformerDecoder>,
941    vocab_to_id: HashMap<String, usize>,
942    id_to_vocab: HashMap<usize, String>,
943}
944
945impl TransformerModel {
946    /// Create new transformer model
947    pub fn new(config: TransformerConfig, vocabulary: Vec<String>) -> Result<Self> {
948        let vocab_size = vocabulary.len();
949        if vocab_size != config.vocab_size {
950            return Err(TextError::InvalidInput(format!(
951                "Vocabulary size {} doesn't match config {}",
952                vocab_size, config.vocab_size
953            )));
954        }
955
956        let mut vocab_to_id = HashMap::new();
957        let mut id_to_vocab = HashMap::new();
958
959        for (id, token) in vocabulary.into_iter().enumerate() {
960            vocab_to_id.insert(token.clone(), id);
961            id_to_vocab.insert(id, token);
962        }
963
964        Ok(Self {
965            config: config.clone(),
966            token_embedding: TokenEmbedding::new(config.vocab_size, config.d_model),
967            encoder: TransformerEncoder::new(config)?,
968            decoder: None, // Encoder-only model
969            vocab_to_id,
970            id_to_vocab,
971        })
972    }
973
974    /// Encode text tokens to contextual embeddings
975    pub fn encode_tokens(&self, tokens: &[String]) -> Result<Array2<f64>> {
976        // Convert tokens to IDs
977        let tokenids: Result<Vec<usize>> = tokens
978            .iter()
979            .map(|token| {
980                self.vocab_to_id
981                    .get(token)
982                    .cloned()
983                    .ok_or_else(|| TextError::InvalidInput(format!("Unknown token: {token}")))
984            })
985            .collect();
986        let tokenids = tokenids?;
987
988        // Get token embeddings
989        let embeddings = self.token_embedding.forward(&tokenids)?;
990
991        // Encode with transformer
992        self.encoder.encode(embeddings.view(), None)
993    }
994
995    /// Create new encoder-decoder transformer model
996    pub fn new_encoder_decoder(config: TransformerConfig, vocabulary: Vec<String>) -> Result<Self> {
997        let vocab_size = vocabulary.len();
998        if vocab_size != config.vocab_size {
999            return Err(TextError::InvalidInput(format!(
1000                "Vocabulary size {} doesn't match config {}",
1001                vocab_size, config.vocab_size
1002            )));
1003        }
1004
1005        let mut vocab_to_id = HashMap::new();
1006        let mut id_to_vocab = HashMap::new();
1007
1008        for (id, token) in vocabulary.into_iter().enumerate() {
1009            vocab_to_id.insert(token.clone(), id);
1010            id_to_vocab.insert(id, token);
1011        }
1012
1013        Ok(Self {
1014            config: config.clone(),
1015            token_embedding: TokenEmbedding::new(config.vocab_size, config.d_model),
1016            encoder: TransformerEncoder::new(config.clone())?,
1017            decoder: Some(TransformerDecoder::new(config)?),
1018            vocab_to_id,
1019            id_to_vocab,
1020        })
1021    }
1022
1023    /// Perform encoder-decoder forward pass
1024    pub fn encode_decode(
1025        &self,
1026        input_tokens: &[String],
1027        target_tokens: &[String],
1028    ) -> Result<Array2<f64>> {
1029        let decoder = self
1030            .decoder
1031            .as_ref()
1032            .ok_or_else(|| TextError::InvalidInput("Model has no decoder".to_string()))?;
1033
1034        // Encode input
1035        let encoder_output = self.encode_tokens(input_tokens)?;
1036
1037        // Convert target _tokens to IDs and embeddings
1038        let target_ids: Result<Vec<usize>> = target_tokens
1039            .iter()
1040            .map(|token| {
1041                self.vocab_to_id
1042                    .get(token)
1043                    .copied()
1044                    .ok_or_else(|| TextError::InvalidInput(format!("Unknown token: {token}")))
1045            })
1046            .collect();
1047        let target_ids = target_ids?;
1048
1049        let target_embeddings = self.token_embedding.forward(&target_ids)?;
1050
1051        // Generate causal mask for decoder self-attention
1052        let seqlen = target_tokens.len();
1053        let mut causal_mask = Array2::from_elem((seqlen, seqlen), false);
1054        for i in 0..seqlen {
1055            for j in (i + 1)..seqlen {
1056                causal_mask[[i, j]] = true; // Mask future positions
1057            }
1058        }
1059
1060        // Decode
1061        decoder.forward(
1062            target_embeddings.view(),
1063            encoder_output.view(),
1064            Some(causal_mask.view()),
1065            None,
1066        )
1067    }
1068
1069    /// Generate text using the decoder (for generation tasks)
1070    pub fn generate(
1071        &self,
1072        input_tokens: &[String],
1073        max_length: usize,
1074        start_token: &str,
1075    ) -> Result<Vec<String>> {
1076        let decoder = self
1077            .decoder
1078            .as_ref()
1079            .ok_or_else(|| TextError::InvalidInput("Model has no decoder".to_string()))?;
1080
1081        // Encode input
1082        let encoder_output = self.encode_tokens(input_tokens)?;
1083
1084        // Start with the start _token
1085        let mut generated_tokens = vec![start_token.to_string()];
1086
1087        for _ in 0..max_length {
1088            // Convert current _tokens to embeddings
1089            let current_ids: Result<Vec<usize>> = generated_tokens
1090                .iter()
1091                .map(|_token| {
1092                    self.vocab_to_id
1093                        .get(_token)
1094                        .copied()
1095                        .ok_or_else(|| TextError::InvalidInput(format!("Unknown token: {_token}")))
1096                })
1097                .collect();
1098            let current_ids = current_ids?;
1099
1100            let current_embeddings = self.token_embedding.forward(&current_ids)?;
1101
1102            // Generate causal mask
1103            let seqlen = generated_tokens.len();
1104            let mut causal_mask = Array2::from_elem((seqlen, seqlen), false);
1105            for i in 0..seqlen {
1106                for j in (i + 1)..seqlen {
1107                    causal_mask[[i, j]] = true;
1108                }
1109            }
1110
1111            // Decode
1112            let decoder_output = decoder.forward(
1113                current_embeddings.view(),
1114                encoder_output.view(),
1115                Some(causal_mask.view()),
1116                None,
1117            )?;
1118
1119            // Get the last timestep output
1120            let last_output = decoder_output.row(decoder_output.nrows() - 1);
1121
1122            // Simple greedy selection (find _token with highest logit)
1123            let mut best_token_id = 0;
1124            let mut best_score = last_output[0];
1125            for (i, &score) in last_output.iter().enumerate() {
1126                if score > best_score {
1127                    best_score = score;
1128                    best_token_id = i;
1129                }
1130            }
1131
1132            // Convert _token ID back to string
1133            if let Some(_token) = self.id_to_vocab.get(&best_token_id) {
1134                generated_tokens.push(_token.clone());
1135
1136                // Stop if we hit an end _token (you might want to customize this)
1137                if _token == "</s>" || _token == "<eos>" {
1138                    break;
1139                }
1140            } else {
1141                break;
1142            }
1143        }
1144
1145        Ok(generated_tokens)
1146    }
1147
1148    /// Get vocabulary mapping
1149    pub fn vocabulary(&self) -> (&HashMap<String, usize>, &HashMap<usize, String>) {
1150        (&self.vocab_to_id, &self.id_to_vocab)
1151    }
1152}
1153
1154#[cfg(test)]
1155mod tests {
1156    use super::*;
1157
1158    #[test]
1159    fn test_positional_encoding() {
1160        let pos_enc = PositionalEncoding::new(10, 4);
1161        let encoding = pos_enc.get_encoding(5).expect("Operation failed");
1162        assert_eq!(encoding.shape(), &[5, 4]);
1163
1164        // Test that positions are different
1165        let pos0 = encoding.row(0);
1166        let pos1 = encoding.row(1);
1167        assert!(pos0
1168            .iter()
1169            .zip(pos1.iter())
1170            .any(|(a, b)| (a - b).abs() > 1e-6));
1171    }
1172
1173    #[test]
1174    fn test_multi_head_attention() {
1175        let mha = MultiHeadAttention::new(8, 2).expect("Operation failed");
1176        let seqlen = 4;
1177        let d_model = 8;
1178
1179        let input = Array2::ones((seqlen, d_model));
1180        let output = mha
1181            .forward(input.view(), input.view(), input.view(), None)
1182            .expect("Operation failed");
1183
1184        assert_eq!(output.shape(), &[seqlen, d_model]);
1185    }
1186
1187    #[test]
1188    fn test_transformer_encoder() {
1189        let config = TransformerConfig {
1190            d_model: 8,
1191            nheads: 2,
1192            d_ff: 16,
1193            n_encoder_layers: 2,
1194            ..Default::default()
1195        };
1196
1197        let encoder = TransformerEncoder::new(config).expect("Operation failed");
1198        let input = Array2::ones((4, 8));
1199        let output = encoder
1200            .encode(input.view(), None)
1201            .expect("Operation failed");
1202
1203        assert_eq!(output.shape(), &[4, 8]);
1204    }
1205}