quantrs2_ml/
attention.rs

1//! Quantum attention mechanisms for transformer architectures.
2//!
3//! This module implements quantum versions of attention mechanisms including
4//! multi-head attention, cross-attention, and quantum transformer blocks.
5
6use scirs2_core::ndarray::{s, Array1, Array2, Array3, Axis};
7use scirs2_core::Complex64;
8use std::collections::HashMap;
9use std::f64::consts::PI;
10
11use crate::autodiff::DifferentiableParam;
12use crate::error::{MLError, Result};
13use crate::utils::VariationalCircuit;
14use quantrs2_circuit::prelude::*;
15use quantrs2_core::gate::{multi::*, single::*, GateOp};
16
17/// Quantum self-attention mechanism
18#[derive(Debug, Clone)]
19pub struct QuantumSelfAttention {
20    /// Embedding dimension
21    embed_dim: usize,
22    /// Number of attention heads
23    num_heads: usize,
24    /// Head dimension
25    head_dim: usize,
26    /// Number of qubits per head
27    qubits_per_head: usize,
28    /// Query projection circuit
29    query_circuit: QuantumProjection,
30    /// Key projection circuit
31    key_circuit: QuantumProjection,
32    /// Value projection circuit
33    value_circuit: QuantumProjection,
34    /// Output projection circuit
35    output_circuit: QuantumProjection,
36    /// Dropout rate
37    dropout_rate: f64,
38    /// Temperature for attention scaling
39    temperature: f64,
40}
41
42/// Quantum projection layer
43#[derive(Debug, Clone)]
44struct QuantumProjection {
45    /// Input dimension
46    input_dim: usize,
47    /// Output dimension
48    output_dim: usize,
49    /// Number of qubits
50    num_qubits: usize,
51    /// Variational circuit
52    circuit: VariationalCircuit,
53    /// Parameters
54    parameters: HashMap<String, f64>,
55}
56
57impl QuantumProjection {
58    /// Create a new projection layer
59    fn new(input_dim: usize, output_dim: usize) -> Self {
60        let num_qubits = ((input_dim.max(output_dim)) as f64).log2().ceil() as usize;
61        let circuit = Self::build_projection_circuit(num_qubits);
62
63        Self {
64            input_dim,
65            output_dim,
66            num_qubits,
67            circuit,
68            parameters: HashMap::new(),
69        }
70    }
71
72    /// Build the projection circuit
73    fn build_projection_circuit(num_qubits: usize) -> VariationalCircuit {
74        let mut circuit = VariationalCircuit::new(num_qubits);
75
76        // Layer 1: Feature encoding
77        for q in 0..num_qubits {
78            circuit.add_gate("RY", vec![q], vec![format!("encode_{}", q)]);
79        }
80
81        // Layer 2: Entangling layer
82        for q in 0..num_qubits - 1 {
83            circuit.add_gate("CNOT", vec![q, q + 1], vec![]);
84        }
85        if num_qubits > 2 {
86            circuit.add_gate("CNOT", vec![num_qubits - 1, 0], vec![]);
87        }
88
89        // Layer 3: Parameterized rotations
90        for q in 0..num_qubits {
91            circuit.add_gate("RX", vec![q], vec![format!("rx_{}", q)]);
92            circuit.add_gate("RZ", vec![q], vec![format!("rz_{}", q)]);
93        }
94
95        // Layer 4: Second entangling layer
96        for q in (0..num_qubits - 1).step_by(2) {
97            circuit.add_gate("CZ", vec![q, q + 1], vec![]);
98        }
99        for q in (1..num_qubits - 1).step_by(2) {
100            circuit.add_gate("CZ", vec![q, q + 1], vec![]);
101        }
102
103        // Layer 5: Final rotations
104        for q in 0..num_qubits {
105            circuit.add_gate("RY", vec![q], vec![format!("final_{}", q)]);
106        }
107
108        circuit
109    }
110
111    /// Project input through the quantum circuit
112    fn forward(&self, input: &Array1<f64>) -> Result<Array1<f64>> {
113        // Encode input
114        let encoded = self.encode_input(input)?;
115
116        // Execute circuit (simplified)
117        let output_state = self.execute_circuit(&encoded)?;
118
119        // Decode output
120        self.decode_output(&output_state)
121    }
122
123    /// Encode classical input to quantum state
124    fn encode_input(&self, input: &Array1<f64>) -> Result<Vec<Complex64>> {
125        let state_dim = 2_usize.pow(self.num_qubits as u32);
126        let mut quantum_state = vec![Complex64::new(0.0, 0.0); state_dim];
127
128        // Amplitude encoding
129        let norm: f64 = input.iter().map(|x| x * x).sum::<f64>().sqrt();
130        if norm < 1e-10 {
131            return Err(MLError::InvalidInput("Zero norm input".to_string()));
132        }
133
134        for (i, &val) in input.iter().enumerate() {
135            if i < state_dim {
136                quantum_state[i] = Complex64::new(val / norm, 0.0);
137            }
138        }
139
140        Ok(quantum_state)
141    }
142
143    /// Execute the quantum circuit
144    fn execute_circuit(&self, input_state: &[Complex64]) -> Result<Vec<Complex64>> {
145        // Simplified circuit execution
146        // In practice, would use actual quantum simulation
147        let state_dim = input_state.len();
148        let mut output_state = input_state.to_vec();
149
150        // Apply some transformation
151        for i in 0..state_dim {
152            let phase = (i as f64) * 0.1;
153            output_state[i] *= Complex64::new(phase.cos(), phase.sin());
154        }
155
156        Ok(output_state)
157    }
158
159    /// Decode quantum state to classical output
160    fn decode_output(&self, quantum_state: &[Complex64]) -> Result<Array1<f64>> {
161        let mut output = Array1::zeros(self.output_dim);
162
163        // Extract amplitudes
164        for i in 0..self.output_dim.min(quantum_state.len()) {
165            output[i] = quantum_state[i].norm();
166        }
167
168        Ok(output)
169    }
170}
171
172impl QuantumSelfAttention {
173    /// Create a new quantum self-attention layer
174    pub fn new(embed_dim: usize, num_heads: usize, dropout_rate: f64) -> Self {
175        assert!(
176            embed_dim % num_heads == 0,
177            "embed_dim must be divisible by num_heads"
178        );
179
180        let head_dim = embed_dim / num_heads;
181        let qubits_per_head = (head_dim as f64).log2().ceil() as usize;
182
183        Self {
184            embed_dim,
185            num_heads,
186            head_dim,
187            qubits_per_head,
188            query_circuit: QuantumProjection::new(embed_dim, embed_dim),
189            key_circuit: QuantumProjection::new(embed_dim, embed_dim),
190            value_circuit: QuantumProjection::new(embed_dim, embed_dim),
191            output_circuit: QuantumProjection::new(embed_dim, embed_dim),
192            dropout_rate,
193            temperature: (head_dim as f64).sqrt(),
194        }
195    }
196
197    /// Forward pass through attention layer
198    pub fn forward(
199        &self,
200        query: &Array2<f64>,
201        key: &Array2<f64>,
202        value: &Array2<f64>,
203        mask: Option<&Array2<bool>>,
204    ) -> Result<Array2<f64>> {
205        let batch_size = query.nrows();
206        let seq_len = query.ncols() / self.embed_dim;
207
208        // Project Q, K, V
209        let q = self.project_to_heads(query, &self.query_circuit)?;
210        let k = self.project_to_heads(key, &self.key_circuit)?;
211        let v = self.project_to_heads(value, &self.value_circuit)?;
212
213        // Compute attention scores
214        let attention_scores = self.compute_attention_scores(&q, &k)?;
215
216        // Apply mask if provided
217        let masked_scores = if let Some(mask) = mask {
218            self.apply_mask(&attention_scores, mask)?
219        } else {
220            attention_scores
221        };
222
223        // Apply softmax
224        let attention_weights = self.quantum_softmax(&masked_scores)?;
225
226        // Apply attention to values
227        let attended_values = self.apply_attention(&attention_weights, &v)?;
228
229        // Concatenate heads and project output
230        self.project_output(&attended_values)
231    }
232
233    /// Project input to multi-head format
234    fn project_to_heads(
235        &self,
236        input: &Array2<f64>,
237        projection: &QuantumProjection,
238    ) -> Result<Array3<f64>> {
239        let batch_size = input.nrows();
240        let seq_len = input.ncols() / self.embed_dim;
241
242        let mut output = Array3::zeros((batch_size, self.num_heads, seq_len * self.head_dim));
243
244        for b in 0..batch_size {
245            for s in 0..seq_len {
246                let start = s * self.embed_dim;
247                let end = start + self.embed_dim;
248                let input_vec = input.row(b).slice(s![start..end]).to_owned();
249
250                let projected = projection.forward(&input_vec)?;
251
252                // Split into heads
253                for h in 0..self.num_heads {
254                    let head_start = h * self.head_dim;
255                    let head_end = head_start + self.head_dim;
256
257                    for i in 0..self.head_dim {
258                        if head_start + i < projected.len() {
259                            output[[b, h, s * self.head_dim + i]] = projected[head_start + i];
260                        }
261                    }
262                }
263            }
264        }
265
266        Ok(output)
267    }
268
269    /// Compute quantum attention scores
270    fn compute_attention_scores(
271        &self,
272        query: &Array3<f64>,
273        key: &Array3<f64>,
274    ) -> Result<Array3<f64>> {
275        let batch_size = query.shape()[0];
276        let seq_len = query.shape()[2] / self.head_dim;
277
278        let mut scores = Array3::zeros((batch_size, self.num_heads, seq_len * seq_len));
279
280        // Quantum dot product attention
281        for b in 0..batch_size {
282            for h in 0..self.num_heads {
283                for i in 0..seq_len {
284                    for j in 0..seq_len {
285                        let q_start = i * self.head_dim;
286                        let q_end = q_start + self.head_dim;
287                        let k_start = j * self.head_dim;
288                        let k_end = k_start + self.head_dim;
289
290                        let q_vec = query.slice(s![b, h, q_start..q_end]);
291                        let k_vec = key.slice(s![b, h, k_start..k_end]);
292
293                        // Quantum inner product
294                        let score =
295                            self.quantum_inner_product(&q_vec.to_owned(), &k_vec.to_owned())?;
296                        scores[[b, h, i * seq_len + j]] = score / self.temperature;
297                    }
298                }
299            }
300        }
301
302        Ok(scores)
303    }
304
305    /// Compute quantum inner product
306    fn quantum_inner_product(&self, vec1: &Array1<f64>, vec2: &Array1<f64>) -> Result<f64> {
307        // Build quantum circuit for inner product
308        let num_qubits = self.qubits_per_head * 2 + 1; // Extra qubit for measurement
309        let mut circuit = VariationalCircuit::new(num_qubits);
310
311        // Encode vectors
312        for i in 0..self.qubits_per_head {
313            if i < vec1.len() {
314                let angle1 = vec1[i] * PI;
315                circuit.add_gate("RY", vec![i], vec![angle1.to_string()]);
316            }
317            if i < vec2.len() {
318                let angle2 = vec2[i] * PI;
319                circuit.add_gate(
320                    "RY",
321                    vec![i + self.qubits_per_head],
322                    vec![angle2.to_string()],
323                );
324            }
325        }
326
327        // Hadamard on ancilla
328        circuit.add_gate("H", vec![num_qubits - 1], vec![]);
329
330        // Controlled swap test
331        for i in 0..self.qubits_per_head {
332            circuit.add_gate(
333                "CSWAP",
334                vec![num_qubits - 1, i, i + self.qubits_per_head],
335                vec![],
336            );
337        }
338
339        // Hadamard on ancilla
340        circuit.add_gate("H", vec![num_qubits - 1], vec![]);
341
342        // Measurement probability gives inner product
343        // Simplified: return dot product
344        Ok(vec1.dot(vec2))
345    }
346
347    /// Quantum softmax implementation
348    fn quantum_softmax(&self, scores: &Array3<f64>) -> Result<Array3<f64>> {
349        let mut output = scores.clone();
350
351        // Apply quantum softmax per attention head
352        for b in 0..scores.shape()[0] {
353            for h in 0..scores.shape()[1] {
354                let head_scores = scores.slice(s![b, h, ..]);
355                let seq_len = (head_scores.len() as f64).sqrt() as usize;
356
357                for i in 0..seq_len {
358                    let start = i * seq_len;
359                    let end = start + seq_len;
360                    let row_scores = head_scores.slice(s![start..end]);
361
362                    // Quantum softmax circuit
363                    let softmax_vals = self.quantum_softmax_circuit(&row_scores.to_owned())?;
364
365                    for j in 0..seq_len {
366                        output[[b, h, start + j]] = softmax_vals[j];
367                    }
368                }
369            }
370        }
371
372        Ok(output)
373    }
374
375    /// Quantum circuit for softmax
376    fn quantum_softmax_circuit(&self, logits: &Array1<f64>) -> Result<Vec<f64>> {
377        // Classical softmax for now
378        let max_logit = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
379        let exp_logits: Vec<f64> = logits.iter().map(|&x| (x - max_logit).exp()).collect();
380        let sum_exp: f64 = exp_logits.iter().sum();
381
382        Ok(exp_logits.into_iter().map(|x| x / sum_exp).collect())
383    }
384
385    /// Apply attention weights to values
386    fn apply_attention(&self, weights: &Array3<f64>, values: &Array3<f64>) -> Result<Array3<f64>> {
387        let batch_size = weights.shape()[0];
388        let num_heads = weights.shape()[1];
389        let seq_len = (weights.shape()[2] as f64).sqrt() as usize;
390
391        let mut output = Array3::zeros((batch_size, num_heads, seq_len * self.head_dim));
392
393        for b in 0..batch_size {
394            for h in 0..num_heads {
395                for i in 0..seq_len {
396                    for j in 0..seq_len {
397                        let weight = weights[[b, h, i * seq_len + j]];
398
399                        for d in 0..self.head_dim {
400                            output[[b, h, i * self.head_dim + d]] +=
401                                weight * values[[b, h, j * self.head_dim + d]];
402                        }
403                    }
404                }
405            }
406        }
407
408        Ok(output)
409    }
410
411    /// Apply attention mask
412    fn apply_mask(&self, scores: &Array3<f64>, mask: &Array2<bool>) -> Result<Array3<f64>> {
413        let mut masked_scores = scores.clone();
414
415        for b in 0..scores.shape()[0] {
416            for h in 0..scores.shape()[1] {
417                for (idx, &is_masked) in mask.iter().enumerate() {
418                    if is_masked && idx < scores.shape()[2] {
419                        masked_scores[[b, h, idx]] = -1e9; // Large negative value
420                    }
421                }
422            }
423        }
424
425        Ok(masked_scores)
426    }
427
428    /// Project concatenated heads to output
429    fn project_output(&self, attended: &Array3<f64>) -> Result<Array2<f64>> {
430        let batch_size = attended.shape()[0];
431        let seq_len = attended.shape()[2] / self.head_dim;
432
433        let mut output = Array2::zeros((batch_size, seq_len * self.embed_dim));
434
435        for b in 0..batch_size {
436            for s in 0..seq_len {
437                // Concatenate heads
438                let mut concat = Array1::zeros(self.embed_dim);
439                for h in 0..self.num_heads {
440                    for d in 0..self.head_dim {
441                        concat[h * self.head_dim + d] = attended[[b, h, s * self.head_dim + d]];
442                    }
443                }
444
445                // Project through output circuit
446                let projected = self.output_circuit.forward(&concat)?;
447
448                for d in 0..self.embed_dim {
449                    output[[b, s * self.embed_dim + d]] = projected[d];
450                }
451            }
452        }
453
454        Ok(output)
455    }
456}
457
458/// Quantum transformer block
459#[derive(Debug)]
460pub struct QuantumTransformerBlock {
461    /// Self-attention layer
462    self_attention: QuantumSelfAttention,
463    /// Feed-forward dimension
464    ff_dim: usize,
465    /// First feed-forward layer
466    ff1: QuantumFeedForward,
467    /// Second feed-forward layer
468    ff2: QuantumFeedForward,
469    /// Layer normalization (classical)
470    layer_norm1: LayerNorm,
471    layer_norm2: LayerNorm,
472    /// Dropout rate
473    dropout_rate: f64,
474}
475
476/// Quantum feed-forward layer
477#[derive(Debug)]
478struct QuantumFeedForward {
479    input_dim: usize,
480    output_dim: usize,
481    circuit: VariationalCircuit,
482}
483
484impl QuantumFeedForward {
485    fn new(input_dim: usize, output_dim: usize) -> Self {
486        let num_qubits = ((input_dim.max(output_dim)) as f64).log2().ceil() as usize;
487        let circuit = Self::build_ff_circuit(num_qubits);
488
489        Self {
490            input_dim,
491            output_dim,
492            circuit,
493        }
494    }
495
496    fn build_ff_circuit(num_qubits: usize) -> VariationalCircuit {
497        let mut circuit = VariationalCircuit::new(num_qubits);
498
499        // Dense connectivity pattern
500        for layer in 0..3 {
501            // Rotation layer
502            for q in 0..num_qubits {
503                circuit.add_gate("RY", vec![q], vec![format!("ff_ry_{}_{}", layer, q)]);
504                circuit.add_gate("RZ", vec![q], vec![format!("ff_rz_{}_{}", layer, q)]);
505            }
506
507            // All-to-all entangling
508            for i in 0..num_qubits {
509                for j in i + 1..num_qubits {
510                    circuit.add_gate("CZ", vec![i, j], vec![]);
511                }
512            }
513        }
514
515        circuit
516    }
517
518    fn forward(&self, input: &Array1<f64>) -> Result<Array1<f64>> {
519        // Simplified forward pass
520        let mut output = Array1::zeros(self.output_dim);
521
522        // Apply non-linear transformation
523        for i in 0..self.output_dim {
524            if i < input.len() {
525                output[i] = (input[i] * 2.0 * PI).sin() * 0.5 + 0.5;
526            }
527        }
528
529        Ok(output)
530    }
531}
532
533/// Classical layer normalization
534#[derive(Debug)]
535struct LayerNorm {
536    normalized_shape: usize,
537    epsilon: f64,
538}
539
540impl LayerNorm {
541    fn new(normalized_shape: usize) -> Self {
542        Self {
543            normalized_shape,
544            epsilon: 1e-5,
545        }
546    }
547
548    fn forward(&self, input: &Array2<f64>) -> Array2<f64> {
549        let mean = input.mean_axis(Axis(1)).unwrap();
550        let variance = input.var_axis(Axis(1), 0.0);
551
552        let mut output = input.clone();
553        for i in 0..input.nrows() {
554            let std = (variance[i] + self.epsilon).sqrt();
555            for j in 0..input.ncols() {
556                output[[i, j]] = (input[[i, j]] - mean[i]) / std;
557            }
558        }
559
560        output
561    }
562}
563
564impl QuantumTransformerBlock {
565    /// Create a new transformer block
566    pub fn new(embed_dim: usize, num_heads: usize, ff_dim: usize, dropout_rate: f64) -> Self {
567        Self {
568            self_attention: QuantumSelfAttention::new(embed_dim, num_heads, dropout_rate),
569            ff_dim,
570            ff1: QuantumFeedForward::new(embed_dim, ff_dim),
571            ff2: QuantumFeedForward::new(ff_dim, embed_dim),
572            layer_norm1: LayerNorm::new(embed_dim),
573            layer_norm2: LayerNorm::new(embed_dim),
574            dropout_rate,
575        }
576    }
577
578    /// Forward pass through transformer block
579    pub fn forward(&self, input: &Array2<f64>, mask: Option<&Array2<bool>>) -> Result<Array2<f64>> {
580        // Self-attention with residual connection
581        let attended = self.self_attention.forward(input, input, input, mask)?;
582        let residual1 = &attended + input;
583        let norm1 = self.layer_norm1.forward(&residual1);
584
585        // Feed-forward with residual connection
586        let batch_size = norm1.nrows();
587        let seq_dim = norm1.ncols();
588        let seq_len = seq_dim / self.self_attention.embed_dim;
589
590        let mut ff_output = Array2::zeros((batch_size, seq_dim));
591
592        for b in 0..batch_size {
593            for s in 0..seq_len {
594                let start = s * self.self_attention.embed_dim;
595                let end = start + self.self_attention.embed_dim;
596
597                let input_slice = norm1.slice(s![b, start..end]).to_owned();
598                let hidden = self.ff1.forward(&input_slice)?;
599                let output = self.ff2.forward(&hidden)?;
600
601                for i in 0..self.self_attention.embed_dim {
602                    ff_output[[b, start + i]] = output[i];
603                }
604            }
605        }
606
607        let residual2 = &ff_output + &norm1;
608        let output = self.layer_norm2.forward(&residual2);
609
610        Ok(output)
611    }
612}
613
614/// Quantum transformer model
615#[derive(Debug)]
616pub struct QuantumTransformer {
617    /// Embedding dimension
618    embed_dim: usize,
619    /// Number of transformer blocks
620    num_layers: usize,
621    /// Transformer blocks
622    blocks: Vec<QuantumTransformerBlock>,
623    /// Positional encoding
624    positional_encoding: PositionalEncoding,
625}
626
627/// Quantum positional encoding
628#[derive(Debug)]
629struct PositionalEncoding {
630    max_length: usize,
631    embed_dim: usize,
632}
633
634impl PositionalEncoding {
635    fn new(max_length: usize, embed_dim: usize) -> Self {
636        Self {
637            max_length,
638            embed_dim,
639        }
640    }
641
642    fn encode(&self, seq_len: usize) -> Array2<f64> {
643        let mut encoding = Array2::zeros((seq_len, self.embed_dim));
644
645        for pos in 0..seq_len {
646            for i in 0..self.embed_dim {
647                let angle = if i % 2 == 0 {
648                    (pos as f64) / 10000_f64.powf((i as f64) / (self.embed_dim as f64))
649                } else {
650                    (pos as f64) / 10000_f64.powf(((i - 1) as f64) / (self.embed_dim as f64))
651                };
652
653                encoding[[pos, i]] = if i % 2 == 0 { angle.sin() } else { angle.cos() };
654            }
655        }
656
657        encoding
658    }
659}
660
661impl QuantumTransformer {
662    /// Create a new quantum transformer
663    pub fn new(
664        embed_dim: usize,
665        num_layers: usize,
666        num_heads: usize,
667        ff_dim: usize,
668        max_length: usize,
669        dropout_rate: f64,
670    ) -> Self {
671        let blocks = (0..num_layers)
672            .map(|_| QuantumTransformerBlock::new(embed_dim, num_heads, ff_dim, dropout_rate))
673            .collect();
674
675        Self {
676            embed_dim,
677            num_layers,
678            blocks,
679            positional_encoding: PositionalEncoding::new(max_length, embed_dim),
680        }
681    }
682
683    /// Forward pass through transformer
684    pub fn forward(&self, input: &Array2<f64>, mask: Option<&Array2<bool>>) -> Result<Array2<f64>> {
685        let seq_len = input.ncols() / self.embed_dim;
686
687        // Add positional encoding
688        let pos_encoding = self.positional_encoding.encode(seq_len);
689        let mut encoded = input.clone();
690
691        for i in 0..input.nrows() {
692            for s in 0..seq_len {
693                for d in 0..self.embed_dim {
694                    encoded[[i, s * self.embed_dim + d]] += pos_encoding[[s, d]];
695                }
696            }
697        }
698
699        // Pass through transformer blocks
700        let mut output = encoded;
701        for block in &self.blocks {
702            output = block.forward(&output, mask)?;
703        }
704
705        Ok(output)
706    }
707}
708
709#[cfg(test)]
710mod tests {
711    use super::*;
712    use scirs2_core::ndarray::array;
713
714    #[test]
715    fn test_quantum_projection() {
716        let proj = QuantumProjection::new(8, 8);
717        let input = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]);
718
719        let output = proj.forward(&input).unwrap();
720        assert_eq!(output.len(), 8);
721    }
722
723    #[test]
724    fn test_quantum_self_attention() {
725        let attention = QuantumSelfAttention::new(16, 4, 0.1);
726
727        let batch_size = 2;
728        let seq_len = 3;
729        let embed_dim = 16;
730
731        // Initialize input with non-zero values to avoid "Zero norm input" error
732        let mut input = Array2::zeros((batch_size, seq_len * embed_dim));
733        for i in 0..batch_size {
734            for j in 0..seq_len * embed_dim {
735                input[[i, j]] = 0.1 + (i * seq_len * embed_dim + j) as f64 * 0.01;
736            }
737        }
738
739        let output = attention.forward(&input, &input, &input, None).unwrap();
740
741        assert_eq!(output.shape(), &[batch_size, seq_len * embed_dim]);
742    }
743
744    #[test]
745    fn test_quantum_transformer_block() {
746        let block = QuantumTransformerBlock::new(8, 2, 16, 0.1);
747
748        let batch_size = 1;
749        let seq_len = 2;
750        let embed_dim = 8;
751
752        let input = Array2::ones((batch_size, seq_len * embed_dim));
753        let output = block.forward(&input, None).unwrap();
754
755        assert_eq!(output.shape(), &[batch_size, seq_len * embed_dim]);
756    }
757
758    #[test]
759    fn test_positional_encoding() {
760        let pos_enc = PositionalEncoding::new(100, 16);
761        let encoding = pos_enc.encode(10);
762
763        assert_eq!(encoding.shape(), &[10, 16]);
764
765        // Check that different positions have different encodings
766        let pos0 = encoding.row(0);
767        let pos1 = encoding.row(1);
768        let diff: f64 = (&pos1 - &pos0).iter().map(|x| x.abs()).sum();
769        assert!(diff > 0.0);
770    }
771
772    #[test]
773    fn test_quantum_transformer() {
774        let transformer = QuantumTransformer::new(8, 2, 2, 16, 100, 0.1);
775
776        let batch_size = 1;
777        let seq_len = 3;
778        let embed_dim = 8;
779
780        let input = Array2::zeros((batch_size, seq_len * embed_dim));
781        let output = transformer.forward(&input, None).unwrap();
782
783        assert_eq!(output.shape(), &[batch_size, seq_len * embed_dim]);
784    }
785}