quantrs2_ml/pytorch_api/
transformer.rs

1//! Transformer layers for PyTorch-like API
2
3use super::layers::{QuantumLayerNorm, QuantumLinear};
4use super::{Parameter, QuantumModule};
5use crate::error::{MLError, Result};
6use crate::scirs2_integration::SciRS2Array;
7use scirs2_core::ndarray::{ArrayD, IxDyn};
8
9/// Multi-head attention layer
10pub struct QuantumMultiheadAttention {
11    embed_dim: usize,
12    num_heads: usize,
13    head_dim: usize,
14    q_proj: Parameter,
15    k_proj: Parameter,
16    v_proj: Parameter,
17    out_proj: Parameter,
18    dropout: f64,
19    training: bool,
20}
21
22impl QuantumMultiheadAttention {
23    /// Create new multi-head attention
24    pub fn new(embed_dim: usize, num_heads: usize) -> Result<Self> {
25        if embed_dim % num_heads != 0 {
26            return Err(MLError::InvalidConfiguration(
27                "embed_dim must be divisible by num_heads".to_string(),
28            ));
29        }
30
31        let head_dim = embed_dim / num_heads;
32        let scale = (1.0 / (embed_dim as f64)).sqrt();
33
34        let q_proj = ArrayD::from_shape_fn(IxDyn(&[embed_dim, embed_dim]), |_| {
35            (fastrand::f64() * 2.0 - 1.0) * scale
36        });
37        let k_proj = ArrayD::from_shape_fn(IxDyn(&[embed_dim, embed_dim]), |_| {
38            (fastrand::f64() * 2.0 - 1.0) * scale
39        });
40        let v_proj = ArrayD::from_shape_fn(IxDyn(&[embed_dim, embed_dim]), |_| {
41            (fastrand::f64() * 2.0 - 1.0) * scale
42        });
43        let out_proj = ArrayD::from_shape_fn(IxDyn(&[embed_dim, embed_dim]), |_| {
44            (fastrand::f64() * 2.0 - 1.0) * scale
45        });
46
47        Ok(Self {
48            embed_dim,
49            num_heads,
50            head_dim,
51            q_proj: Parameter::new(SciRS2Array::with_grad(q_proj), "q_proj"),
52            k_proj: Parameter::new(SciRS2Array::with_grad(k_proj), "k_proj"),
53            v_proj: Parameter::new(SciRS2Array::with_grad(v_proj), "v_proj"),
54            out_proj: Parameter::new(SciRS2Array::with_grad(out_proj), "out_proj"),
55            dropout: 0.0,
56            training: true,
57        })
58    }
59
60    /// Set dropout
61    pub fn dropout(mut self, dropout: f64) -> Self {
62        self.dropout = dropout;
63        self
64    }
65
66    /// Forward with query, key, value
67    pub fn forward_qkv(
68        &self,
69        query: &SciRS2Array,
70        key: &SciRS2Array,
71        value: &SciRS2Array,
72        attn_mask: Option<&ArrayD<f64>>,
73    ) -> Result<(SciRS2Array, SciRS2Array)> {
74        let shape = query.data.shape();
75        let (batch_size, seq_len, _) = (shape[0], shape[1], shape[2]);
76        let scale = (self.head_dim as f64).sqrt();
77
78        // Project Q, K, V
79        let mut q = ArrayD::zeros(IxDyn(&[batch_size, seq_len, self.embed_dim]));
80        let mut k = ArrayD::zeros(IxDyn(&[batch_size, seq_len, self.embed_dim]));
81        let mut v = ArrayD::zeros(IxDyn(&[batch_size, seq_len, self.embed_dim]));
82
83        // Simple matrix multiplication for projection
84        for b in 0..batch_size {
85            for s in 0..seq_len {
86                for e_out in 0..self.embed_dim {
87                    let mut q_sum = 0.0;
88                    let mut k_sum = 0.0;
89                    let mut v_sum = 0.0;
90                    for e_in in 0..self.embed_dim {
91                        q_sum += query.data[[b, s, e_in]] * self.q_proj.data.data[[e_out, e_in]];
92                        k_sum += key.data[[b, s, e_in]] * self.k_proj.data.data[[e_out, e_in]];
93                        v_sum += value.data[[b, s, e_in]] * self.v_proj.data.data[[e_out, e_in]];
94                    }
95                    q[[b, s, e_out]] = q_sum;
96                    k[[b, s, e_out]] = k_sum;
97                    v[[b, s, e_out]] = v_sum;
98                }
99            }
100        }
101
102        // Compute attention scores: Q @ K^T / sqrt(d_k)
103        let mut attn_scores = ArrayD::zeros(IxDyn(&[batch_size, self.num_heads, seq_len, seq_len]));
104
105        for b in 0..batch_size {
106            for h in 0..self.num_heads {
107                for i in 0..seq_len {
108                    for j in 0..seq_len {
109                        let mut score = 0.0;
110                        for d in 0..self.head_dim {
111                            let q_idx = h * self.head_dim + d;
112                            let k_idx = h * self.head_dim + d;
113                            score += q[[b, i, q_idx]] * k[[b, j, k_idx]];
114                        }
115                        attn_scores[[b, h, i, j]] = score / scale;
116                    }
117                }
118            }
119        }
120
121        // Apply attention mask if provided
122        if let Some(mask) = attn_mask {
123            for b in 0..batch_size {
124                for h in 0..self.num_heads {
125                    for i in 0..seq_len {
126                        for j in 0..seq_len {
127                            if mask[[i, j]] == 0.0 {
128                                attn_scores[[b, h, i, j]] = f64::NEG_INFINITY;
129                            }
130                        }
131                    }
132                }
133            }
134        }
135
136        // Softmax
137        for b in 0..batch_size {
138            for h in 0..self.num_heads {
139                for i in 0..seq_len {
140                    let max_score = (0..seq_len)
141                        .map(|j| attn_scores[[b, h, i, j]])
142                        .fold(f64::NEG_INFINITY, f64::max);
143                    let mut sum_exp = 0.0;
144                    for j in 0..seq_len {
145                        attn_scores[[b, h, i, j]] = (attn_scores[[b, h, i, j]] - max_score).exp();
146                        sum_exp += attn_scores[[b, h, i, j]];
147                    }
148                    for j in 0..seq_len {
149                        attn_scores[[b, h, i, j]] /= sum_exp;
150                    }
151                }
152            }
153        }
154
155        // Attention output: attn_weights @ V
156        let mut attn_output = ArrayD::zeros(IxDyn(&[batch_size, seq_len, self.embed_dim]));
157
158        for b in 0..batch_size {
159            for h in 0..self.num_heads {
160                for i in 0..seq_len {
161                    for d in 0..self.head_dim {
162                        let mut sum = 0.0;
163                        for j in 0..seq_len {
164                            sum += attn_scores[[b, h, i, j]] * v[[b, j, h * self.head_dim + d]];
165                        }
166                        attn_output[[b, i, h * self.head_dim + d]] = sum;
167                    }
168                }
169            }
170        }
171
172        // Output projection
173        let mut output = ArrayD::zeros(IxDyn(&[batch_size, seq_len, self.embed_dim]));
174        for b in 0..batch_size {
175            for s in 0..seq_len {
176                for e_out in 0..self.embed_dim {
177                    let mut sum = 0.0;
178                    for e_in in 0..self.embed_dim {
179                        sum += attn_output[[b, s, e_in]] * self.out_proj.data.data[[e_out, e_in]];
180                    }
181                    output[[b, s, e_out]] = sum;
182                }
183            }
184        }
185
186        // Average attention weights across heads for output
187        let mut avg_attn = ArrayD::zeros(IxDyn(&[batch_size, seq_len, seq_len]));
188        for b in 0..batch_size {
189            for i in 0..seq_len {
190                for j in 0..seq_len {
191                    let mut sum = 0.0;
192                    for h in 0..self.num_heads {
193                        sum += attn_scores[[b, h, i, j]];
194                    }
195                    avg_attn[[b, i, j]] = sum / self.num_heads as f64;
196                }
197            }
198        }
199
200        Ok((
201            SciRS2Array::new(output, query.requires_grad),
202            SciRS2Array::new(avg_attn, false),
203        ))
204    }
205}
206
207impl QuantumModule for QuantumMultiheadAttention {
208    fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array> {
209        // Self-attention: query = key = value = input
210        let (output, _) = self.forward_qkv(input, input, input, None)?;
211        Ok(output)
212    }
213
214    fn parameters(&self) -> Vec<Parameter> {
215        vec![
216            self.q_proj.clone(),
217            self.k_proj.clone(),
218            self.v_proj.clone(),
219            self.out_proj.clone(),
220        ]
221    }
222
223    fn train(&mut self, mode: bool) {
224        self.training = mode;
225    }
226
227    fn training(&self) -> bool {
228        self.training
229    }
230
231    fn zero_grad(&mut self) {
232        self.q_proj.data.zero_grad();
233        self.k_proj.data.zero_grad();
234        self.v_proj.data.zero_grad();
235        self.out_proj.data.zero_grad();
236    }
237
238    fn name(&self) -> &str {
239        "MultiheadAttention"
240    }
241}
242
243/// Transformer encoder layer
244pub struct QuantumTransformerEncoderLayer {
245    self_attn: QuantumMultiheadAttention,
246    linear1: QuantumLinear,
247    linear2: QuantumLinear,
248    norm1: QuantumLayerNorm,
249    norm2: QuantumLayerNorm,
250    dropout: f64,
251    training: bool,
252}
253
254impl QuantumTransformerEncoderLayer {
255    /// Create new transformer encoder layer
256    pub fn new(d_model: usize, nhead: usize, dim_feedforward: usize) -> Result<Self> {
257        Ok(Self {
258            self_attn: QuantumMultiheadAttention::new(d_model, nhead)?,
259            linear1: QuantumLinear::new(d_model, dim_feedforward)?,
260            linear2: QuantumLinear::new(dim_feedforward, d_model)?,
261            norm1: QuantumLayerNorm::new(vec![d_model]),
262            norm2: QuantumLayerNorm::new(vec![d_model]),
263            dropout: 0.1,
264            training: true,
265        })
266    }
267
268    /// Set dropout
269    pub fn dropout(mut self, dropout: f64) -> Self {
270        self.dropout = dropout;
271        self
272    }
273}
274
275impl QuantumModule for QuantumTransformerEncoderLayer {
276    fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array> {
277        // Self attention
278        let attn_output = self.self_attn.forward(input)?;
279
280        // Add & Norm
281        let residual1 = SciRS2Array::new(&input.data + &attn_output.data, input.requires_grad);
282        let normed1 = self.norm1.forward(&residual1)?;
283
284        // Feedforward
285        let ff_output = self.linear1.forward(&normed1)?;
286        let ff_activated =
287            SciRS2Array::new(ff_output.data.mapv(|x| x.max(0.0)), ff_output.requires_grad);
288        let ff_output2 = self.linear2.forward(&ff_activated)?;
289
290        // Add & Norm
291        let residual2 = SciRS2Array::new(&normed1.data + &ff_output2.data, input.requires_grad);
292        self.norm2.forward(&residual2)
293    }
294
295    fn parameters(&self) -> Vec<Parameter> {
296        let mut params = self.self_attn.parameters();
297        params.extend(self.linear1.parameters());
298        params.extend(self.linear2.parameters());
299        params.extend(self.norm1.parameters());
300        params.extend(self.norm2.parameters());
301        params
302    }
303
304    fn train(&mut self, mode: bool) {
305        self.training = mode;
306        self.self_attn.train(mode);
307        self.linear1.train(mode);
308        self.linear2.train(mode);
309        self.norm1.train(mode);
310        self.norm2.train(mode);
311    }
312
313    fn training(&self) -> bool {
314        self.training
315    }
316
317    fn zero_grad(&mut self) {
318        self.self_attn.zero_grad();
319        self.linear1.zero_grad();
320        self.linear2.zero_grad();
321        self.norm1.zero_grad();
322        self.norm2.zero_grad();
323    }
324
325    fn name(&self) -> &str {
326        "TransformerEncoderLayer"
327    }
328}
329
330/// Positional encoding for transformers
331pub struct PositionalEncoding {
332    d_model: usize,
333    max_len: usize,
334    dropout: f64,
335    encoding: ArrayD<f64>,
336    training: bool,
337}
338
339impl PositionalEncoding {
340    /// Create new positional encoding
341    pub fn new(d_model: usize, max_len: usize) -> Self {
342        let mut encoding = ArrayD::zeros(IxDyn(&[max_len, d_model]));
343
344        for pos in 0..max_len {
345            for i in 0..d_model {
346                let angle = pos as f64 / 10000.0_f64.powf(2.0 * (i / 2) as f64 / d_model as f64);
347                encoding[[pos, i]] = if i % 2 == 0 { angle.sin() } else { angle.cos() };
348            }
349        }
350
351        Self {
352            d_model,
353            max_len,
354            dropout: 0.1,
355            encoding,
356            training: true,
357        }
358    }
359
360    /// Set dropout
361    pub fn dropout(mut self, dropout: f64) -> Self {
362        self.dropout = dropout;
363        self
364    }
365}
366
367impl QuantumModule for PositionalEncoding {
368    fn forward(&mut self, input: &SciRS2Array) -> Result<SciRS2Array> {
369        let shape = input.data.shape();
370        let seq_len = shape[1];
371
372        let mut output = input.data.clone();
373
374        for b in 0..shape[0] {
375            for s in 0..seq_len.min(self.max_len) {
376                for d in 0..self.d_model.min(shape[2]) {
377                    output[[b, s, d]] += self.encoding[[s, d]];
378                }
379            }
380        }
381
382        Ok(SciRS2Array::new(output, input.requires_grad))
383    }
384
385    fn parameters(&self) -> Vec<Parameter> {
386        Vec::new() // Positional encoding is typically not learned
387    }
388
389    fn train(&mut self, mode: bool) {
390        self.training = mode;
391    }
392
393    fn training(&self) -> bool {
394        self.training
395    }
396
397    fn zero_grad(&mut self) {}
398
399    fn name(&self) -> &str {
400        "PositionalEncoding"
401    }
402}