quantrs2_ml/
attention.rs

1//! Quantum attention mechanisms for transformer architectures.
2//!
3//! This module implements quantum versions of attention mechanisms including
4//! multi-head attention, cross-attention, and quantum transformer blocks.
5
6use scirs2_core::ndarray::{s, Array1, Array2, Array3, Axis};
7use scirs2_core::Complex64;
8use std::collections::HashMap;
9use std::f64::consts::PI;
10
11use crate::autodiff::DifferentiableParam;
12use crate::error::{MLError, Result};
13use crate::utils::VariationalCircuit;
14use quantrs2_circuit::prelude::*;
15use quantrs2_core::gate::{multi::*, single::*, GateOp};
16
17/// Quantum self-attention mechanism
18#[derive(Debug, Clone)]
19pub struct QuantumSelfAttention {
20    /// Embedding dimension
21    embed_dim: usize,
22    /// Number of attention heads
23    num_heads: usize,
24    /// Head dimension
25    head_dim: usize,
26    /// Number of qubits per head
27    qubits_per_head: usize,
28    /// Query projection circuit
29    query_circuit: QuantumProjection,
30    /// Key projection circuit
31    key_circuit: QuantumProjection,
32    /// Value projection circuit
33    value_circuit: QuantumProjection,
34    /// Output projection circuit
35    output_circuit: QuantumProjection,
36    /// Dropout rate
37    dropout_rate: f64,
38    /// Temperature for attention scaling
39    temperature: f64,
40}
41
42/// Quantum projection layer
43#[derive(Debug, Clone)]
44struct QuantumProjection {
45    /// Input dimension
46    input_dim: usize,
47    /// Output dimension
48    output_dim: usize,
49    /// Number of qubits
50    num_qubits: usize,
51    /// Variational circuit
52    circuit: VariationalCircuit,
53    /// Parameters
54    parameters: HashMap<String, f64>,
55}
56
57impl QuantumProjection {
58    /// Create a new projection layer
59    fn new(input_dim: usize, output_dim: usize) -> Self {
60        let num_qubits = ((input_dim.max(output_dim)) as f64).log2().ceil() as usize;
61        let circuit = Self::build_projection_circuit(num_qubits);
62
63        Self {
64            input_dim,
65            output_dim,
66            num_qubits,
67            circuit,
68            parameters: HashMap::new(),
69        }
70    }
71
72    /// Build the projection circuit
73    fn build_projection_circuit(num_qubits: usize) -> VariationalCircuit {
74        let mut circuit = VariationalCircuit::new(num_qubits);
75
76        // Layer 1: Feature encoding
77        for q in 0..num_qubits {
78            circuit.add_gate("RY", vec![q], vec![format!("encode_{}", q)]);
79        }
80
81        // Layer 2: Entangling layer
82        for q in 0..num_qubits - 1 {
83            circuit.add_gate("CNOT", vec![q, q + 1], vec![]);
84        }
85        if num_qubits > 2 {
86            circuit.add_gate("CNOT", vec![num_qubits - 1, 0], vec![]);
87        }
88
89        // Layer 3: Parameterized rotations
90        for q in 0..num_qubits {
91            circuit.add_gate("RX", vec![q], vec![format!("rx_{}", q)]);
92            circuit.add_gate("RZ", vec![q], vec![format!("rz_{}", q)]);
93        }
94
95        // Layer 4: Second entangling layer
96        for q in (0..num_qubits - 1).step_by(2) {
97            circuit.add_gate("CZ", vec![q, q + 1], vec![]);
98        }
99        for q in (1..num_qubits - 1).step_by(2) {
100            circuit.add_gate("CZ", vec![q, q + 1], vec![]);
101        }
102
103        // Layer 5: Final rotations
104        for q in 0..num_qubits {
105            circuit.add_gate("RY", vec![q], vec![format!("final_{}", q)]);
106        }
107
108        circuit
109    }
110
111    /// Project input through the quantum circuit
112    fn forward(&self, input: &Array1<f64>) -> Result<Array1<f64>> {
113        // Encode input
114        let encoded = self.encode_input(input)?;
115
116        // Execute circuit (simplified)
117        let output_state = self.execute_circuit(&encoded)?;
118
119        // Decode output
120        self.decode_output(&output_state)
121    }
122
123    /// Encode classical input to quantum state
124    fn encode_input(&self, input: &Array1<f64>) -> Result<Vec<Complex64>> {
125        let state_dim = 2_usize.pow(self.num_qubits as u32);
126        let mut quantum_state = vec![Complex64::new(0.0, 0.0); state_dim];
127
128        // Amplitude encoding
129        let norm: f64 = input.iter().map(|x| x * x).sum::<f64>().sqrt();
130        if norm < 1e-10 {
131            return Err(MLError::InvalidInput("Zero norm input".to_string()));
132        }
133
134        for (i, &val) in input.iter().enumerate() {
135            if i < state_dim {
136                quantum_state[i] = Complex64::new(val / norm, 0.0);
137            }
138        }
139
140        Ok(quantum_state)
141    }
142
143    /// Execute the quantum circuit
144    fn execute_circuit(&self, input_state: &[Complex64]) -> Result<Vec<Complex64>> {
145        // Simplified circuit execution
146        // In practice, would use actual quantum simulation
147        let state_dim = input_state.len();
148        let mut output_state = input_state.to_vec();
149
150        // Apply some transformation
151        for i in 0..state_dim {
152            let phase = (i as f64) * 0.1;
153            output_state[i] *= Complex64::new(phase.cos(), phase.sin());
154        }
155
156        Ok(output_state)
157    }
158
159    /// Decode quantum state to classical output
160    fn decode_output(&self, quantum_state: &[Complex64]) -> Result<Array1<f64>> {
161        let mut output = Array1::zeros(self.output_dim);
162
163        // Extract amplitudes
164        for i in 0..self.output_dim.min(quantum_state.len()) {
165            output[i] = quantum_state[i].norm();
166        }
167
168        Ok(output)
169    }
170}
171
172impl QuantumSelfAttention {
173    /// Create a new quantum self-attention layer
174    pub fn new(embed_dim: usize, num_heads: usize, dropout_rate: f64) -> Self {
175        assert!(
176            embed_dim % num_heads == 0,
177            "embed_dim must be divisible by num_heads"
178        );
179
180        let head_dim = embed_dim / num_heads;
181        let qubits_per_head = (head_dim as f64).log2().ceil() as usize;
182
183        Self {
184            embed_dim,
185            num_heads,
186            head_dim,
187            qubits_per_head,
188            query_circuit: QuantumProjection::new(embed_dim, embed_dim),
189            key_circuit: QuantumProjection::new(embed_dim, embed_dim),
190            value_circuit: QuantumProjection::new(embed_dim, embed_dim),
191            output_circuit: QuantumProjection::new(embed_dim, embed_dim),
192            dropout_rate,
193            temperature: (head_dim as f64).sqrt(),
194        }
195    }
196
197    /// Forward pass through attention layer
198    pub fn forward(
199        &self,
200        query: &Array2<f64>,
201        key: &Array2<f64>,
202        value: &Array2<f64>,
203        mask: Option<&Array2<bool>>,
204    ) -> Result<Array2<f64>> {
205        let batch_size = query.nrows();
206        let seq_len = query.ncols() / self.embed_dim;
207
208        // Project Q, K, V
209        let q = self.project_to_heads(query, &self.query_circuit)?;
210        let k = self.project_to_heads(key, &self.key_circuit)?;
211        let v = self.project_to_heads(value, &self.value_circuit)?;
212
213        // Compute attention scores
214        let attention_scores = self.compute_attention_scores(&q, &k)?;
215
216        // Apply mask if provided
217        let masked_scores = if let Some(mask) = mask {
218            self.apply_mask(&attention_scores, mask)?
219        } else {
220            attention_scores
221        };
222
223        // Apply softmax
224        let attention_weights = self.quantum_softmax(&masked_scores)?;
225
226        // Apply attention to values
227        let attended_values = self.apply_attention(&attention_weights, &v)?;
228
229        // Concatenate heads and project output
230        self.project_output(&attended_values)
231    }
232
233    /// Project input to multi-head format
234    fn project_to_heads(
235        &self,
236        input: &Array2<f64>,
237        projection: &QuantumProjection,
238    ) -> Result<Array3<f64>> {
239        let batch_size = input.nrows();
240        let seq_len = input.ncols() / self.embed_dim;
241
242        let mut output = Array3::zeros((batch_size, self.num_heads, seq_len * self.head_dim));
243
244        for b in 0..batch_size {
245            for s in 0..seq_len {
246                let start = s * self.embed_dim;
247                let end = start + self.embed_dim;
248                let input_vec = input.row(b).slice(s![start..end]).to_owned();
249
250                let projected = projection.forward(&input_vec)?;
251
252                // Split into heads
253                for h in 0..self.num_heads {
254                    let head_start = h * self.head_dim;
255                    let head_end = head_start + self.head_dim;
256
257                    for i in 0..self.head_dim {
258                        if head_start + i < projected.len() {
259                            output[[b, h, s * self.head_dim + i]] = projected[head_start + i];
260                        }
261                    }
262                }
263            }
264        }
265
266        Ok(output)
267    }
268
269    /// Compute quantum attention scores
270    fn compute_attention_scores(
271        &self,
272        query: &Array3<f64>,
273        key: &Array3<f64>,
274    ) -> Result<Array3<f64>> {
275        let batch_size = query.shape()[0];
276        let seq_len = query.shape()[2] / self.head_dim;
277
278        let mut scores = Array3::zeros((batch_size, self.num_heads, seq_len * seq_len));
279
280        // Quantum dot product attention
281        for b in 0..batch_size {
282            for h in 0..self.num_heads {
283                for i in 0..seq_len {
284                    for j in 0..seq_len {
285                        let q_start = i * self.head_dim;
286                        let q_end = q_start + self.head_dim;
287                        let k_start = j * self.head_dim;
288                        let k_end = k_start + self.head_dim;
289
290                        let q_vec = query.slice(s![b, h, q_start..q_end]);
291                        let k_vec = key.slice(s![b, h, k_start..k_end]);
292
293                        // Quantum inner product
294                        let score =
295                            self.quantum_inner_product(&q_vec.to_owned(), &k_vec.to_owned())?;
296                        scores[[b, h, i * seq_len + j]] = score / self.temperature;
297                    }
298                }
299            }
300        }
301
302        Ok(scores)
303    }
304
305    /// Compute quantum inner product
306    fn quantum_inner_product(&self, vec1: &Array1<f64>, vec2: &Array1<f64>) -> Result<f64> {
307        // Build quantum circuit for inner product
308        let num_qubits = self.qubits_per_head * 2 + 1; // Extra qubit for measurement
309        let mut circuit = VariationalCircuit::new(num_qubits);
310
311        // Encode vectors
312        for i in 0..self.qubits_per_head {
313            if i < vec1.len() {
314                let angle1 = vec1[i] * PI;
315                circuit.add_gate("RY", vec![i], vec![angle1.to_string()]);
316            }
317            if i < vec2.len() {
318                let angle2 = vec2[i] * PI;
319                circuit.add_gate(
320                    "RY",
321                    vec![i + self.qubits_per_head],
322                    vec![angle2.to_string()],
323                );
324            }
325        }
326
327        // Hadamard on ancilla
328        circuit.add_gate("H", vec![num_qubits - 1], vec![]);
329
330        // Controlled swap test
331        for i in 0..self.qubits_per_head {
332            circuit.add_gate(
333                "CSWAP",
334                vec![num_qubits - 1, i, i + self.qubits_per_head],
335                vec![],
336            );
337        }
338
339        // Hadamard on ancilla
340        circuit.add_gate("H", vec![num_qubits - 1], vec![]);
341
342        // Measurement probability gives inner product
343        // Simplified: return dot product
344        Ok(vec1.dot(vec2))
345    }
346
347    /// Quantum softmax implementation
348    fn quantum_softmax(&self, scores: &Array3<f64>) -> Result<Array3<f64>> {
349        let mut output = scores.clone();
350
351        // Apply quantum softmax per attention head
352        for b in 0..scores.shape()[0] {
353            for h in 0..scores.shape()[1] {
354                let head_scores = scores.slice(s![b, h, ..]);
355                let seq_len = (head_scores.len() as f64).sqrt() as usize;
356
357                for i in 0..seq_len {
358                    let start = i * seq_len;
359                    let end = start + seq_len;
360                    let row_scores = head_scores.slice(s![start..end]);
361
362                    // Quantum softmax circuit
363                    let softmax_vals = self.quantum_softmax_circuit(&row_scores.to_owned())?;
364
365                    for j in 0..seq_len {
366                        output[[b, h, start + j]] = softmax_vals[j];
367                    }
368                }
369            }
370        }
371
372        Ok(output)
373    }
374
375    /// Quantum circuit for softmax
376    fn quantum_softmax_circuit(&self, logits: &Array1<f64>) -> Result<Vec<f64>> {
377        // Classical softmax for now
378        let max_logit = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
379        let exp_logits: Vec<f64> = logits.iter().map(|&x| (x - max_logit).exp()).collect();
380        let sum_exp: f64 = exp_logits.iter().sum();
381
382        Ok(exp_logits.into_iter().map(|x| x / sum_exp).collect())
383    }
384
385    /// Apply attention weights to values
386    fn apply_attention(&self, weights: &Array3<f64>, values: &Array3<f64>) -> Result<Array3<f64>> {
387        let batch_size = weights.shape()[0];
388        let num_heads = weights.shape()[1];
389        let seq_len = (weights.shape()[2] as f64).sqrt() as usize;
390
391        let mut output = Array3::zeros((batch_size, num_heads, seq_len * self.head_dim));
392
393        for b in 0..batch_size {
394            for h in 0..num_heads {
395                for i in 0..seq_len {
396                    for j in 0..seq_len {
397                        let weight = weights[[b, h, i * seq_len + j]];
398
399                        for d in 0..self.head_dim {
400                            output[[b, h, i * self.head_dim + d]] +=
401                                weight * values[[b, h, j * self.head_dim + d]];
402                        }
403                    }
404                }
405            }
406        }
407
408        Ok(output)
409    }
410
411    /// Apply attention mask
412    fn apply_mask(&self, scores: &Array3<f64>, mask: &Array2<bool>) -> Result<Array3<f64>> {
413        let mut masked_scores = scores.clone();
414
415        for b in 0..scores.shape()[0] {
416            for h in 0..scores.shape()[1] {
417                for (idx, &is_masked) in mask.iter().enumerate() {
418                    if is_masked && idx < scores.shape()[2] {
419                        masked_scores[[b, h, idx]] = -1e9; // Large negative value
420                    }
421                }
422            }
423        }
424
425        Ok(masked_scores)
426    }
427
428    /// Project concatenated heads to output
429    fn project_output(&self, attended: &Array3<f64>) -> Result<Array2<f64>> {
430        let batch_size = attended.shape()[0];
431        let seq_len = attended.shape()[2] / self.head_dim;
432
433        let mut output = Array2::zeros((batch_size, seq_len * self.embed_dim));
434
435        for b in 0..batch_size {
436            for s in 0..seq_len {
437                // Concatenate heads
438                let mut concat = Array1::zeros(self.embed_dim);
439                for h in 0..self.num_heads {
440                    for d in 0..self.head_dim {
441                        concat[h * self.head_dim + d] = attended[[b, h, s * self.head_dim + d]];
442                    }
443                }
444
445                // Project through output circuit
446                let projected = self.output_circuit.forward(&concat)?;
447
448                for d in 0..self.embed_dim {
449                    output[[b, s * self.embed_dim + d]] = projected[d];
450                }
451            }
452        }
453
454        Ok(output)
455    }
456}
457
458/// Quantum transformer block
459#[derive(Debug)]
460pub struct QuantumTransformerBlock {
461    /// Self-attention layer
462    self_attention: QuantumSelfAttention,
463    /// Feed-forward dimension
464    ff_dim: usize,
465    /// First feed-forward layer
466    ff1: QuantumFeedForward,
467    /// Second feed-forward layer
468    ff2: QuantumFeedForward,
469    /// Layer normalization (classical)
470    layer_norm1: LayerNorm,
471    layer_norm2: LayerNorm,
472    /// Dropout rate
473    dropout_rate: f64,
474}
475
476/// Quantum feed-forward layer
477#[derive(Debug)]
478struct QuantumFeedForward {
479    input_dim: usize,
480    output_dim: usize,
481    circuit: VariationalCircuit,
482}
483
484impl QuantumFeedForward {
485    fn new(input_dim: usize, output_dim: usize) -> Self {
486        let num_qubits = ((input_dim.max(output_dim)) as f64).log2().ceil() as usize;
487        let circuit = Self::build_ff_circuit(num_qubits);
488
489        Self {
490            input_dim,
491            output_dim,
492            circuit,
493        }
494    }
495
496    fn build_ff_circuit(num_qubits: usize) -> VariationalCircuit {
497        let mut circuit = VariationalCircuit::new(num_qubits);
498
499        // Dense connectivity pattern
500        for layer in 0..3 {
501            // Rotation layer
502            for q in 0..num_qubits {
503                circuit.add_gate("RY", vec![q], vec![format!("ff_ry_{}_{}", layer, q)]);
504                circuit.add_gate("RZ", vec![q], vec![format!("ff_rz_{}_{}", layer, q)]);
505            }
506
507            // All-to-all entangling
508            for i in 0..num_qubits {
509                for j in i + 1..num_qubits {
510                    circuit.add_gate("CZ", vec![i, j], vec![]);
511                }
512            }
513        }
514
515        circuit
516    }
517
518    fn forward(&self, input: &Array1<f64>) -> Result<Array1<f64>> {
519        // Simplified forward pass
520        let mut output = Array1::zeros(self.output_dim);
521
522        // Apply non-linear transformation
523        for i in 0..self.output_dim {
524            if i < input.len() {
525                output[i] = (input[i] * 2.0 * PI).sin() * 0.5 + 0.5;
526            }
527        }
528
529        Ok(output)
530    }
531}
532
533/// Classical layer normalization
534#[derive(Debug)]
535struct LayerNorm {
536    normalized_shape: usize,
537    epsilon: f64,
538}
539
540impl LayerNorm {
541    fn new(normalized_shape: usize) -> Self {
542        Self {
543            normalized_shape,
544            epsilon: 1e-5,
545        }
546    }
547
548    fn forward(&self, input: &Array2<f64>) -> Array2<f64> {
549        let mean = input
550            .mean_axis(Axis(1))
551            .expect("Input array should not be empty for mean computation");
552        let variance = input.var_axis(Axis(1), 0.0);
553
554        let mut output = input.clone();
555        for i in 0..input.nrows() {
556            let std = (variance[i] + self.epsilon).sqrt();
557            for j in 0..input.ncols() {
558                output[[i, j]] = (input[[i, j]] - mean[i]) / std;
559            }
560        }
561
562        output
563    }
564}
565
566impl QuantumTransformerBlock {
567    /// Create a new transformer block
568    pub fn new(embed_dim: usize, num_heads: usize, ff_dim: usize, dropout_rate: f64) -> Self {
569        Self {
570            self_attention: QuantumSelfAttention::new(embed_dim, num_heads, dropout_rate),
571            ff_dim,
572            ff1: QuantumFeedForward::new(embed_dim, ff_dim),
573            ff2: QuantumFeedForward::new(ff_dim, embed_dim),
574            layer_norm1: LayerNorm::new(embed_dim),
575            layer_norm2: LayerNorm::new(embed_dim),
576            dropout_rate,
577        }
578    }
579
580    /// Forward pass through transformer block
581    pub fn forward(&self, input: &Array2<f64>, mask: Option<&Array2<bool>>) -> Result<Array2<f64>> {
582        // Self-attention with residual connection
583        let attended = self.self_attention.forward(input, input, input, mask)?;
584        let residual1 = &attended + input;
585        let norm1 = self.layer_norm1.forward(&residual1);
586
587        // Feed-forward with residual connection
588        let batch_size = norm1.nrows();
589        let seq_dim = norm1.ncols();
590        let seq_len = seq_dim / self.self_attention.embed_dim;
591
592        let mut ff_output = Array2::zeros((batch_size, seq_dim));
593
594        for b in 0..batch_size {
595            for s in 0..seq_len {
596                let start = s * self.self_attention.embed_dim;
597                let end = start + self.self_attention.embed_dim;
598
599                let input_slice = norm1.slice(s![b, start..end]).to_owned();
600                let hidden = self.ff1.forward(&input_slice)?;
601                let output = self.ff2.forward(&hidden)?;
602
603                for i in 0..self.self_attention.embed_dim {
604                    ff_output[[b, start + i]] = output[i];
605                }
606            }
607        }
608
609        let residual2 = &ff_output + &norm1;
610        let output = self.layer_norm2.forward(&residual2);
611
612        Ok(output)
613    }
614}
615
616/// Quantum transformer model
617#[derive(Debug)]
618pub struct QuantumTransformer {
619    /// Embedding dimension
620    embed_dim: usize,
621    /// Number of transformer blocks
622    num_layers: usize,
623    /// Transformer blocks
624    blocks: Vec<QuantumTransformerBlock>,
625    /// Positional encoding
626    positional_encoding: PositionalEncoding,
627}
628
629/// Quantum positional encoding
630#[derive(Debug)]
631struct PositionalEncoding {
632    max_length: usize,
633    embed_dim: usize,
634}
635
636impl PositionalEncoding {
637    fn new(max_length: usize, embed_dim: usize) -> Self {
638        Self {
639            max_length,
640            embed_dim,
641        }
642    }
643
644    fn encode(&self, seq_len: usize) -> Array2<f64> {
645        let mut encoding = Array2::zeros((seq_len, self.embed_dim));
646
647        for pos in 0..seq_len {
648            for i in 0..self.embed_dim {
649                let angle = if i % 2 == 0 {
650                    (pos as f64) / 10000_f64.powf((i as f64) / (self.embed_dim as f64))
651                } else {
652                    (pos as f64) / 10000_f64.powf(((i - 1) as f64) / (self.embed_dim as f64))
653                };
654
655                encoding[[pos, i]] = if i % 2 == 0 { angle.sin() } else { angle.cos() };
656            }
657        }
658
659        encoding
660    }
661}
662
663impl QuantumTransformer {
664    /// Create a new quantum transformer
665    pub fn new(
666        embed_dim: usize,
667        num_layers: usize,
668        num_heads: usize,
669        ff_dim: usize,
670        max_length: usize,
671        dropout_rate: f64,
672    ) -> Self {
673        let blocks = (0..num_layers)
674            .map(|_| QuantumTransformerBlock::new(embed_dim, num_heads, ff_dim, dropout_rate))
675            .collect();
676
677        Self {
678            embed_dim,
679            num_layers,
680            blocks,
681            positional_encoding: PositionalEncoding::new(max_length, embed_dim),
682        }
683    }
684
685    /// Forward pass through transformer
686    pub fn forward(&self, input: &Array2<f64>, mask: Option<&Array2<bool>>) -> Result<Array2<f64>> {
687        let seq_len = input.ncols() / self.embed_dim;
688
689        // Add positional encoding
690        let pos_encoding = self.positional_encoding.encode(seq_len);
691        let mut encoded = input.clone();
692
693        for i in 0..input.nrows() {
694            for s in 0..seq_len {
695                for d in 0..self.embed_dim {
696                    encoded[[i, s * self.embed_dim + d]] += pos_encoding[[s, d]];
697                }
698            }
699        }
700
701        // Pass through transformer blocks
702        let mut output = encoded;
703        for block in &self.blocks {
704            output = block.forward(&output, mask)?;
705        }
706
707        Ok(output)
708    }
709}
710
711#[cfg(test)]
712mod tests {
713    use super::*;
714    use scirs2_core::ndarray::array;
715
716    #[test]
717    fn test_quantum_projection() {
718        let proj = QuantumProjection::new(8, 8);
719        let input = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]);
720
721        let output = proj
722            .forward(&input)
723            .expect("Projection forward should succeed");
724        assert_eq!(output.len(), 8);
725    }
726
727    #[test]
728    fn test_quantum_self_attention() {
729        let attention = QuantumSelfAttention::new(16, 4, 0.1);
730
731        let batch_size = 2;
732        let seq_len = 3;
733        let embed_dim = 16;
734
735        // Initialize input with non-zero values to avoid "Zero norm input" error
736        let mut input = Array2::zeros((batch_size, seq_len * embed_dim));
737        for i in 0..batch_size {
738            for j in 0..seq_len * embed_dim {
739                input[[i, j]] = 0.1 + (i * seq_len * embed_dim + j) as f64 * 0.01;
740            }
741        }
742
743        let output = attention
744            .forward(&input, &input, &input, None)
745            .expect("Attention forward should succeed");
746
747        assert_eq!(output.shape(), &[batch_size, seq_len * embed_dim]);
748    }
749
750    #[test]
751    fn test_quantum_transformer_block() {
752        let block = QuantumTransformerBlock::new(8, 2, 16, 0.1);
753
754        let batch_size = 1;
755        let seq_len = 2;
756        let embed_dim = 8;
757
758        let input = Array2::ones((batch_size, seq_len * embed_dim));
759        let output = block
760            .forward(&input, None)
761            .expect("Transformer block forward should succeed");
762
763        assert_eq!(output.shape(), &[batch_size, seq_len * embed_dim]);
764    }
765
766    #[test]
767    fn test_positional_encoding() {
768        let pos_enc = PositionalEncoding::new(100, 16);
769        let encoding = pos_enc.encode(10);
770
771        assert_eq!(encoding.shape(), &[10, 16]);
772
773        // Check that different positions have different encodings
774        let pos0 = encoding.row(0);
775        let pos1 = encoding.row(1);
776        let diff: f64 = (&pos1 - &pos0).iter().map(|x| x.abs()).sum();
777        assert!(diff > 0.0);
778    }
779
780    #[test]
781    fn test_quantum_transformer() {
782        let transformer = QuantumTransformer::new(8, 2, 2, 16, 100, 0.1);
783
784        let batch_size = 1;
785        let seq_len = 3;
786        let embed_dim = 8;
787
788        let input = Array2::zeros((batch_size, seq_len * embed_dim));
789        let output = transformer
790            .forward(&input, None)
791            .expect("Transformer forward should succeed");
792
793        assert_eq!(output.shape(), &[batch_size, seq_len * embed_dim]);
794    }
795}