quantrs2_core/qml/
quantum_transformer.rs

1//! Quantum Transformer with Attention Mechanisms
2//!
3//! This module implements quantum transformers with attention mechanisms for
4//! processing sequential quantum data. It includes:
5//! - Multi-head quantum attention
6//! - Quantum positional encoding
7//! - Quantum feed-forward networks
8//! - Layer normalization for quantum states
9//!
10//! # Theoretical Background
11//!
12//! Quantum transformers extend classical transformer architectures to the quantum domain,
13//! leveraging quantum superposition and entanglement for enhanced representation learning.
14//! The attention mechanism is implemented using quantum circuits that compute attention
15//! scores via quantum interference patterns.
16//!
17//! # References
18//!
19//! - "Quantum Attention Networks" (2023)
20//! - "Self-Attention in Quantum Computing" (2024)
21//! - "Quantum Transformers for Natural Language Processing" (2024)
22
23use crate::{
24    error::{QuantRS2Error, QuantRS2Result},
25    gate::GateOp,
26    qubit::QubitId,
27};
28use scirs2_core::ndarray::{Array1, Array2, Array3};
29use scirs2_core::random::prelude::*;
30use scirs2_core::Complex64;
31use std::f64::consts::PI;
32
33/// Configuration for quantum transformer
34#[derive(Debug, Clone)]
35pub struct QuantumTransformerConfig {
36    /// Number of qubits for data representation
37    pub num_qubits: usize,
38    /// Number of attention heads
39    pub num_heads: usize,
40    /// Dimension of each attention head
41    pub head_dim: usize,
42    /// Number of transformer layers
43    pub num_layers: usize,
44    /// Dimension of feed-forward network
45    pub ffn_dim: usize,
46    /// Dropout rate for regularization
47    pub dropout_rate: f64,
48    /// Maximum sequence length
49    pub max_seq_length: usize,
50    /// Whether to use layer normalization
51    pub use_layer_norm: bool,
52}
53
54impl Default for QuantumTransformerConfig {
55    fn default() -> Self {
56        Self {
57            num_qubits: 4,
58            num_heads: 2,
59            head_dim: 2,
60            num_layers: 2,
61            ffn_dim: 8,
62            dropout_rate: 0.1,
63            max_seq_length: 64,
64            use_layer_norm: true,
65        }
66    }
67}
68
69/// Quantum attention mechanism using quantum circuits
70#[derive(Debug, Clone)]
71pub struct QuantumAttention {
72    /// Number of qubits
73    num_qubits: usize,
74    /// Number of attention heads
75    num_heads: usize,
76    /// Dimension per head
77    head_dim: usize,
78    /// Query parameters
79    query_params: Array2<f64>,
80    /// Key parameters
81    key_params: Array2<f64>,
82    /// Value parameters
83    value_params: Array2<f64>,
84    /// Output projection parameters
85    output_params: Array2<f64>,
86}
87
88impl QuantumAttention {
89    /// Create a new quantum attention mechanism
90    pub fn new(num_qubits: usize, num_heads: usize, head_dim: usize) -> QuantRS2Result<Self> {
91        if num_qubits < 2 {
92            return Err(QuantRS2Error::InvalidInput(
93                "Quantum attention requires at least 2 qubits".to_string(),
94            ));
95        }
96
97        if num_heads == 0 || head_dim == 0 {
98            return Err(QuantRS2Error::InvalidInput(
99                "Number of heads and head dimension must be positive".to_string(),
100            ));
101        }
102
103        let total_dim = num_heads * head_dim;
104        let mut rng = thread_rng();
105
106        // Xavier initialization for quantum parameters
107        let scale = (2.0 / (num_qubits as f64)).sqrt();
108
109        let query_params =
110            Array2::from_shape_fn((total_dim, num_qubits), |_| rng.gen_range(-scale..scale));
111
112        let key_params =
113            Array2::from_shape_fn((total_dim, num_qubits), |_| rng.gen_range(-scale..scale));
114
115        let value_params =
116            Array2::from_shape_fn((total_dim, num_qubits), |_| rng.gen_range(-scale..scale));
117
118        let output_params =
119            Array2::from_shape_fn((num_qubits, total_dim), |_| rng.gen_range(-scale..scale));
120
121        Ok(Self {
122            num_qubits,
123            num_heads,
124            head_dim,
125            query_params,
126            key_params,
127            value_params,
128            output_params,
129        })
130    }
131
132    /// Compute quantum attention scores using quantum interference
133    pub fn attention_scores(
134        &self,
135        query: &Array2<Complex64>,
136        key: &Array2<Complex64>,
137    ) -> QuantRS2Result<Array2<f64>> {
138        let seq_len = query.shape()[0];
139        let mut scores = Array2::zeros((seq_len, seq_len));
140
141        // Compute attention scores via quantum state overlap
142        for i in 0..seq_len {
143            for j in 0..seq_len {
144                let q = query.row(i);
145                let k = key.row(j);
146
147                // Quantum inner product (fidelity)
148                let mut score = Complex64::new(0.0, 0.0);
149                for (qi, ki) in q.iter().zip(k.iter()) {
150                    score += qi.conj() * ki;
151                }
152
153                // Scale by sqrt(head_dim) as in classical transformers
154                let scaled_score = score.norm() / (self.head_dim as f64).sqrt();
155                scores[[i, j]] = scaled_score;
156            }
157        }
158
159        Ok(scores)
160    }
161
162    /// Apply softmax to attention scores
163    pub fn softmax(&self, scores: &Array2<f64>) -> Array2<f64> {
164        let seq_len = scores.shape()[0];
165        let mut softmax_scores = Array2::zeros((seq_len, seq_len));
166
167        for i in 0..seq_len {
168            let row = scores.row(i);
169            let max_score = row.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
170
171            // Compute exp(score - max) for numerical stability
172            let mut exp_scores = Array1::zeros(seq_len);
173            let mut sum_exp = 0.0;
174
175            for (j, &score) in row.iter().enumerate() {
176                let exp_val = (score - max_score).exp();
177                exp_scores[j] = exp_val;
178                sum_exp += exp_val;
179            }
180
181            // Normalize
182            for j in 0..seq_len {
183                softmax_scores[[i, j]] = exp_scores[j] / sum_exp;
184            }
185        }
186
187        softmax_scores
188    }
189
190    /// Apply quantum attention to input
191    pub fn forward(&self, input: &Array2<Complex64>) -> QuantRS2Result<Array2<Complex64>> {
192        let seq_len = input.shape()[0];
193
194        // Project to query, key, value using quantum rotations
195        let query = self.project_qkv(input, &self.query_params)?;
196        let key = self.project_qkv(input, &self.key_params)?;
197        let value = self.project_qkv(input, &self.value_params)?;
198
199        // Compute attention scores
200        let scores = self.attention_scores(&query, &key)?;
201        let attention_weights = self.softmax(&scores);
202
203        // Apply attention to values
204        let total_dim = self.num_heads * self.head_dim;
205        let mut output = Array2::zeros((seq_len, total_dim));
206
207        for i in 0..seq_len {
208            for j in 0..seq_len {
209                let weight = attention_weights[[i, j]];
210                for k in 0..total_dim {
211                    output[[i, k]] = output[[i, k]] + value[[j, k]] * weight;
212                }
213            }
214        }
215
216        // Project back to original dimension
217        self.project_output(&output)
218    }
219
220    /// Project input to query/key/value space
221    fn project_qkv(
222        &self,
223        input: &Array2<Complex64>,
224        params: &Array2<f64>,
225    ) -> QuantRS2Result<Array2<Complex64>> {
226        let seq_len = input.shape()[0];
227        let out_dim = params.shape()[0];
228        let mut output = Array2::zeros((seq_len, out_dim));
229
230        for i in 0..seq_len {
231            for j in 0..out_dim {
232                let mut sum = Complex64::new(0.0, 0.0);
233                for k in 0..self.num_qubits {
234                    // Quantum rotation based projection
235                    let angle = params[[j, k]];
236                    let rotation = Complex64::new(angle.cos(), angle.sin());
237                    sum += input[[i, k]] * rotation;
238                }
239                output[[i, j]] = sum;
240            }
241        }
242
243        Ok(output)
244    }
245
246    /// Project output back to original dimension
247    fn project_output(
248        &self,
249        attention_out: &Array2<Complex64>,
250    ) -> QuantRS2Result<Array2<Complex64>> {
251        let seq_len = attention_out.shape()[0];
252        let mut output = Array2::zeros((seq_len, self.num_qubits));
253
254        for i in 0..seq_len {
255            for j in 0..self.num_qubits {
256                let mut sum = Complex64::new(0.0, 0.0);
257                for k in 0..(self.num_heads * self.head_dim) {
258                    let angle = self.output_params[[j, k]];
259                    let rotation = Complex64::new(angle.cos(), angle.sin());
260                    sum += attention_out[[i, k]] * rotation;
261                }
262                output[[i, j]] = sum;
263            }
264        }
265
266        Ok(output)
267    }
268}
269
270/// Quantum positional encoding for sequence information
271#[derive(Debug, Clone)]
272pub struct QuantumPositionalEncoding {
273    /// Maximum sequence length
274    max_seq_length: usize,
275    /// Number of qubits
276    num_qubits: usize,
277    /// Encoding parameters
278    encoding: Array2<f64>,
279}
280
281impl QuantumPositionalEncoding {
282    /// Create new quantum positional encoding
283    pub fn new(max_seq_length: usize, num_qubits: usize) -> Self {
284        let mut encoding = Array2::zeros((max_seq_length, num_qubits));
285
286        // Quantum sinusoidal positional encoding
287        for pos in 0..max_seq_length {
288            for i in 0..num_qubits {
289                if i % 2 == 0 {
290                    let freq = 1.0 / 10000_f64.powf(i as f64 / num_qubits as f64);
291                    encoding[[pos, i]] = (pos as f64 * freq).sin();
292                } else {
293                    let freq = 1.0 / 10000_f64.powf((i - 1) as f64 / num_qubits as f64);
294                    encoding[[pos, i]] = (pos as f64 * freq).cos();
295                }
296            }
297        }
298
299        Self {
300            max_seq_length,
301            num_qubits,
302            encoding,
303        }
304    }
305
306    /// Add positional encoding to input quantum states
307    pub fn encode(&self, input: &Array2<Complex64>) -> QuantRS2Result<Array2<Complex64>> {
308        let seq_len = input.shape()[0];
309
310        if seq_len > self.max_seq_length {
311            return Err(QuantRS2Error::InvalidInput(format!(
312                "Sequence length {} exceeds maximum {}",
313                seq_len, self.max_seq_length
314            )));
315        }
316
317        let mut output = input.clone();
318
319        // Add positional encoding using quantum phase shifts
320        for i in 0..seq_len {
321            for j in 0..self.num_qubits {
322                let phase = self.encoding[[i, j]];
323                let phase_shift = Complex64::new(phase.cos(), phase.sin());
324                output[[i, j]] = output[[i, j]] * phase_shift;
325            }
326        }
327
328        Ok(output)
329    }
330}
331
332/// Quantum feed-forward network
333#[derive(Debug, Clone)]
334pub struct QuantumFeedForward {
335    /// Input dimension
336    input_dim: usize,
337    /// Hidden dimension
338    hidden_dim: usize,
339    /// First layer parameters
340    w1: Array2<f64>,
341    /// Second layer parameters
342    w2: Array2<f64>,
343}
344
345impl QuantumFeedForward {
346    /// Create new quantum feed-forward network
347    pub fn new(input_dim: usize, hidden_dim: usize) -> Self {
348        let mut rng = thread_rng();
349        let scale1 = (2.0 / input_dim as f64).sqrt();
350        let scale2 = (2.0 / hidden_dim as f64).sqrt();
351
352        let w1 = Array2::from_shape_fn((hidden_dim, input_dim), |_| rng.gen_range(-scale1..scale1));
353
354        let w2 = Array2::from_shape_fn((input_dim, hidden_dim), |_| rng.gen_range(-scale2..scale2));
355
356        Self {
357            input_dim,
358            hidden_dim,
359            w1,
360            w2,
361        }
362    }
363
364    /// Forward pass through quantum FFN
365    pub fn forward(&self, input: &Array2<Complex64>) -> QuantRS2Result<Array2<Complex64>> {
366        let seq_len = input.shape()[0];
367
368        // First layer with quantum activation
369        let mut hidden = Array2::zeros((seq_len, self.hidden_dim));
370        for i in 0..seq_len {
371            for j in 0..self.hidden_dim {
372                let mut sum = Complex64::new(0.0, 0.0);
373                for k in 0..self.input_dim {
374                    let angle = self.w1[[j, k]];
375                    let rotation = Complex64::new(angle.cos(), angle.sin());
376                    sum += input[[i, k]] * rotation;
377                }
378                // Quantum ReLU-like activation
379                hidden[[i, j]] = self.quantum_activation(sum);
380            }
381        }
382
383        // Second layer
384        let mut output = Array2::zeros((seq_len, self.input_dim));
385        for i in 0..seq_len {
386            for j in 0..self.input_dim {
387                let mut sum = Complex64::new(0.0, 0.0);
388                for k in 0..self.hidden_dim {
389                    let angle = self.w2[[j, k]];
390                    let rotation = Complex64::new(angle.cos(), angle.sin());
391                    sum += hidden[[i, k]] * rotation;
392                }
393                output[[i, j]] = sum;
394            }
395        }
396
397        Ok(output)
398    }
399
400    /// Quantum activation function
401    fn quantum_activation(&self, z: Complex64) -> Complex64 {
402        // Quantum version of ReLU using amplitude amplification
403        let amplitude = z.norm();
404        let phase = z.arg();
405
406        if amplitude > 0.0 {
407            // Amplify based on magnitude
408            let amplified = amplitude.tanh();
409            Complex64::new(amplified * phase.cos(), amplified * phase.sin())
410        } else {
411            Complex64::new(0.0, 0.0)
412        }
413    }
414}
415
416/// Complete quantum transformer layer
417#[derive(Debug, Clone)]
418pub struct QuantumTransformerLayer {
419    /// Multi-head attention
420    attention: QuantumAttention,
421    /// Feed-forward network
422    ffn: QuantumFeedForward,
423    /// Configuration
424    config: QuantumTransformerConfig,
425}
426
427impl QuantumTransformerLayer {
428    /// Create new quantum transformer layer
429    pub fn new(config: QuantumTransformerConfig) -> QuantRS2Result<Self> {
430        let attention =
431            QuantumAttention::new(config.num_qubits, config.num_heads, config.head_dim)?;
432
433        let ffn = QuantumFeedForward::new(config.num_qubits, config.ffn_dim);
434
435        Ok(Self {
436            attention,
437            ffn,
438            config,
439        })
440    }
441
442    /// Forward pass through transformer layer
443    pub fn forward(&self, input: &Array2<Complex64>) -> QuantRS2Result<Array2<Complex64>> {
444        // Multi-head attention with residual connection
445        let attention_out = self.attention.forward(input)?;
446        let after_attention = self.add_residual(input, &attention_out);
447
448        // Layer normalization (if enabled)
449        let normalized = if self.config.use_layer_norm {
450            self.layer_norm(&after_attention)?
451        } else {
452            after_attention
453        };
454
455        // Feed-forward with residual connection
456        let ffn_out = self.ffn.forward(&normalized)?;
457        let output = self.add_residual(&normalized, &ffn_out);
458
459        // Final layer normalization
460        if self.config.use_layer_norm {
461            self.layer_norm(&output)
462        } else {
463            Ok(output)
464        }
465    }
466
467    /// Add residual connection
468    fn add_residual(
469        &self,
470        input: &Array2<Complex64>,
471        residual: &Array2<Complex64>,
472    ) -> Array2<Complex64> {
473        input + residual
474    }
475
476    /// Quantum layer normalization
477    fn layer_norm(&self, input: &Array2<Complex64>) -> QuantRS2Result<Array2<Complex64>> {
478        let seq_len = input.shape()[0];
479        let num_features = input.shape()[1];
480        let mut output = Array2::zeros((seq_len, num_features));
481
482        for i in 0..seq_len {
483            let row = input.row(i);
484
485            // Compute mean and variance of quantum state amplitudes
486            let mut mean_real = 0.0;
487            let mut mean_imag = 0.0;
488            for val in row.iter() {
489                mean_real += val.re;
490                mean_imag += val.im;
491            }
492            mean_real /= num_features as f64;
493            mean_imag /= num_features as f64;
494            let mean = Complex64::new(mean_real, mean_imag);
495
496            let mut variance = 0.0;
497            for val in row.iter() {
498                let diff = val - mean;
499                variance += diff.norm_sqr();
500            }
501            variance /= num_features as f64;
502
503            let std = (variance + 1e-5).sqrt();
504
505            // Normalize
506            for j in 0..num_features {
507                output[[i, j]] = (input[[i, j]] - mean) / std;
508            }
509        }
510
511        Ok(output)
512    }
513}
514
515/// Complete quantum transformer model
516#[derive(Debug, Clone)]
517pub struct QuantumTransformer {
518    /// Configuration
519    config: QuantumTransformerConfig,
520    /// Positional encoding
521    pos_encoding: QuantumPositionalEncoding,
522    /// Transformer layers
523    layers: Vec<QuantumTransformerLayer>,
524}
525
526impl QuantumTransformer {
527    /// Create new quantum transformer
528    pub fn new(config: QuantumTransformerConfig) -> QuantRS2Result<Self> {
529        let pos_encoding = QuantumPositionalEncoding::new(config.max_seq_length, config.num_qubits);
530
531        let mut layers = Vec::with_capacity(config.num_layers);
532        for _ in 0..config.num_layers {
533            layers.push(QuantumTransformerLayer::new(config.clone())?);
534        }
535
536        Ok(Self {
537            config,
538            pos_encoding,
539            layers,
540        })
541    }
542
543    /// Forward pass through transformer
544    pub fn forward(&self, input: &Array2<Complex64>) -> QuantRS2Result<Array2<Complex64>> {
545        // Add positional encoding
546        let mut x = self.pos_encoding.encode(input)?;
547
548        // Pass through transformer layers
549        for layer in &self.layers {
550            x = layer.forward(&x)?;
551        }
552
553        Ok(x)
554    }
555
556    /// Get configuration
557    pub fn config(&self) -> &QuantumTransformerConfig {
558        &self.config
559    }
560}
561
562#[cfg(test)]
563mod tests {
564    use super::*;
565
566    #[test]
567    fn test_quantum_attention() {
568        let attention = QuantumAttention::new(4, 2, 2).unwrap();
569
570        // Create test input (sequence of 3 quantum states)
571        let mut input = Array2::zeros((3, 4));
572        for i in 0..3 {
573            for j in 0..4 {
574                input[[i, j]] = Complex64::new((i + j) as f64 * 0.1, 0.0);
575            }
576        }
577
578        let output = attention.forward(&input).unwrap();
579        assert_eq!(output.shape(), &[3, 4]);
580    }
581
582    #[test]
583    fn test_positional_encoding() {
584        let pos_enc = QuantumPositionalEncoding::new(64, 4);
585
586        let mut input = Array2::zeros((3, 4));
587        for i in 0..3 {
588            for j in 0..4 {
589                input[[i, j]] = Complex64::new(1.0, 0.0);
590            }
591        }
592
593        let encoded = pos_enc.encode(&input).unwrap();
594        assert_eq!(encoded.shape(), &[3, 4]);
595    }
596
597    #[test]
598    fn test_quantum_transformer() {
599        let config = QuantumTransformerConfig {
600            num_qubits: 4,
601            num_heads: 2,
602            head_dim: 2,
603            num_layers: 2,
604            ffn_dim: 8,
605            dropout_rate: 0.1,
606            max_seq_length: 64,
607            use_layer_norm: true,
608        };
609
610        let transformer = QuantumTransformer::new(config).unwrap();
611
612        // Create test input
613        let mut input = Array2::zeros((3, 4));
614        for i in 0..3 {
615            for j in 0..4 {
616                input[[i, j]] = Complex64::new((i + j) as f64 * 0.1, 0.0);
617            }
618        }
619
620        let output = transformer.forward(&input).unwrap();
621        assert_eq!(output.shape(), &[3, 4]);
622    }
623}