quantrs2_ml/
quantum_transformer.rs

1//! Quantum Transformer Architectures
2//!
3//! This module implements quantum transformer models with quantum attention mechanisms,
4//! position encoding, and multi-head attention for processing quantum and classical data
5//! in transformer-style architectures.
6
7use crate::error::{MLError, Result};
8use crate::optimization::OptimizationMethod;
9use crate::qnn::{QNNLayerType, QuantumNeuralNetwork};
10use quantrs2_circuit::builder::{Circuit, Simulator};
11use quantrs2_core::gate::{multi::*, single::*, GateOp};
12use quantrs2_sim::statevector::StateVectorSimulator;
13use scirs2_core::ndarray::{s, Array1, Array2, Array3, Array4, Axis};
14use std::collections::HashMap;
15use std::f64::consts::PI;
16
17/// Quantum transformer model configuration
18#[derive(Debug, Clone)]
19pub struct QuantumTransformerConfig {
20    /// Model dimension (d_model)
21    pub model_dim: usize,
22
23    /// Number of attention heads
24    pub num_heads: usize,
25
26    /// Feedforward dimension
27    pub ff_dim: usize,
28
29    /// Number of transformer layers
30    pub num_layers: usize,
31
32    /// Maximum sequence length
33    pub max_seq_len: usize,
34
35    /// Number of qubits for quantum computation
36    pub num_qubits: usize,
37
38    /// Dropout rate
39    pub dropout_rate: f64,
40
41    /// Attention mechanism type
42    pub attention_type: QuantumAttentionType,
43
44    /// Position encoding type
45    pub position_encoding: PositionEncodingType,
46}
47
48/// Types of quantum attention mechanisms
49#[derive(Debug, Clone)]
50pub enum QuantumAttentionType {
51    /// Full quantum attention with entanglement
52    FullQuantum,
53
54    /// Hybrid quantum-classical attention
55    HybridQuantumClassical,
56
57    /// Variational quantum attention
58    VariationalQuantum,
59
60    /// Quantum-enhanced multi-head attention
61    QuantumEnhancedMultiHead,
62
63    /// Quantum self-attention with superposition
64    QuantumSelfAttention,
65}
66
67/// Position encoding types for quantum transformers
68#[derive(Debug, Clone)]
69pub enum PositionEncodingType {
70    /// Sinusoidal position encoding
71    Sinusoidal,
72
73    /// Quantum position encoding with phase rotation
74    QuantumPhase,
75
76    /// Learnable quantum position encoding
77    LearnableQuantum,
78
79    /// Relative position encoding
80    Relative,
81
82    /// Rotary position embedding (RoPE)
83    Rotary,
84}
85
86/// Quantum multi-head attention module
87#[derive(Debug, Clone)]
88pub struct QuantumMultiHeadAttention {
89    /// Number of attention heads
90    num_heads: usize,
91
92    /// Model dimension
93    model_dim: usize,
94
95    /// Head dimension
96    head_dim: usize,
97
98    /// Query projection layers
99    query_layers: Vec<QuantumNeuralNetwork>,
100
101    /// Key projection layers
102    key_layers: Vec<QuantumNeuralNetwork>,
103
104    /// Value projection layers
105    value_layers: Vec<QuantumNeuralNetwork>,
106
107    /// Output projection
108    output_projection: QuantumNeuralNetwork,
109
110    /// Attention type
111    attention_type: QuantumAttentionType,
112
113    /// Quantum circuit for attention computation
114    attention_circuit: Circuit<16>,
115}
116
117/// Quantum position encoding module
118#[derive(Debug, Clone)]
119pub struct QuantumPositionEncoding {
120    /// Encoding type
121    encoding_type: PositionEncodingType,
122
123    /// Model dimension
124    model_dim: usize,
125
126    /// Maximum sequence length
127    max_seq_len: usize,
128
129    /// Learnable parameters (for learnable encodings)
130    learnable_params: Option<Array2<f64>>,
131
132    /// Quantum circuits for position encoding
133    encoding_circuits: Vec<Circuit<16>>,
134}
135
136/// Quantum feedforward network
137#[derive(Debug, Clone)]
138pub struct QuantumFeedForward {
139    /// Input dimension
140    input_dim: usize,
141
142    /// Hidden dimension
143    hidden_dim: usize,
144
145    /// Output dimension
146    output_dim: usize,
147
148    /// First layer
149    layer1: QuantumNeuralNetwork,
150
151    /// Second layer
152    layer2: QuantumNeuralNetwork,
153
154    /// Activation function type
155    activation: ActivationType,
156
157    /// Dropout rate
158    dropout_rate: f64,
159}
160
161/// Activation function types for quantum networks
162#[derive(Debug, Clone)]
163pub enum ActivationType {
164    /// Quantum ReLU using amplitude encoding
165    QuantumReLU,
166
167    /// Quantum GELU approximation
168    QuantumGELU,
169
170    /// Quantum Swish activation
171    QuantumSwish,
172
173    /// Parameterized quantum activation
174    ParameterizedQuantum,
175
176    /// Classical activation applied to measurement outcomes
177    ClassicalHybrid,
178}
179
180/// Single quantum transformer layer
181#[derive(Debug, Clone)]
182pub struct QuantumTransformerLayer {
183    /// Multi-head attention module
184    attention: QuantumMultiHeadAttention,
185
186    /// Feedforward network
187    feedforward: QuantumFeedForward,
188
189    /// Layer normalization parameters
190    norm1_scale: Array1<f64>,
191    norm1_bias: Array1<f64>,
192    norm2_scale: Array1<f64>,
193    norm2_bias: Array1<f64>,
194
195    /// Model dimension
196    model_dim: usize,
197
198    /// Dropout rate
199    dropout_rate: f64,
200}
201
202/// Main quantum transformer model
203#[derive(Debug, Clone)]
204pub struct QuantumTransformer {
205    /// Model configuration
206    config: QuantumTransformerConfig,
207
208    /// Position encoding module
209    position_encoding: QuantumPositionEncoding,
210
211    /// Transformer layers
212    layers: Vec<QuantumTransformerLayer>,
213
214    /// Input embedding layer
215    input_embedding: QuantumNeuralNetwork,
216
217    /// Output projection layer
218    output_projection: QuantumNeuralNetwork,
219
220    /// Layer normalization at output
221    final_norm_scale: Array1<f64>,
222    final_norm_bias: Array1<f64>,
223}
224
225/// Attention computation result
226#[derive(Debug, Clone)]
227pub struct AttentionOutput {
228    /// Attention weights
229    pub attention_weights: Array3<f64>,
230
231    /// Output values
232    pub output: Array3<f64>,
233
234    /// Quantum state information
235    pub quantum_info: QuantumAttentionInfo,
236}
237
238/// Quantum attention information
239#[derive(Debug, Clone)]
240pub struct QuantumAttentionInfo {
241    /// Entanglement measures between positions
242    pub entanglement_matrix: Array2<f64>,
243
244    /// Quantum coherence scores
245    pub coherence_scores: Array1<f64>,
246
247    /// Superposition amplitudes
248    pub superposition_amplitudes: Array2<f64>,
249
250    /// Measurement probabilities
251    pub measurement_probs: Array3<f64>,
252}
253
254impl QuantumTransformerConfig {
255    /// Create default transformer configuration
256    pub fn default() -> Self {
257        Self {
258            model_dim: 512,
259            num_heads: 8,
260            ff_dim: 2048,
261            num_layers: 6,
262            max_seq_len: 512,
263            num_qubits: 10,
264            dropout_rate: 0.1,
265            attention_type: QuantumAttentionType::HybridQuantumClassical,
266            position_encoding: PositionEncodingType::QuantumPhase,
267        }
268    }
269
270    /// Create configuration for large model
271    pub fn large() -> Self {
272        Self {
273            model_dim: 1024,
274            num_heads: 16,
275            ff_dim: 4096,
276            num_layers: 12,
277            max_seq_len: 1024,
278            num_qubits: 16,
279            dropout_rate: 0.1,
280            attention_type: QuantumAttentionType::FullQuantum,
281            position_encoding: PositionEncodingType::LearnableQuantum,
282        }
283    }
284
285    /// Create configuration for small/efficient model
286    pub fn small() -> Self {
287        Self {
288            model_dim: 256,
289            num_heads: 4,
290            ff_dim: 1024,
291            num_layers: 4,
292            max_seq_len: 256,
293            num_qubits: 8,
294            dropout_rate: 0.1,
295            attention_type: QuantumAttentionType::VariationalQuantum,
296            position_encoding: PositionEncodingType::Sinusoidal,
297        }
298    }
299}
300
301impl QuantumMultiHeadAttention {
302    /// Create new quantum multi-head attention module
303    pub fn new(
304        num_heads: usize,
305        model_dim: usize,
306        attention_type: QuantumAttentionType,
307        num_qubits: usize,
308    ) -> Result<Self> {
309        if model_dim % num_heads != 0 {
310            return Err(MLError::ConfigurationError(
311                "Model dimension must be divisible by number of heads".to_string(),
312            ));
313        }
314
315        let head_dim = model_dim / num_heads;
316
317        // Create projection layers for each head
318        let mut query_layers = Vec::new();
319        let mut key_layers = Vec::new();
320        let mut value_layers = Vec::new();
321
322        for _ in 0..num_heads {
323            let q_layers = vec![
324                QNNLayerType::EncodingLayer {
325                    num_features: model_dim,
326                },
327                QNNLayerType::VariationalLayer {
328                    num_params: head_dim * 2,
329                },
330                QNNLayerType::MeasurementLayer {
331                    measurement_basis: "computational".to_string(),
332                },
333            ];
334            query_layers.push(QuantumNeuralNetwork::new(
335                q_layers, num_qubits, model_dim, head_dim,
336            )?);
337
338            let k_layers = vec![
339                QNNLayerType::EncodingLayer {
340                    num_features: model_dim,
341                },
342                QNNLayerType::VariationalLayer {
343                    num_params: head_dim * 2,
344                },
345                QNNLayerType::MeasurementLayer {
346                    measurement_basis: "computational".to_string(),
347                },
348            ];
349            key_layers.push(QuantumNeuralNetwork::new(
350                k_layers, num_qubits, model_dim, head_dim,
351            )?);
352
353            let v_layers = vec![
354                QNNLayerType::EncodingLayer {
355                    num_features: model_dim,
356                },
357                QNNLayerType::VariationalLayer {
358                    num_params: head_dim * 2,
359                },
360                QNNLayerType::MeasurementLayer {
361                    measurement_basis: "computational".to_string(),
362                },
363            ];
364            value_layers.push(QuantumNeuralNetwork::new(
365                v_layers, num_qubits, model_dim, head_dim,
366            )?);
367        }
368
369        // Output projection layer
370        let out_layers = vec![
371            QNNLayerType::EncodingLayer {
372                num_features: model_dim,
373            },
374            QNNLayerType::VariationalLayer {
375                num_params: model_dim,
376            },
377            QNNLayerType::MeasurementLayer {
378                measurement_basis: "computational".to_string(),
379            },
380        ];
381        let output_projection =
382            QuantumNeuralNetwork::new(out_layers, num_qubits, model_dim, model_dim)?;
383
384        // Create attention computation circuit
385        let attention_circuit = Self::create_attention_circuit(num_qubits, &attention_type)?;
386
387        Ok(Self {
388            num_heads,
389            model_dim,
390            head_dim,
391            query_layers,
392            key_layers,
393            value_layers,
394            output_projection,
395            attention_type,
396            attention_circuit,
397        })
398    }
399
400    /// Create quantum circuit for attention computation
401    fn create_attention_circuit(
402        num_qubits: usize,
403        attention_type: &QuantumAttentionType,
404    ) -> Result<Circuit<16>> {
405        let mut circuit = Circuit::<16>::new();
406
407        match attention_type {
408            QuantumAttentionType::FullQuantum => {
409                // Create fully quantum attention circuit
410                // Initialize superposition
411                for i in 0..num_qubits.min(16) {
412                    circuit.h(i);
413                }
414
415                // Add entangling gates
416                for i in 0..num_qubits.min(15) {
417                    circuit.cnot(i, i + 1);
418                }
419
420                // Add parameterized rotations
421                for i in 0..num_qubits.min(16) {
422                    circuit.ry(i, 0.0); // Will be parameterized
423                }
424            }
425
426            QuantumAttentionType::HybridQuantumClassical => {
427                // Hybrid approach circuit
428                let half_qubits = (num_qubits / 2).min(8);
429                for i in 0..half_qubits {
430                    circuit.h(i);
431                }
432
433                for i in 0..half_qubits - 1 {
434                    circuit.cnot(i, i + 1);
435                }
436
437                for i in 0..num_qubits.min(16) {
438                    circuit.rx(i, 0.0); // Will be parameterized
439                }
440            }
441
442            QuantumAttentionType::VariationalQuantum => {
443                // Variational quantum attention circuit
444                for layer in 0..3 {
445                    for i in 0..num_qubits.min(16) {
446                        circuit.ry(i, 0.0); // Will be parameterized
447                        circuit.rz(i, 0.0); // Will be parameterized
448                    }
449
450                    for i in 0..num_qubits.min(15) {
451                        circuit.cnot(i, i + 1);
452                    }
453                }
454            }
455
456            _ => {
457                // Default attention circuit
458                for i in 0..num_qubits.min(16) {
459                    circuit.h(i);
460                    circuit.ry(i, 0.0); // Will be parameterized
461                }
462            }
463        }
464
465        Ok(circuit)
466    }
467
468    /// Forward pass through quantum multi-head attention
469    pub fn forward(
470        &self,
471        query: &Array3<f64>, // [batch_size, seq_len, model_dim]
472        key: &Array3<f64>,
473        value: &Array3<f64>,
474        attention_mask: Option<&Array3<bool>>,
475    ) -> Result<AttentionOutput> {
476        let (batch_size, seq_len, model_dim) = query.dim();
477
478        if model_dim != self.model_dim {
479            return Err(MLError::DimensionMismatch(format!(
480                "Expected model_dim {}, got {}",
481                self.model_dim, model_dim
482            )));
483        }
484
485        let mut all_head_outputs = Vec::new();
486        let mut attention_weights_all = Array3::zeros((batch_size, seq_len, seq_len));
487        let mut quantum_info = QuantumAttentionInfo {
488            entanglement_matrix: Array2::zeros((seq_len, seq_len)),
489            coherence_scores: Array1::zeros(seq_len),
490            superposition_amplitudes: Array2::zeros((seq_len, self.head_dim)),
491            measurement_probs: Array3::zeros((batch_size, seq_len, self.head_dim)),
492        };
493
494        // Process each attention head
495        for head_idx in 0..self.num_heads {
496            let head_output = self.compute_head_attention(
497                query,
498                key,
499                value,
500                head_idx,
501                attention_mask,
502                &mut quantum_info,
503            )?;
504            all_head_outputs.push(head_output.0);
505
506            // Accumulate attention weights
507            attention_weights_all = attention_weights_all + &head_output.1;
508        }
509
510        // Average attention weights across heads
511        attention_weights_all = attention_weights_all / self.num_heads as f64;
512
513        // Concatenate all head outputs
514        let concatenated = self.concatenate_heads(&all_head_outputs)?;
515
516        // Apply output projection
517        let mut final_output = Array3::zeros((batch_size, seq_len, self.model_dim));
518        for batch_idx in 0..batch_size {
519            for seq_idx in 0..seq_len {
520                let input = concatenated.slice(s![batch_idx, seq_idx, ..]).to_owned();
521                let output = self.output_projection.forward(&input)?;
522                final_output
523                    .slice_mut(s![batch_idx, seq_idx, ..])
524                    .assign(&output);
525            }
526        }
527
528        Ok(AttentionOutput {
529            attention_weights: attention_weights_all,
530            output: final_output,
531            quantum_info,
532        })
533    }
534
535    /// Compute attention for a single head
536    fn compute_head_attention(
537        &self,
538        query: &Array3<f64>,
539        key: &Array3<f64>,
540        value: &Array3<f64>,
541        head_idx: usize,
542        attention_mask: Option<&Array3<bool>>,
543        quantum_info: &mut QuantumAttentionInfo,
544    ) -> Result<(Array3<f64>, Array3<f64>)> {
545        let (batch_size, seq_len, _) = query.dim();
546
547        // Project query, key, value through quantum networks
548        let mut q_proj = Array3::zeros((batch_size, seq_len, self.head_dim));
549        let mut k_proj = Array3::zeros((batch_size, seq_len, self.head_dim));
550        let mut v_proj = Array3::zeros((batch_size, seq_len, self.head_dim));
551
552        for batch_idx in 0..batch_size {
553            for seq_idx in 0..seq_len {
554                let q_input = query.slice(s![batch_idx, seq_idx, ..]).to_owned();
555                let k_input = key.slice(s![batch_idx, seq_idx, ..]).to_owned();
556                let v_input = value.slice(s![batch_idx, seq_idx, ..]).to_owned();
557
558                let q_out = self.query_layers[head_idx].forward(&q_input)?;
559                let k_out = self.key_layers[head_idx].forward(&k_input)?;
560                let v_out = self.value_layers[head_idx].forward(&v_input)?;
561
562                q_proj.slice_mut(s![batch_idx, seq_idx, ..]).assign(&q_out);
563                k_proj.slice_mut(s![batch_idx, seq_idx, ..]).assign(&k_out);
564                v_proj.slice_mut(s![batch_idx, seq_idx, ..]).assign(&v_out);
565            }
566        }
567
568        // Compute quantum attention scores
569        let attention_scores =
570            self.compute_quantum_attention_scores(&q_proj, &k_proj, quantum_info)?;
571
572        // Apply attention mask if provided
573        let masked_scores = if let Some(mask) = attention_mask {
574            self.apply_attention_mask(&attention_scores, mask)?
575        } else {
576            attention_scores
577        };
578
579        // Apply softmax to get attention weights
580        let attention_weights = self.quantum_softmax(&masked_scores)?;
581
582        // Apply attention to values
583        let output = self.apply_attention_to_values(&attention_weights, &v_proj)?;
584
585        Ok((output, attention_weights))
586    }
587
588    /// Compute quantum attention scores using quantum circuits
589    fn compute_quantum_attention_scores(
590        &self,
591        query: &Array3<f64>,
592        key: &Array3<f64>,
593        quantum_info: &mut QuantumAttentionInfo,
594    ) -> Result<Array3<f64>> {
595        let (batch_size, seq_len, head_dim) = query.dim();
596        let mut attention_scores = Array3::zeros((batch_size, seq_len, seq_len));
597
598        match self.attention_type {
599            QuantumAttentionType::FullQuantum => {
600                // Use quantum interference for attention computation
601                for batch_idx in 0..batch_size {
602                    for i in 0..seq_len {
603                        for j in 0..seq_len {
604                            let score = self.quantum_dot_product(
605                                &query.slice(s![batch_idx, i, ..]).to_owned(),
606                                &key.slice(s![batch_idx, j, ..]).to_owned(),
607                                quantum_info,
608                                i,
609                                j,
610                            )?;
611                            attention_scores[[batch_idx, i, j]] = score;
612                        }
613                    }
614                }
615            }
616
617            QuantumAttentionType::HybridQuantumClassical => {
618                // Hybrid computation
619                for batch_idx in 0..batch_size {
620                    let q_batch = query.slice(s![batch_idx, .., ..]);
621                    let k_batch = key.slice(s![batch_idx, .., ..]);
622
623                    // Classical dot product with quantum enhancement
624                    let classical_scores = q_batch.dot(&k_batch.t()) / (head_dim as f64).sqrt();
625
626                    // Apply quantum enhancement
627                    for i in 0..seq_len {
628                        for j in 0..seq_len {
629                            let quantum_enhancement = self.compute_quantum_enhancement(
630                                &query.slice(s![batch_idx, i, ..]).to_owned(),
631                                &key.slice(s![batch_idx, j, ..]).to_owned(),
632                            )?;
633                            attention_scores[[batch_idx, i, j]] =
634                                classical_scores[[i, j]] * (1.0 + quantum_enhancement);
635                        }
636                    }
637                }
638            }
639
640            _ => {
641                // Default classical computation with quantum post-processing
642                for batch_idx in 0..batch_size {
643                    let q_batch = query.slice(s![batch_idx, .., ..]);
644                    let k_batch = key.slice(s![batch_idx, .., ..]);
645                    let scores = q_batch.dot(&k_batch.t()) / (head_dim as f64).sqrt();
646                    attention_scores
647                        .slice_mut(s![batch_idx, .., ..])
648                        .assign(&scores);
649                }
650            }
651        }
652
653        Ok(attention_scores)
654    }
655
656    /// Compute quantum dot product with entanglement tracking
657    fn quantum_dot_product(
658        &self,
659        vec1: &Array1<f64>,
660        vec2: &Array1<f64>,
661        quantum_info: &mut QuantumAttentionInfo,
662        pos1: usize,
663        pos2: usize,
664    ) -> Result<f64> {
665        let dim = vec1.len();
666        let num_qubits = self.attention_circuit.num_qubits();
667
668        // Encode vectors into quantum states
669        let mut circuit = self.attention_circuit.clone();
670
671        // Encode first vector
672        for i in 0..dim.min(num_qubits / 2) {
673            let angle = vec1[i] * PI;
674            circuit.ry(i, angle);
675        }
676
677        // Encode second vector
678        for i in 0..dim.min(num_qubits / 2) {
679            let angle = vec2[i] * PI;
680            let qubit_idx = i + num_qubits / 2;
681            if qubit_idx < num_qubits {
682                circuit.ry(qubit_idx, angle);
683            }
684        }
685
686        // Add entangling operations for interference
687        for i in 0..num_qubits / 2 {
688            let target = i + num_qubits / 2;
689            if target < num_qubits {
690                circuit.cnot(i, target);
691            }
692        }
693
694        // Simulate and extract dot product
695        let simulator = StateVectorSimulator::new();
696        let register = simulator.run(&circuit)?;
697        let state_probs = register.probabilities();
698
699        // Compute expectation value as dot product
700        let dot_product = self.extract_dot_product_from_state(&state_probs)?;
701
702        // Update quantum information
703        let entanglement = self.compute_entanglement(&state_probs)?;
704        quantum_info.entanglement_matrix[[pos1, pos2]] = entanglement;
705
706        if pos1 < quantum_info.coherence_scores.len() {
707            quantum_info.coherence_scores[pos1] = self.compute_coherence(&state_probs)?;
708        }
709
710        Ok(dot_product)
711    }
712
713    /// Extract dot product value from quantum state
714    fn extract_dot_product_from_state(&self, state: &[f64]) -> Result<f64> {
715        // Simplified extraction - compute overlap amplitude
716        let mut dot_product = 0.0;
717
718        for (i, &amplitude) in state.iter().enumerate() {
719            // Weight by computational basis state
720            let weight = if i % 2 == 0 { 1.0 } else { -1.0 };
721            dot_product += weight * amplitude * amplitude;
722        }
723
724        Ok(dot_product)
725    }
726
727    /// Compute quantum enhancement factor
728    fn compute_quantum_enhancement(&self, vec1: &Array1<f64>, vec2: &Array1<f64>) -> Result<f64> {
729        // Quantum enhancement based on vector properties
730        let norm1 = vec1.mapv(|x| x * x).sum().sqrt();
731        let norm2 = vec2.mapv(|x| x * x).sum().sqrt();
732
733        if norm1 < 1e-10 || norm2 < 1e-10 {
734            return Ok(0.0);
735        }
736
737        // Compute quantum coherence enhancement
738        let classical_dot = vec1.dot(vec2);
739        let quantum_interference = (norm1 * norm2 - classical_dot.abs()).max(0.0);
740
741        Ok(quantum_interference / (norm1 * norm2 + 1e-10))
742    }
743
744    /// Apply quantum softmax with temperature scaling
745    fn quantum_softmax(&self, scores: &Array3<f64>) -> Result<Array3<f64>> {
746        let mut weights = Array3::zeros(scores.dim());
747        let temperature = 1.0; // Could be learnable parameter
748
749        for batch_idx in 0..scores.dim().0 {
750            for seq_idx in 0..scores.dim().1 {
751                let row = scores.slice(s![batch_idx, seq_idx, ..]);
752                let max_score = row.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
753
754                // Apply quantum-enhanced softmax
755                let mut exp_scores = Array1::zeros(row.len());
756                let mut sum_exp = 0.0;
757
758                for (i, &score) in row.iter().enumerate() {
759                    let enhanced_score = score / temperature;
760                    let quantum_factor = self.compute_quantum_softmax_factor(score)?;
761                    let exp_val = (enhanced_score - max_score + quantum_factor).exp();
762                    exp_scores[i] = exp_val;
763                    sum_exp += exp_val;
764                }
765
766                // Normalize
767                if sum_exp > 1e-10 {
768                    exp_scores = exp_scores / sum_exp;
769                }
770
771                weights
772                    .slice_mut(s![batch_idx, seq_idx, ..])
773                    .assign(&exp_scores);
774            }
775        }
776
777        Ok(weights)
778    }
779
780    /// Compute quantum factor for softmax enhancement
781    fn compute_quantum_softmax_factor(&self, score: f64) -> Result<f64> {
782        // Apply quantum superposition effect to softmax
783        let quantum_phase = (score * PI).sin().abs();
784        Ok(0.1 * quantum_phase) // Small quantum enhancement
785    }
786
787    /// Apply attention weights to values
788    fn apply_attention_to_values(
789        &self,
790        attention_weights: &Array3<f64>,
791        values: &Array3<f64>,
792    ) -> Result<Array3<f64>> {
793        let (batch_size, seq_len, head_dim) = values.dim();
794        let mut output = Array3::zeros((batch_size, seq_len, head_dim));
795
796        for batch_idx in 0..batch_size {
797            let weights = attention_weights.slice(s![batch_idx, .., ..]);
798            let vals = values.slice(s![batch_idx, .., ..]);
799
800            let attended_values = weights.dot(&vals);
801            output
802                .slice_mut(s![batch_idx, .., ..])
803                .assign(&attended_values);
804        }
805
806        Ok(output)
807    }
808
809    /// Apply attention mask
810    fn apply_attention_mask(
811        &self,
812        scores: &Array3<f64>,
813        mask: &Array3<bool>,
814    ) -> Result<Array3<f64>> {
815        let mut masked_scores = scores.clone();
816
817        for ((i, j, k), &should_mask) in mask.indexed_iter() {
818            if should_mask {
819                masked_scores[[i, j, k]] = f64::NEG_INFINITY;
820            }
821        }
822
823        Ok(masked_scores)
824    }
825
826    /// Concatenate outputs from all attention heads
827    fn concatenate_heads(&self, head_outputs: &[Array3<f64>]) -> Result<Array3<f64>> {
828        let (batch_size, seq_len, head_dim) = head_outputs[0].dim();
829        let mut concatenated = Array3::zeros((batch_size, seq_len, self.model_dim));
830
831        for batch_idx in 0..batch_size {
832            for seq_idx in 0..seq_len {
833                let mut concat_vec = Array1::zeros(self.model_dim);
834
835                for (head_idx, head_output) in head_outputs.iter().enumerate() {
836                    let start_idx = head_idx * head_dim;
837                    let end_idx = start_idx + head_dim;
838
839                    concat_vec
840                        .slice_mut(s![start_idx..end_idx])
841                        .assign(&head_output.slice(s![batch_idx, seq_idx, ..]));
842                }
843
844                concatenated
845                    .slice_mut(s![batch_idx, seq_idx, ..])
846                    .assign(&concat_vec);
847            }
848        }
849
850        Ok(concatenated)
851    }
852
853    /// Compute entanglement measure from quantum state
854    fn compute_entanglement(&self, state: &[f64]) -> Result<f64> {
855        // Simplified entanglement computation
856        let num_qubits = (state.len() as f64).log2() as usize;
857        if num_qubits < 2 {
858            return Ok(0.0);
859        }
860
861        // Compute entropy-based entanglement measure
862        let mut entanglement = 0.0;
863        for &amplitude in state {
864            let prob = amplitude * amplitude;
865            if prob > 1e-10 {
866                entanglement -= prob * prob.ln();
867            }
868        }
869
870        Ok(entanglement / (num_qubits as f64))
871    }
872
873    /// Compute quantum coherence measure
874    fn compute_coherence(&self, state: &[f64]) -> Result<f64> {
875        // L1 norm coherence measure
876        let mut coherence = 0.0;
877
878        for (i, &amplitude) in state.iter().enumerate() {
879            if i > 0 {
880                // Exclude diagonal elements (|0><0|, |1><1|, etc.)
881                coherence += amplitude.abs();
882            }
883        }
884
885        Ok(coherence)
886    }
887}
888
889impl QuantumPositionEncoding {
890    /// Create new quantum position encoding
891    pub fn new(
892        encoding_type: PositionEncodingType,
893        model_dim: usize,
894        max_seq_len: usize,
895        num_qubits: usize,
896    ) -> Result<Self> {
897        let mut encoding_circuits = Vec::new();
898        let mut learnable_params = None;
899
900        match encoding_type {
901            PositionEncodingType::LearnableQuantum => {
902                learnable_params = Some(Array2::zeros((max_seq_len, model_dim)));
903
904                // Create quantum circuits for learnable position encoding
905                for _ in 0..max_seq_len {
906                    let mut circuit = Circuit::<16>::new();
907                    for i in 0..num_qubits.min(16) {
908                        circuit.ry(i, 0.0); // Will be parameterized
909                        circuit.rz(i, 0.0); // Will be parameterized
910                    }
911                    encoding_circuits.push(circuit);
912                }
913            }
914
915            PositionEncodingType::QuantumPhase => {
916                // Create quantum phase encoding circuits
917                for pos in 0..max_seq_len {
918                    let mut circuit = Circuit::<16>::new();
919                    for i in 0..num_qubits.min(16) {
920                        let phase = 2.0 * PI * pos as f64
921                            / (10000_f64.powf(2.0 * i as f64 / model_dim as f64));
922                        circuit.h(i);
923                        circuit.rz(i, phase);
924                    }
925                    encoding_circuits.push(circuit);
926                }
927            }
928
929            _ => {
930                // Default circuits for other encoding types
931                for _ in 0..max_seq_len {
932                    let mut circuit = Circuit::<16>::new();
933                    for i in 0..num_qubits.min(16) {
934                        circuit.h(i);
935                        circuit.ry(i, 0.0); // Will be parameterized
936                    }
937                    encoding_circuits.push(circuit);
938                }
939            }
940        }
941
942        Ok(Self {
943            encoding_type,
944            model_dim,
945            max_seq_len,
946            learnable_params,
947            encoding_circuits,
948        })
949    }
950
951    /// Generate position encodings for input sequence
952    pub fn forward(&self, seq_len: usize, batch_size: usize) -> Result<Array3<f64>> {
953        let mut encodings = Array3::zeros((batch_size, seq_len, self.model_dim));
954
955        match self.encoding_type {
956            PositionEncodingType::Sinusoidal => {
957                self.generate_sinusoidal_encoding(&mut encodings, seq_len)?;
958            }
959
960            PositionEncodingType::QuantumPhase => {
961                self.generate_quantum_phase_encoding(&mut encodings, seq_len)?;
962            }
963
964            PositionEncodingType::LearnableQuantum => {
965                self.generate_learnable_quantum_encoding(&mut encodings, seq_len)?;
966            }
967
968            PositionEncodingType::Relative => {
969                self.generate_relative_encoding(&mut encodings, seq_len)?;
970            }
971
972            PositionEncodingType::Rotary => {
973                self.generate_rotary_encoding(&mut encodings, seq_len)?;
974            }
975        }
976
977        Ok(encodings)
978    }
979
980    /// Generate sinusoidal position encoding
981    fn generate_sinusoidal_encoding(
982        &self,
983        encodings: &mut Array3<f64>,
984        seq_len: usize,
985    ) -> Result<()> {
986        for pos in 0..seq_len {
987            for i in 0..self.model_dim {
988                let angle =
989                    pos as f64 / 10000_f64.powf(2.0 * (i / 2) as f64 / self.model_dim as f64);
990
991                let encoding_value = if i % 2 == 0 { angle.sin() } else { angle.cos() };
992
993                // Apply to all batches
994                for batch in 0..encodings.dim().0 {
995                    encodings[[batch, pos, i]] = encoding_value;
996                }
997            }
998        }
999
1000        Ok(())
1001    }
1002
1003    /// Generate quantum phase position encoding
1004    fn generate_quantum_phase_encoding(
1005        &self,
1006        encodings: &mut Array3<f64>,
1007        seq_len: usize,
1008    ) -> Result<()> {
1009        let simulator = StateVectorSimulator::new();
1010
1011        for pos in 0..seq_len {
1012            if pos < self.encoding_circuits.len() {
1013                let register = simulator.run(&self.encoding_circuits[pos])?;
1014                let state = register.probabilities();
1015
1016                // Extract encoding from quantum state
1017                for i in 0..self.model_dim.min(state.len()) {
1018                    let encoding_value = state[i % state.len()];
1019
1020                    for batch in 0..encodings.dim().0 {
1021                        encodings[[batch, pos, i]] = encoding_value;
1022                    }
1023                }
1024            }
1025        }
1026
1027        Ok(())
1028    }
1029
1030    /// Generate learnable quantum position encoding
1031    fn generate_learnable_quantum_encoding(
1032        &self,
1033        encodings: &mut Array3<f64>,
1034        seq_len: usize,
1035    ) -> Result<()> {
1036        if let Some(ref params) = self.learnable_params {
1037            for pos in 0..seq_len.min(params.nrows()) {
1038                for i in 0..self.model_dim.min(params.ncols()) {
1039                    let encoding_value = params[[pos, i]];
1040
1041                    for batch in 0..encodings.dim().0 {
1042                        encodings[[batch, pos, i]] = encoding_value;
1043                    }
1044                }
1045            }
1046        }
1047
1048        Ok(())
1049    }
1050
1051    /// Generate relative position encoding
1052    fn generate_relative_encoding(
1053        &self,
1054        encodings: &mut Array3<f64>,
1055        seq_len: usize,
1056    ) -> Result<()> {
1057        // Simplified relative encoding
1058        for pos in 0..seq_len {
1059            for i in 0..self.model_dim {
1060                let relative_pos = (pos as f64 - seq_len as f64 / 2.0) / seq_len as f64;
1061                let encoding_value = relative_pos * (i as f64 / self.model_dim as f64);
1062
1063                for batch in 0..encodings.dim().0 {
1064                    encodings[[batch, pos, i]] = encoding_value.tanh();
1065                }
1066            }
1067        }
1068
1069        Ok(())
1070    }
1071
1072    /// Generate rotary position embedding (RoPE)
1073    fn generate_rotary_encoding(&self, encodings: &mut Array3<f64>, seq_len: usize) -> Result<()> {
1074        for pos in 0..seq_len {
1075            for i in 0..(self.model_dim / 2) {
1076                let theta = pos as f64 / 10000_f64.powf(2.0 * i as f64 / self.model_dim as f64);
1077
1078                let cos_val = theta.cos();
1079                let sin_val = theta.sin();
1080
1081                for batch in 0..encodings.dim().0 {
1082                    encodings[[batch, pos, 2 * i]] = cos_val;
1083                    encodings[[batch, pos, 2 * i + 1]] = sin_val;
1084                }
1085            }
1086        }
1087
1088        Ok(())
1089    }
1090}
1091
1092impl QuantumFeedForward {
1093    /// Create new quantum feedforward network
1094    pub fn new(
1095        input_dim: usize,
1096        hidden_dim: usize,
1097        output_dim: usize,
1098        num_qubits: usize,
1099        activation: ActivationType,
1100        dropout_rate: f64,
1101    ) -> Result<Self> {
1102        // First layer: input_dim -> hidden_dim
1103        let layer1_structure = vec![
1104            QNNLayerType::EncodingLayer {
1105                num_features: input_dim,
1106            },
1107            QNNLayerType::VariationalLayer {
1108                num_params: hidden_dim * 2,
1109            },
1110            QNNLayerType::EntanglementLayer {
1111                connectivity: "circular".to_string(),
1112            },
1113            QNNLayerType::VariationalLayer {
1114                num_params: hidden_dim,
1115            },
1116            QNNLayerType::MeasurementLayer {
1117                measurement_basis: "computational".to_string(),
1118            },
1119        ];
1120        let layer1 =
1121            QuantumNeuralNetwork::new(layer1_structure, num_qubits, input_dim, hidden_dim)?;
1122
1123        // Second layer: hidden_dim -> output_dim
1124        let layer2_structure = vec![
1125            QNNLayerType::EncodingLayer {
1126                num_features: hidden_dim,
1127            },
1128            QNNLayerType::VariationalLayer {
1129                num_params: output_dim * 2,
1130            },
1131            QNNLayerType::MeasurementLayer {
1132                measurement_basis: "computational".to_string(),
1133            },
1134        ];
1135        let layer2 =
1136            QuantumNeuralNetwork::new(layer2_structure, num_qubits, hidden_dim, output_dim)?;
1137
1138        Ok(Self {
1139            input_dim,
1140            hidden_dim,
1141            output_dim,
1142            layer1,
1143            layer2,
1144            activation,
1145            dropout_rate,
1146        })
1147    }
1148
1149    /// Forward pass through feedforward network
1150    pub fn forward(&self, input: &Array2<f64>) -> Result<Array2<f64>> {
1151        let (batch_size, seq_len) = (input.nrows(), input.ncols() / self.input_dim);
1152        let mut output = Array2::zeros((batch_size, seq_len * self.output_dim));
1153
1154        for batch_idx in 0..batch_size {
1155            for seq_idx in 0..seq_len {
1156                let start_idx = seq_idx * self.input_dim;
1157                let end_idx = start_idx + self.input_dim;
1158
1159                let input_slice = input.slice(s![batch_idx, start_idx..end_idx]).to_owned();
1160
1161                // First layer
1162                let hidden = self.layer1.forward(&input_slice)?;
1163
1164                // Apply quantum activation
1165                let activated = self.apply_quantum_activation(&hidden)?;
1166
1167                // Apply dropout (simplified)
1168                let dropped = self.apply_dropout(&activated)?;
1169
1170                // Second layer
1171                let output_slice = self.layer2.forward(&dropped)?;
1172
1173                let out_start = seq_idx * self.output_dim;
1174                let out_end = out_start + self.output_dim;
1175                output
1176                    .slice_mut(s![batch_idx, out_start..out_end])
1177                    .assign(&output_slice);
1178            }
1179        }
1180
1181        Ok(output)
1182    }
1183
1184    /// Apply quantum activation function
1185    fn apply_quantum_activation(&self, input: &Array1<f64>) -> Result<Array1<f64>> {
1186        match self.activation {
1187            ActivationType::QuantumReLU => {
1188                // Quantum ReLU using amplitude encoding
1189                Ok(input.mapv(|x| if x > 0.0 { x } else { 0.0 }))
1190            }
1191
1192            ActivationType::QuantumGELU => {
1193                // Quantum GELU approximation
1194                Ok(input.mapv(|x| {
1195                    let gelu =
1196                        0.5 * x * (1.0 + (x * 0.7978845608 * (1.0 + 0.044715 * x * x)).tanh());
1197                    gelu
1198                }))
1199            }
1200
1201            ActivationType::QuantumSwish => {
1202                // Quantum Swish activation
1203                Ok(input.mapv(|x| x / (1.0 + (-x).exp())))
1204            }
1205
1206            ActivationType::ParameterizedQuantum => {
1207                // Parameterized quantum activation
1208                Ok(input.mapv(|x| (x * PI / 2.0).sin()))
1209            }
1210
1211            ActivationType::ClassicalHybrid => {
1212                // Classical activation applied to quantum measurements
1213                Ok(input.mapv(|x| x.tanh()))
1214            }
1215        }
1216    }
1217
1218    /// Apply dropout
1219    fn apply_dropout(&self, input: &Array1<f64>) -> Result<Array1<f64>> {
1220        // Simplified dropout - in practice would be training-dependent
1221        if self.dropout_rate > 0.0 {
1222            let scale = 1.0 / (1.0 - self.dropout_rate);
1223            Ok(input.mapv(|x| {
1224                if fastrand::f64() < self.dropout_rate {
1225                    0.0
1226                } else {
1227                    x * scale
1228                }
1229            }))
1230        } else {
1231            Ok(input.clone())
1232        }
1233    }
1234}
1235
1236impl QuantumTransformerLayer {
1237    /// Create new quantum transformer layer
1238    pub fn new(
1239        model_dim: usize,
1240        num_heads: usize,
1241        ff_dim: usize,
1242        num_qubits: usize,
1243        attention_type: QuantumAttentionType,
1244        dropout_rate: f64,
1245    ) -> Result<Self> {
1246        let attention =
1247            QuantumMultiHeadAttention::new(num_heads, model_dim, attention_type, num_qubits)?;
1248
1249        let feedforward = QuantumFeedForward::new(
1250            model_dim,
1251            ff_dim,
1252            model_dim,
1253            num_qubits,
1254            ActivationType::QuantumGELU,
1255            dropout_rate,
1256        )?;
1257
1258        let norm1_scale = Array1::ones(model_dim);
1259        let norm1_bias = Array1::zeros(model_dim);
1260        let norm2_scale = Array1::ones(model_dim);
1261        let norm2_bias = Array1::zeros(model_dim);
1262
1263        Ok(Self {
1264            attention,
1265            feedforward,
1266            norm1_scale,
1267            norm1_bias,
1268            norm2_scale,
1269            norm2_bias,
1270            model_dim,
1271            dropout_rate,
1272        })
1273    }
1274
1275    /// Forward pass through transformer layer
1276    pub fn forward(
1277        &self,
1278        input: &Array3<f64>,
1279        attention_mask: Option<&Array3<bool>>,
1280    ) -> Result<Array3<f64>> {
1281        // Self-attention with residual connection and layer norm
1282        let attention_output = self
1283            .attention
1284            .forward(input, input, input, attention_mask)?;
1285        let attended = input + &attention_output.output;
1286        let normed1 = self.layer_norm(&attended, &self.norm1_scale, &self.norm1_bias)?;
1287
1288        // Feedforward with residual connection and layer norm
1289        let ff_input = self.reshape_for_feedforward(&normed1)?;
1290        let ff_output = self.feedforward.forward(&ff_input)?;
1291        let ff_reshaped = self.reshape_from_feedforward(&ff_output, normed1.dim())?;
1292        let ff_residual = &normed1 + &ff_reshaped;
1293        let normed2 = self.layer_norm(&ff_residual, &self.norm2_scale, &self.norm2_bias)?;
1294
1295        Ok(normed2)
1296    }
1297
1298    /// Apply layer normalization
1299    fn layer_norm(
1300        &self,
1301        input: &Array3<f64>,
1302        scale: &Array1<f64>,
1303        bias: &Array1<f64>,
1304    ) -> Result<Array3<f64>> {
1305        let (batch_size, seq_len, model_dim) = input.dim();
1306        let mut output = Array3::zeros((batch_size, seq_len, model_dim));
1307
1308        for batch_idx in 0..batch_size {
1309            for seq_idx in 0..seq_len {
1310                let input_slice = input.slice(s![batch_idx, seq_idx, ..]);
1311
1312                // Compute mean and variance
1313                let mean = input_slice.mean().unwrap_or(0.0);
1314                let variance = input_slice
1315                    .mapv(|x| (x - mean).powi(2))
1316                    .mean()
1317                    .unwrap_or(1.0);
1318                let std = (variance + 1e-6).sqrt();
1319
1320                // Normalize
1321                let normalized = input_slice.mapv(|x| (x - mean) / std);
1322
1323                // Scale and shift
1324                let scaled = &normalized * scale + bias;
1325                output.slice_mut(s![batch_idx, seq_idx, ..]).assign(&scaled);
1326            }
1327        }
1328
1329        Ok(output)
1330    }
1331
1332    /// Reshape for feedforward processing
1333    fn reshape_for_feedforward(&self, input: &Array3<f64>) -> Result<Array2<f64>> {
1334        let (batch_size, seq_len, model_dim) = input.dim();
1335        let mut output = Array2::zeros((batch_size, seq_len * model_dim));
1336
1337        for batch_idx in 0..batch_size {
1338            for seq_idx in 0..seq_len {
1339                let start_idx = seq_idx * model_dim;
1340                let end_idx = start_idx + model_dim;
1341
1342                output
1343                    .slice_mut(s![batch_idx, start_idx..end_idx])
1344                    .assign(&input.slice(s![batch_idx, seq_idx, ..]));
1345            }
1346        }
1347
1348        Ok(output)
1349    }
1350
1351    /// Reshape from feedforward processing
1352    fn reshape_from_feedforward(
1353        &self,
1354        input: &Array2<f64>,
1355        target_shape: (usize, usize, usize),
1356    ) -> Result<Array3<f64>> {
1357        let (batch_size, seq_len, model_dim) = target_shape;
1358        let mut output = Array3::zeros((batch_size, seq_len, model_dim));
1359
1360        for batch_idx in 0..batch_size {
1361            for seq_idx in 0..seq_len {
1362                let start_idx = seq_idx * model_dim;
1363                let end_idx = start_idx + model_dim;
1364
1365                output
1366                    .slice_mut(s![batch_idx, seq_idx, ..])
1367                    .assign(&input.slice(s![batch_idx, start_idx..end_idx]));
1368            }
1369        }
1370
1371        Ok(output)
1372    }
1373}
1374
1375impl QuantumTransformer {
1376    /// Create new quantum transformer model
1377    pub fn new(config: QuantumTransformerConfig) -> Result<Self> {
1378        // Position encoding
1379        let position_encoding = QuantumPositionEncoding::new(
1380            config.position_encoding.clone(),
1381            config.model_dim,
1382            config.max_seq_len,
1383            config.num_qubits,
1384        )?;
1385
1386        // Transformer layers
1387        let mut layers = Vec::new();
1388        for _ in 0..config.num_layers {
1389            let layer = QuantumTransformerLayer::new(
1390                config.model_dim,
1391                config.num_heads,
1392                config.ff_dim,
1393                config.num_qubits,
1394                config.attention_type.clone(),
1395                config.dropout_rate,
1396            )?;
1397            layers.push(layer);
1398        }
1399
1400        // Input embedding layer
1401        let embedding_layers = vec![
1402            QNNLayerType::EncodingLayer {
1403                num_features: config.model_dim,
1404            },
1405            QNNLayerType::VariationalLayer {
1406                num_params: config.model_dim,
1407            },
1408            QNNLayerType::MeasurementLayer {
1409                measurement_basis: "computational".to_string(),
1410            },
1411        ];
1412        let input_embedding = QuantumNeuralNetwork::new(
1413            embedding_layers,
1414            config.num_qubits,
1415            config.model_dim,
1416            config.model_dim,
1417        )?;
1418
1419        // Output projection layer
1420        let output_layers = vec![
1421            QNNLayerType::EncodingLayer {
1422                num_features: config.model_dim,
1423            },
1424            QNNLayerType::VariationalLayer {
1425                num_params: config.model_dim,
1426            },
1427            QNNLayerType::MeasurementLayer {
1428                measurement_basis: "computational".to_string(),
1429            },
1430        ];
1431        let output_projection = QuantumNeuralNetwork::new(
1432            output_layers,
1433            config.num_qubits,
1434            config.model_dim,
1435            config.model_dim,
1436        )?;
1437
1438        // Final layer normalization
1439        let final_norm_scale = Array1::ones(config.model_dim);
1440        let final_norm_bias = Array1::zeros(config.model_dim);
1441
1442        Ok(Self {
1443            config,
1444            position_encoding,
1445            layers,
1446            input_embedding,
1447            output_projection,
1448            final_norm_scale,
1449            final_norm_bias,
1450        })
1451    }
1452
1453    /// Forward pass through quantum transformer
1454    pub fn forward(
1455        &self,
1456        input: &Array3<f64>, // [batch_size, seq_len, input_dim]
1457        attention_mask: Option<&Array3<bool>>,
1458    ) -> Result<Array3<f64>> {
1459        let (batch_size, seq_len, input_dim) = input.dim();
1460
1461        if seq_len > self.config.max_seq_len {
1462            return Err(MLError::ConfigurationError(format!(
1463                "Sequence length {} exceeds maximum {}",
1464                seq_len, self.config.max_seq_len
1465            )));
1466        }
1467
1468        // Input embedding
1469        let mut embedded = Array3::zeros((batch_size, seq_len, self.config.model_dim));
1470        for batch_idx in 0..batch_size {
1471            for seq_idx in 0..seq_len {
1472                let input_vec = input.slice(s![batch_idx, seq_idx, ..]).to_owned();
1473
1474                // Pad or truncate to model_dim
1475                let mut padded_input = Array1::zeros(self.config.model_dim);
1476                let copy_len = input_dim.min(self.config.model_dim);
1477                padded_input
1478                    .slice_mut(s![..copy_len])
1479                    .assign(&input_vec.slice(s![..copy_len]));
1480
1481                let embedding_output = self.input_embedding.forward(&padded_input)?;
1482                embedded
1483                    .slice_mut(s![batch_idx, seq_idx, ..])
1484                    .assign(&embedding_output);
1485            }
1486        }
1487
1488        // Add position encoding
1489        let position_encodings = self.position_encoding.forward(seq_len, batch_size)?;
1490        let mut x = embedded + position_encodings;
1491
1492        // Pass through transformer layers
1493        for layer in &self.layers {
1494            x = layer.forward(&x, attention_mask)?;
1495        }
1496
1497        // Apply final layer normalization
1498        x = self.apply_final_layer_norm(&x)?;
1499
1500        // Output projection
1501        let mut output = Array3::zeros((batch_size, seq_len, self.config.model_dim));
1502        for batch_idx in 0..batch_size {
1503            for seq_idx in 0..seq_len {
1504                let input_vec = x.slice(s![batch_idx, seq_idx, ..]).to_owned();
1505                let projected_output = self.output_projection.forward(&input_vec)?;
1506                output
1507                    .slice_mut(s![batch_idx, seq_idx, ..])
1508                    .assign(&projected_output);
1509            }
1510        }
1511
1512        Ok(output)
1513    }
1514
1515    /// Apply final layer normalization
1516    fn apply_final_layer_norm(&self, input: &Array3<f64>) -> Result<Array3<f64>> {
1517        let (batch_size, seq_len, model_dim) = input.dim();
1518        let mut output = Array3::zeros((batch_size, seq_len, model_dim));
1519
1520        for batch_idx in 0..batch_size {
1521            for seq_idx in 0..seq_len {
1522                let input_slice = input.slice(s![batch_idx, seq_idx, ..]);
1523
1524                let mean = input_slice.mean().unwrap_or(0.0);
1525                let variance = input_slice
1526                    .mapv(|x| (x - mean).powi(2))
1527                    .mean()
1528                    .unwrap_or(1.0);
1529                let std = (variance + 1e-6).sqrt();
1530
1531                let normalized = input_slice.mapv(|x| (x - mean) / std);
1532                let scaled = &normalized * &self.final_norm_scale + &self.final_norm_bias;
1533
1534                output.slice_mut(s![batch_idx, seq_idx, ..]).assign(&scaled);
1535            }
1536        }
1537
1538        Ok(output)
1539    }
1540
1541    /// Get model configuration
1542    pub fn config(&self) -> &QuantumTransformerConfig {
1543        &self.config
1544    }
1545
1546    /// Get number of parameters
1547    pub fn num_parameters(&self) -> usize {
1548        let mut total = 0;
1549
1550        // Count parameters in all components
1551        total += self.input_embedding.parameters.len();
1552        total += self.output_projection.parameters.len();
1553
1554        for layer in &self.layers {
1555            total += layer
1556                .attention
1557                .query_layers
1558                .iter()
1559                .map(|l| l.parameters.len())
1560                .sum::<usize>();
1561            total += layer
1562                .attention
1563                .key_layers
1564                .iter()
1565                .map(|l| l.parameters.len())
1566                .sum::<usize>();
1567            total += layer
1568                .attention
1569                .value_layers
1570                .iter()
1571                .map(|l| l.parameters.len())
1572                .sum::<usize>();
1573            total += layer.attention.output_projection.parameters.len();
1574            total += layer.feedforward.layer1.parameters.len();
1575            total += layer.feedforward.layer2.parameters.len();
1576            total += layer.norm1_scale.len() + layer.norm1_bias.len();
1577            total += layer.norm2_scale.len() + layer.norm2_bias.len();
1578        }
1579
1580        if let Some(ref params) = self.position_encoding.learnable_params {
1581            total += params.len();
1582        }
1583
1584        total += self.final_norm_scale.len() + self.final_norm_bias.len();
1585
1586        total
1587    }
1588}
1589
1590/// Helper function to create causal attention mask
1591pub fn create_causal_mask(batch_size: usize, seq_len: usize) -> Array3<bool> {
1592    let mut mask = Array3::from_elem((batch_size, seq_len, seq_len), false);
1593
1594    for batch_idx in 0..batch_size {
1595        for i in 0..seq_len {
1596            for j in (i + 1)..seq_len {
1597                mask[[batch_idx, i, j]] = true; // Mask future positions
1598            }
1599        }
1600    }
1601
1602    mask
1603}
1604
1605/// Helper function to create padding mask
1606pub fn create_padding_mask(
1607    batch_size: usize,
1608    seq_len: usize,
1609    actual_lengths: &[usize],
1610) -> Array3<bool> {
1611    let mut mask = Array3::from_elem((batch_size, seq_len, seq_len), false);
1612
1613    for (batch_idx, &actual_len) in actual_lengths.iter().enumerate() {
1614        if batch_idx < batch_size {
1615            for i in 0..seq_len {
1616                for j in actual_len..seq_len {
1617                    mask[[batch_idx, i, j]] = true; // Mask padding positions
1618                }
1619            }
1620        }
1621    }
1622
1623    mask
1624}
1625
1626#[cfg(test)]
1627mod tests {
1628    use super::*;
1629
1630    #[test]
1631    fn test_quantum_transformer_config() {
1632        let config = QuantumTransformerConfig::default();
1633        assert_eq!(config.model_dim, 512);
1634        assert_eq!(config.num_heads, 8);
1635        assert_eq!(config.num_layers, 6);
1636
1637        let large_config = QuantumTransformerConfig::large();
1638        assert_eq!(large_config.model_dim, 1024);
1639        assert_eq!(large_config.num_heads, 16);
1640    }
1641
1642    #[test]
1643    fn test_quantum_multi_head_attention_creation() {
1644        let attention = QuantumMultiHeadAttention::new(
1645            8,
1646            512,
1647            QuantumAttentionType::HybridQuantumClassical,
1648            10,
1649        );
1650
1651        assert!(attention.is_ok());
1652        let attn = attention.unwrap();
1653        assert_eq!(attn.num_heads, 8);
1654        assert_eq!(attn.model_dim, 512);
1655        assert_eq!(attn.head_dim, 64);
1656    }
1657
1658    #[test]
1659    fn test_quantum_position_encoding() {
1660        let pos_enc = QuantumPositionEncoding::new(PositionEncodingType::Sinusoidal, 256, 512, 8);
1661
1662        assert!(pos_enc.is_ok());
1663        let pe = pos_enc.unwrap();
1664        assert_eq!(pe.model_dim, 256);
1665        assert_eq!(pe.max_seq_len, 512);
1666    }
1667
1668    #[test]
1669    fn test_quantum_feedforward() {
1670        let ff = QuantumFeedForward::new(256, 1024, 256, 8, ActivationType::QuantumGELU, 0.1);
1671
1672        assert!(ff.is_ok());
1673        let feedforward = ff.unwrap();
1674        assert_eq!(feedforward.input_dim, 256);
1675        assert_eq!(feedforward.hidden_dim, 1024);
1676        assert_eq!(feedforward.output_dim, 256);
1677    }
1678
1679    #[test]
1680    fn test_causal_mask_creation() {
1681        let mask = create_causal_mask(2, 4);
1682        assert_eq!(mask.dim(), (2, 4, 4));
1683
1684        // Check that lower triangle is false (not masked)
1685        assert!(!mask[[0, 0, 0]]);
1686        assert!(!mask[[0, 1, 0]]);
1687        assert!(!mask[[0, 1, 1]]);
1688
1689        // Check that upper triangle is true (masked)
1690        assert!(mask[[0, 0, 1]]);
1691        assert!(mask[[0, 0, 2]]);
1692        assert!(mask[[0, 1, 2]]);
1693    }
1694
1695    #[test]
1696    fn test_padding_mask_creation() {
1697        let actual_lengths = vec![3, 2];
1698        let mask = create_padding_mask(2, 4, &actual_lengths);
1699
1700        // Check padding is masked for first batch (length 3)
1701        assert!(!mask[[0, 0, 2]]); // Not masked (within length)
1702        assert!(mask[[0, 0, 3]]); // Masked (padding)
1703
1704        // Check padding is masked for second batch (length 2)
1705        assert!(!mask[[1, 0, 1]]); // Not masked (within length)
1706        assert!(mask[[1, 0, 2]]); // Masked (padding)
1707        assert!(mask[[1, 0, 3]]); // Masked (padding)
1708    }
1709}