Skip to main content

quantrs2_core/qml/
quantum_transformer.rs

1//! Quantum Transformer with Attention Mechanisms
2//!
3//! This module implements quantum transformers with attention mechanisms for
4//! processing sequential quantum data. It includes:
5//! - Multi-head quantum attention
6//! - Quantum positional encoding
7//! - Quantum feed-forward networks
8//! - Layer normalization for quantum states
9//!
10//! # Theoretical Background
11//!
12//! Quantum transformers extend classical transformer architectures to the quantum domain,
13//! leveraging quantum superposition and entanglement for enhanced representation learning.
14//! The attention mechanism is implemented using quantum circuits that compute attention
15//! scores via quantum interference patterns.
16//!
17//! # References
18//!
19//! - "Quantum Attention Networks"
20//! - "Self-Attention in Quantum Computing"
21//! - "Quantum Transformers for Natural Language Processing"
22
23use crate::{
24    error::{QuantRS2Error, QuantRS2Result},
25    gate::GateOp,
26    qubit::QubitId,
27};
28use scirs2_core::ndarray::{Array1, Array2, Array3};
29use scirs2_core::random::prelude::*;
30use scirs2_core::Complex64;
31use std::f64::consts::PI;
32
33/// Configuration for quantum transformer
34#[derive(Debug, Clone)]
35pub struct QuantumTransformerConfig {
36    /// Number of qubits for data representation
37    pub num_qubits: usize,
38    /// Number of attention heads
39    pub num_heads: usize,
40    /// Dimension of each attention head
41    pub head_dim: usize,
42    /// Number of transformer layers
43    pub num_layers: usize,
44    /// Dimension of feed-forward network
45    pub ffn_dim: usize,
46    /// Dropout rate for regularization
47    pub dropout_rate: f64,
48    /// Maximum sequence length
49    pub max_seq_length: usize,
50    /// Whether to use layer normalization
51    pub use_layer_norm: bool,
52}
53
54impl Default for QuantumTransformerConfig {
55    fn default() -> Self {
56        Self {
57            num_qubits: 4,
58            num_heads: 2,
59            head_dim: 2,
60            num_layers: 2,
61            ffn_dim: 8,
62            dropout_rate: 0.1,
63            max_seq_length: 64,
64            use_layer_norm: true,
65        }
66    }
67}
68
69/// Quantum attention mechanism using quantum circuits
70#[derive(Debug, Clone)]
71pub struct QuantumAttention {
72    /// Number of qubits
73    num_qubits: usize,
74    /// Number of attention heads
75    num_heads: usize,
76    /// Dimension per head
77    head_dim: usize,
78    /// Query parameters
79    query_params: Array2<f64>,
80    /// Key parameters
81    key_params: Array2<f64>,
82    /// Value parameters
83    value_params: Array2<f64>,
84    /// Output projection parameters
85    output_params: Array2<f64>,
86}
87
88impl QuantumAttention {
89    /// Create a new quantum attention mechanism
90    pub fn new(num_qubits: usize, num_heads: usize, head_dim: usize) -> QuantRS2Result<Self> {
91        if num_qubits < 2 {
92            return Err(QuantRS2Error::InvalidInput(
93                "Quantum attention requires at least 2 qubits".to_string(),
94            ));
95        }
96
97        if num_heads == 0 || head_dim == 0 {
98            return Err(QuantRS2Error::InvalidInput(
99                "Number of heads and head dimension must be positive".to_string(),
100            ));
101        }
102
103        let total_dim = num_heads * head_dim;
104        let mut rng = thread_rng();
105
106        // Xavier initialization for quantum parameters
107        let scale = (2.0 / (num_qubits as f64)).sqrt();
108
109        let query_params =
110            Array2::from_shape_fn((total_dim, num_qubits), |_| rng.random_range(-scale..scale));
111
112        let key_params =
113            Array2::from_shape_fn((total_dim, num_qubits), |_| rng.random_range(-scale..scale));
114
115        let value_params =
116            Array2::from_shape_fn((total_dim, num_qubits), |_| rng.random_range(-scale..scale));
117
118        let output_params =
119            Array2::from_shape_fn((num_qubits, total_dim), |_| rng.random_range(-scale..scale));
120
121        Ok(Self {
122            num_qubits,
123            num_heads,
124            head_dim,
125            query_params,
126            key_params,
127            value_params,
128            output_params,
129        })
130    }
131
132    /// Compute quantum attention scores using quantum interference
133    pub fn attention_scores(
134        &self,
135        query: &Array2<Complex64>,
136        key: &Array2<Complex64>,
137    ) -> QuantRS2Result<Array2<f64>> {
138        let seq_len = query.shape()[0];
139        let mut scores = Array2::zeros((seq_len, seq_len));
140
141        // Compute attention scores via quantum state overlap
142        for i in 0..seq_len {
143            for j in 0..seq_len {
144                let q = query.row(i);
145                let k = key.row(j);
146
147                // Quantum inner product (fidelity)
148                let mut score = Complex64::new(0.0, 0.0);
149                for (qi, ki) in q.iter().zip(k.iter()) {
150                    score += qi.conj() * ki;
151                }
152
153                // Scale by sqrt(head_dim) as in classical transformers
154                let scaled_score = score.norm() / (self.head_dim as f64).sqrt();
155                scores[[i, j]] = scaled_score;
156            }
157        }
158
159        Ok(scores)
160    }
161
162    /// Apply softmax to attention scores
163    pub fn softmax(&self, scores: &Array2<f64>) -> Array2<f64> {
164        let seq_len = scores.shape()[0];
165        let mut softmax_scores = Array2::zeros((seq_len, seq_len));
166
167        for i in 0..seq_len {
168            let row = scores.row(i);
169            let max_score = row.iter().copied().fold(f64::NEG_INFINITY, f64::max);
170
171            // Compute exp(score - max) for numerical stability
172            let mut exp_scores = Array1::zeros(seq_len);
173            let mut sum_exp = 0.0;
174
175            for (j, &score) in row.iter().enumerate() {
176                let exp_val = (score - max_score).exp();
177                exp_scores[j] = exp_val;
178                sum_exp += exp_val;
179            }
180
181            // Normalize
182            for j in 0..seq_len {
183                softmax_scores[[i, j]] = exp_scores[j] / sum_exp;
184            }
185        }
186
187        softmax_scores
188    }
189
190    /// Apply quantum attention to input
191    pub fn forward(&self, input: &Array2<Complex64>) -> QuantRS2Result<Array2<Complex64>> {
192        let seq_len = input.shape()[0];
193
194        // Project to query, key, value using quantum rotations
195        let query = self.project_qkv(input, &self.query_params)?;
196        let key = self.project_qkv(input, &self.key_params)?;
197        let value = self.project_qkv(input, &self.value_params)?;
198
199        // Compute attention scores
200        let scores = self.attention_scores(&query, &key)?;
201        let attention_weights = self.softmax(&scores);
202
203        // Apply attention to values
204        let total_dim = self.num_heads * self.head_dim;
205        let mut output = Array2::zeros((seq_len, total_dim));
206
207        for i in 0..seq_len {
208            for j in 0..seq_len {
209                let weight = attention_weights[[i, j]];
210                for k in 0..total_dim {
211                    output[[i, k]] = output[[i, k]] + value[[j, k]] * weight;
212                }
213            }
214        }
215
216        // Project back to original dimension
217        self.project_output(&output)
218    }
219
220    /// Project input to query/key/value space
221    fn project_qkv(
222        &self,
223        input: &Array2<Complex64>,
224        params: &Array2<f64>,
225    ) -> QuantRS2Result<Array2<Complex64>> {
226        let seq_len = input.shape()[0];
227        let out_dim = params.shape()[0];
228        let mut output = Array2::zeros((seq_len, out_dim));
229
230        for i in 0..seq_len {
231            for j in 0..out_dim {
232                let mut sum = Complex64::new(0.0, 0.0);
233                for k in 0..self.num_qubits {
234                    // Quantum rotation based projection
235                    let angle = params[[j, k]];
236                    let rotation = Complex64::new(angle.cos(), angle.sin());
237                    sum += input[[i, k]] * rotation;
238                }
239                output[[i, j]] = sum;
240            }
241        }
242
243        Ok(output)
244    }
245
246    /// Project output back to original dimension
247    fn project_output(
248        &self,
249        attention_out: &Array2<Complex64>,
250    ) -> QuantRS2Result<Array2<Complex64>> {
251        let seq_len = attention_out.shape()[0];
252        let mut output = Array2::zeros((seq_len, self.num_qubits));
253
254        for i in 0..seq_len {
255            for j in 0..self.num_qubits {
256                let mut sum = Complex64::new(0.0, 0.0);
257                for k in 0..(self.num_heads * self.head_dim) {
258                    let angle = self.output_params[[j, k]];
259                    let rotation = Complex64::new(angle.cos(), angle.sin());
260                    sum += attention_out[[i, k]] * rotation;
261                }
262                output[[i, j]] = sum;
263            }
264        }
265
266        Ok(output)
267    }
268}
269
270/// Quantum positional encoding for sequence information
271#[derive(Debug, Clone)]
272pub struct QuantumPositionalEncoding {
273    /// Maximum sequence length
274    max_seq_length: usize,
275    /// Number of qubits
276    num_qubits: usize,
277    /// Encoding parameters
278    encoding: Array2<f64>,
279}
280
281impl QuantumPositionalEncoding {
282    /// Create new quantum positional encoding
283    pub fn new(max_seq_length: usize, num_qubits: usize) -> Self {
284        let mut encoding = Array2::zeros((max_seq_length, num_qubits));
285
286        // Quantum sinusoidal positional encoding
287        for pos in 0..max_seq_length {
288            for i in 0..num_qubits {
289                if i % 2 == 0 {
290                    let freq = 1.0 / 10000_f64.powf(i as f64 / num_qubits as f64);
291                    encoding[[pos, i]] = (pos as f64 * freq).sin();
292                } else {
293                    let freq = 1.0 / 10000_f64.powf((i - 1) as f64 / num_qubits as f64);
294                    encoding[[pos, i]] = (pos as f64 * freq).cos();
295                }
296            }
297        }
298
299        Self {
300            max_seq_length,
301            num_qubits,
302            encoding,
303        }
304    }
305
306    /// Add positional encoding to input quantum states
307    pub fn encode(&self, input: &Array2<Complex64>) -> QuantRS2Result<Array2<Complex64>> {
308        let seq_len = input.shape()[0];
309
310        if seq_len > self.max_seq_length {
311            return Err(QuantRS2Error::InvalidInput(format!(
312                "Sequence length {} exceeds maximum {}",
313                seq_len, self.max_seq_length
314            )));
315        }
316
317        let mut output = input.clone();
318
319        // Add positional encoding using quantum phase shifts
320        for i in 0..seq_len {
321            for j in 0..self.num_qubits {
322                let phase = self.encoding[[i, j]];
323                let phase_shift = Complex64::new(phase.cos(), phase.sin());
324                output[[i, j]] = output[[i, j]] * phase_shift;
325            }
326        }
327
328        Ok(output)
329    }
330}
331
332/// Quantum feed-forward network
333#[derive(Debug, Clone)]
334pub struct QuantumFeedForward {
335    /// Input dimension
336    input_dim: usize,
337    /// Hidden dimension
338    hidden_dim: usize,
339    /// First layer parameters
340    w1: Array2<f64>,
341    /// Second layer parameters
342    w2: Array2<f64>,
343}
344
345impl QuantumFeedForward {
346    /// Create new quantum feed-forward network
347    pub fn new(input_dim: usize, hidden_dim: usize) -> Self {
348        let mut rng = thread_rng();
349        let scale1 = (2.0 / input_dim as f64).sqrt();
350        let scale2 = (2.0 / hidden_dim as f64).sqrt();
351
352        let w1 = Array2::from_shape_fn((hidden_dim, input_dim), |_| {
353            rng.random_range(-scale1..scale1)
354        });
355
356        let w2 = Array2::from_shape_fn((input_dim, hidden_dim), |_| {
357            rng.random_range(-scale2..scale2)
358        });
359
360        Self {
361            input_dim,
362            hidden_dim,
363            w1,
364            w2,
365        }
366    }
367
368    /// Forward pass through quantum FFN
369    pub fn forward(&self, input: &Array2<Complex64>) -> QuantRS2Result<Array2<Complex64>> {
370        let seq_len = input.shape()[0];
371
372        // First layer with quantum activation
373        let mut hidden = Array2::zeros((seq_len, self.hidden_dim));
374        for i in 0..seq_len {
375            for j in 0..self.hidden_dim {
376                let mut sum = Complex64::new(0.0, 0.0);
377                for k in 0..self.input_dim {
378                    let angle = self.w1[[j, k]];
379                    let rotation = Complex64::new(angle.cos(), angle.sin());
380                    sum += input[[i, k]] * rotation;
381                }
382                // Quantum ReLU-like activation
383                hidden[[i, j]] = self.quantum_activation(sum);
384            }
385        }
386
387        // Second layer
388        let mut output = Array2::zeros((seq_len, self.input_dim));
389        for i in 0..seq_len {
390            for j in 0..self.input_dim {
391                let mut sum = Complex64::new(0.0, 0.0);
392                for k in 0..self.hidden_dim {
393                    let angle = self.w2[[j, k]];
394                    let rotation = Complex64::new(angle.cos(), angle.sin());
395                    sum += hidden[[i, k]] * rotation;
396                }
397                output[[i, j]] = sum;
398            }
399        }
400
401        Ok(output)
402    }
403
404    /// Quantum activation function
405    fn quantum_activation(&self, z: Complex64) -> Complex64 {
406        // Quantum version of ReLU using amplitude amplification
407        let amplitude = z.norm();
408        let phase = z.arg();
409
410        if amplitude > 0.0 {
411            // Amplify based on magnitude
412            let amplified = amplitude.tanh();
413            Complex64::new(amplified * phase.cos(), amplified * phase.sin())
414        } else {
415            Complex64::new(0.0, 0.0)
416        }
417    }
418}
419
420/// Complete quantum transformer layer
421#[derive(Debug, Clone)]
422pub struct QuantumTransformerLayer {
423    /// Multi-head attention
424    attention: QuantumAttention,
425    /// Feed-forward network
426    ffn: QuantumFeedForward,
427    /// Configuration
428    config: QuantumTransformerConfig,
429}
430
431impl QuantumTransformerLayer {
432    /// Create new quantum transformer layer
433    pub fn new(config: QuantumTransformerConfig) -> QuantRS2Result<Self> {
434        let attention =
435            QuantumAttention::new(config.num_qubits, config.num_heads, config.head_dim)?;
436
437        let ffn = QuantumFeedForward::new(config.num_qubits, config.ffn_dim);
438
439        Ok(Self {
440            attention,
441            ffn,
442            config,
443        })
444    }
445
446    /// Forward pass through transformer layer
447    pub fn forward(&self, input: &Array2<Complex64>) -> QuantRS2Result<Array2<Complex64>> {
448        // Multi-head attention with residual connection
449        let attention_out = self.attention.forward(input)?;
450        let after_attention = self.add_residual(input, &attention_out);
451
452        // Layer normalization (if enabled)
453        let normalized = if self.config.use_layer_norm {
454            self.layer_norm(&after_attention)?
455        } else {
456            after_attention
457        };
458
459        // Feed-forward with residual connection
460        let ffn_out = self.ffn.forward(&normalized)?;
461        let output = self.add_residual(&normalized, &ffn_out);
462
463        // Final layer normalization
464        if self.config.use_layer_norm {
465            self.layer_norm(&output)
466        } else {
467            Ok(output)
468        }
469    }
470
471    /// Add residual connection
472    fn add_residual(
473        &self,
474        input: &Array2<Complex64>,
475        residual: &Array2<Complex64>,
476    ) -> Array2<Complex64> {
477        input + residual
478    }
479
480    /// Quantum layer normalization
481    fn layer_norm(&self, input: &Array2<Complex64>) -> QuantRS2Result<Array2<Complex64>> {
482        let seq_len = input.shape()[0];
483        let num_features = input.shape()[1];
484        let mut output = Array2::zeros((seq_len, num_features));
485
486        for i in 0..seq_len {
487            let row = input.row(i);
488
489            // Compute mean and variance of quantum state amplitudes
490            let mut mean_real = 0.0;
491            let mut mean_imag = 0.0;
492            for val in row {
493                mean_real += val.re;
494                mean_imag += val.im;
495            }
496            mean_real /= num_features as f64;
497            mean_imag /= num_features as f64;
498            let mean = Complex64::new(mean_real, mean_imag);
499
500            let mut variance = 0.0;
501            for val in row {
502                let diff = val - mean;
503                variance += diff.norm_sqr();
504            }
505            variance /= num_features as f64;
506
507            let std = (variance + 1e-5).sqrt();
508
509            // Normalize
510            for j in 0..num_features {
511                output[[i, j]] = (input[[i, j]] - mean) / std;
512            }
513        }
514
515        Ok(output)
516    }
517}
518
519/// Complete quantum transformer model
520#[derive(Debug, Clone)]
521pub struct QuantumTransformer {
522    /// Configuration
523    config: QuantumTransformerConfig,
524    /// Positional encoding
525    pos_encoding: QuantumPositionalEncoding,
526    /// Transformer layers
527    layers: Vec<QuantumTransformerLayer>,
528}
529
530impl QuantumTransformer {
531    /// Create new quantum transformer
532    pub fn new(config: QuantumTransformerConfig) -> QuantRS2Result<Self> {
533        let pos_encoding = QuantumPositionalEncoding::new(config.max_seq_length, config.num_qubits);
534
535        let mut layers = Vec::with_capacity(config.num_layers);
536        for _ in 0..config.num_layers {
537            layers.push(QuantumTransformerLayer::new(config.clone())?);
538        }
539
540        Ok(Self {
541            config,
542            pos_encoding,
543            layers,
544        })
545    }
546
547    /// Forward pass through transformer
548    pub fn forward(&self, input: &Array2<Complex64>) -> QuantRS2Result<Array2<Complex64>> {
549        // Add positional encoding
550        let mut x = self.pos_encoding.encode(input)?;
551
552        // Pass through transformer layers
553        for layer in &self.layers {
554            x = layer.forward(&x)?;
555        }
556
557        Ok(x)
558    }
559
560    /// Get configuration
561    pub const fn config(&self) -> &QuantumTransformerConfig {
562        &self.config
563    }
564}
565
566#[cfg(test)]
567mod tests {
568    use super::*;
569
570    #[test]
571    fn test_quantum_attention() {
572        let attention = QuantumAttention::new(4, 2, 2).expect("Failed to create QuantumAttention");
573
574        // Create test input (sequence of 3 quantum states)
575        let mut input = Array2::zeros((3, 4));
576        for i in 0..3 {
577            for j in 0..4 {
578                input[[i, j]] = Complex64::new((i + j) as f64 * 0.1, 0.0);
579            }
580        }
581
582        let output = attention
583            .forward(&input)
584            .expect("Attention forward pass should succeed");
585        assert_eq!(output.shape(), &[3, 4]);
586    }
587
588    #[test]
589    fn test_positional_encoding() {
590        let pos_enc = QuantumPositionalEncoding::new(64, 4);
591
592        let mut input = Array2::zeros((3, 4));
593        for i in 0..3 {
594            for j in 0..4 {
595                input[[i, j]] = Complex64::new(1.0, 0.0);
596            }
597        }
598
599        let encoded = pos_enc
600            .encode(&input)
601            .expect("Positional encoding should succeed");
602        assert_eq!(encoded.shape(), &[3, 4]);
603    }
604
605    #[test]
606    fn test_quantum_transformer() {
607        let config = QuantumTransformerConfig {
608            num_qubits: 4,
609            num_heads: 2,
610            head_dim: 2,
611            num_layers: 2,
612            ffn_dim: 8,
613            dropout_rate: 0.1,
614            max_seq_length: 64,
615            use_layer_norm: true,
616        };
617
618        let transformer =
619            QuantumTransformer::new(config).expect("Failed to create QuantumTransformer");
620
621        // Create test input
622        let mut input = Array2::zeros((3, 4));
623        for i in 0..3 {
624            for j in 0..4 {
625                input[[i, j]] = Complex64::new((i + j) as f64 * 0.1, 0.0);
626            }
627        }
628
629        let output = transformer
630            .forward(&input)
631            .expect("Transformer forward pass should succeed");
632        assert_eq!(output.shape(), &[3, 4]);
633    }
634}