1use ndarray::{s, Array1, Array2, Array3, Axis};
7use num_complex::Complex64;
8use std::collections::HashMap;
9use std::f64::consts::PI;
10
11use crate::autodiff::DifferentiableParam;
12use crate::error::{MLError, Result};
13use crate::utils::VariationalCircuit;
14use quantrs2_circuit::prelude::*;
15use quantrs2_core::gate::{multi::*, single::*, GateOp};
16
17#[derive(Debug, Clone)]
19pub struct QuantumSelfAttention {
20 embed_dim: usize,
22 num_heads: usize,
24 head_dim: usize,
26 qubits_per_head: usize,
28 query_circuit: QuantumProjection,
30 key_circuit: QuantumProjection,
32 value_circuit: QuantumProjection,
34 output_circuit: QuantumProjection,
36 dropout_rate: f64,
38 temperature: f64,
40}
41
42#[derive(Debug, Clone)]
44struct QuantumProjection {
45 input_dim: usize,
47 output_dim: usize,
49 num_qubits: usize,
51 circuit: VariationalCircuit,
53 parameters: HashMap<String, f64>,
55}
56
57impl QuantumProjection {
58 fn new(input_dim: usize, output_dim: usize) -> Self {
60 let num_qubits = ((input_dim.max(output_dim)) as f64).log2().ceil() as usize;
61 let circuit = Self::build_projection_circuit(num_qubits);
62
63 Self {
64 input_dim,
65 output_dim,
66 num_qubits,
67 circuit,
68 parameters: HashMap::new(),
69 }
70 }
71
72 fn build_projection_circuit(num_qubits: usize) -> VariationalCircuit {
74 let mut circuit = VariationalCircuit::new(num_qubits);
75
76 for q in 0..num_qubits {
78 circuit.add_gate("RY", vec![q], vec![format!("encode_{}", q)]);
79 }
80
81 for q in 0..num_qubits - 1 {
83 circuit.add_gate("CNOT", vec![q, q + 1], vec![]);
84 }
85 if num_qubits > 2 {
86 circuit.add_gate("CNOT", vec![num_qubits - 1, 0], vec![]);
87 }
88
89 for q in 0..num_qubits {
91 circuit.add_gate("RX", vec![q], vec![format!("rx_{}", q)]);
92 circuit.add_gate("RZ", vec![q], vec![format!("rz_{}", q)]);
93 }
94
95 for q in (0..num_qubits - 1).step_by(2) {
97 circuit.add_gate("CZ", vec![q, q + 1], vec![]);
98 }
99 for q in (1..num_qubits - 1).step_by(2) {
100 circuit.add_gate("CZ", vec![q, q + 1], vec![]);
101 }
102
103 for q in 0..num_qubits {
105 circuit.add_gate("RY", vec![q], vec![format!("final_{}", q)]);
106 }
107
108 circuit
109 }
110
111 fn forward(&self, input: &Array1<f64>) -> Result<Array1<f64>> {
113 let encoded = self.encode_input(input)?;
115
116 let output_state = self.execute_circuit(&encoded)?;
118
119 self.decode_output(&output_state)
121 }
122
123 fn encode_input(&self, input: &Array1<f64>) -> Result<Vec<Complex64>> {
125 let state_dim = 2_usize.pow(self.num_qubits as u32);
126 let mut quantum_state = vec![Complex64::new(0.0, 0.0); state_dim];
127
128 let norm: f64 = input.iter().map(|x| x * x).sum::<f64>().sqrt();
130 if norm < 1e-10 {
131 return Err(MLError::InvalidInput("Zero norm input".to_string()));
132 }
133
134 for (i, &val) in input.iter().enumerate() {
135 if i < state_dim {
136 quantum_state[i] = Complex64::new(val / norm, 0.0);
137 }
138 }
139
140 Ok(quantum_state)
141 }
142
143 fn execute_circuit(&self, input_state: &[Complex64]) -> Result<Vec<Complex64>> {
145 let state_dim = input_state.len();
148 let mut output_state = input_state.to_vec();
149
150 for i in 0..state_dim {
152 let phase = (i as f64) * 0.1;
153 output_state[i] *= Complex64::new(phase.cos(), phase.sin());
154 }
155
156 Ok(output_state)
157 }
158
159 fn decode_output(&self, quantum_state: &[Complex64]) -> Result<Array1<f64>> {
161 let mut output = Array1::zeros(self.output_dim);
162
163 for i in 0..self.output_dim.min(quantum_state.len()) {
165 output[i] = quantum_state[i].norm();
166 }
167
168 Ok(output)
169 }
170}
171
172impl QuantumSelfAttention {
173 pub fn new(embed_dim: usize, num_heads: usize, dropout_rate: f64) -> Self {
175 assert!(
176 embed_dim % num_heads == 0,
177 "embed_dim must be divisible by num_heads"
178 );
179
180 let head_dim = embed_dim / num_heads;
181 let qubits_per_head = (head_dim as f64).log2().ceil() as usize;
182
183 Self {
184 embed_dim,
185 num_heads,
186 head_dim,
187 qubits_per_head,
188 query_circuit: QuantumProjection::new(embed_dim, embed_dim),
189 key_circuit: QuantumProjection::new(embed_dim, embed_dim),
190 value_circuit: QuantumProjection::new(embed_dim, embed_dim),
191 output_circuit: QuantumProjection::new(embed_dim, embed_dim),
192 dropout_rate,
193 temperature: (head_dim as f64).sqrt(),
194 }
195 }
196
197 pub fn forward(
199 &self,
200 query: &Array2<f64>,
201 key: &Array2<f64>,
202 value: &Array2<f64>,
203 mask: Option<&Array2<bool>>,
204 ) -> Result<Array2<f64>> {
205 let batch_size = query.nrows();
206 let seq_len = query.ncols() / self.embed_dim;
207
208 let q = self.project_to_heads(query, &self.query_circuit)?;
210 let k = self.project_to_heads(key, &self.key_circuit)?;
211 let v = self.project_to_heads(value, &self.value_circuit)?;
212
213 let attention_scores = self.compute_attention_scores(&q, &k)?;
215
216 let masked_scores = if let Some(mask) = mask {
218 self.apply_mask(&attention_scores, mask)?
219 } else {
220 attention_scores
221 };
222
223 let attention_weights = self.quantum_softmax(&masked_scores)?;
225
226 let attended_values = self.apply_attention(&attention_weights, &v)?;
228
229 self.project_output(&attended_values)
231 }
232
233 fn project_to_heads(
235 &self,
236 input: &Array2<f64>,
237 projection: &QuantumProjection,
238 ) -> Result<Array3<f64>> {
239 let batch_size = input.nrows();
240 let seq_len = input.ncols() / self.embed_dim;
241
242 let mut output = Array3::zeros((batch_size, self.num_heads, seq_len * self.head_dim));
243
244 for b in 0..batch_size {
245 for s in 0..seq_len {
246 let start = s * self.embed_dim;
247 let end = start + self.embed_dim;
248 let input_vec = input.row(b).slice(s![start..end]).to_owned();
249
250 let projected = projection.forward(&input_vec)?;
251
252 for h in 0..self.num_heads {
254 let head_start = h * self.head_dim;
255 let head_end = head_start + self.head_dim;
256
257 for i in 0..self.head_dim {
258 if head_start + i < projected.len() {
259 output[[b, h, s * self.head_dim + i]] = projected[head_start + i];
260 }
261 }
262 }
263 }
264 }
265
266 Ok(output)
267 }
268
269 fn compute_attention_scores(
271 &self,
272 query: &Array3<f64>,
273 key: &Array3<f64>,
274 ) -> Result<Array3<f64>> {
275 let batch_size = query.shape()[0];
276 let seq_len = query.shape()[2] / self.head_dim;
277
278 let mut scores = Array3::zeros((batch_size, self.num_heads, seq_len * seq_len));
279
280 for b in 0..batch_size {
282 for h in 0..self.num_heads {
283 for i in 0..seq_len {
284 for j in 0..seq_len {
285 let q_start = i * self.head_dim;
286 let q_end = q_start + self.head_dim;
287 let k_start = j * self.head_dim;
288 let k_end = k_start + self.head_dim;
289
290 let q_vec = query.slice(s![b, h, q_start..q_end]);
291 let k_vec = key.slice(s![b, h, k_start..k_end]);
292
293 let score =
295 self.quantum_inner_product(&q_vec.to_owned(), &k_vec.to_owned())?;
296 scores[[b, h, i * seq_len + j]] = score / self.temperature;
297 }
298 }
299 }
300 }
301
302 Ok(scores)
303 }
304
305 fn quantum_inner_product(&self, vec1: &Array1<f64>, vec2: &Array1<f64>) -> Result<f64> {
307 let num_qubits = self.qubits_per_head * 2 + 1; let mut circuit = VariationalCircuit::new(num_qubits);
310
311 for i in 0..self.qubits_per_head {
313 if i < vec1.len() {
314 let angle1 = vec1[i] * PI;
315 circuit.add_gate("RY", vec![i], vec![angle1.to_string()]);
316 }
317 if i < vec2.len() {
318 let angle2 = vec2[i] * PI;
319 circuit.add_gate(
320 "RY",
321 vec![i + self.qubits_per_head],
322 vec![angle2.to_string()],
323 );
324 }
325 }
326
327 circuit.add_gate("H", vec![num_qubits - 1], vec![]);
329
330 for i in 0..self.qubits_per_head {
332 circuit.add_gate(
333 "CSWAP",
334 vec![num_qubits - 1, i, i + self.qubits_per_head],
335 vec![],
336 );
337 }
338
339 circuit.add_gate("H", vec![num_qubits - 1], vec![]);
341
342 Ok(vec1.dot(vec2))
345 }
346
347 fn quantum_softmax(&self, scores: &Array3<f64>) -> Result<Array3<f64>> {
349 let mut output = scores.clone();
350
351 for b in 0..scores.shape()[0] {
353 for h in 0..scores.shape()[1] {
354 let head_scores = scores.slice(s![b, h, ..]);
355 let seq_len = (head_scores.len() as f64).sqrt() as usize;
356
357 for i in 0..seq_len {
358 let start = i * seq_len;
359 let end = start + seq_len;
360 let row_scores = head_scores.slice(s![start..end]);
361
362 let softmax_vals = self.quantum_softmax_circuit(&row_scores.to_owned())?;
364
365 for j in 0..seq_len {
366 output[[b, h, start + j]] = softmax_vals[j];
367 }
368 }
369 }
370 }
371
372 Ok(output)
373 }
374
375 fn quantum_softmax_circuit(&self, logits: &Array1<f64>) -> Result<Vec<f64>> {
377 let max_logit = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
379 let exp_logits: Vec<f64> = logits.iter().map(|&x| (x - max_logit).exp()).collect();
380 let sum_exp: f64 = exp_logits.iter().sum();
381
382 Ok(exp_logits.into_iter().map(|x| x / sum_exp).collect())
383 }
384
385 fn apply_attention(&self, weights: &Array3<f64>, values: &Array3<f64>) -> Result<Array3<f64>> {
387 let batch_size = weights.shape()[0];
388 let num_heads = weights.shape()[1];
389 let seq_len = (weights.shape()[2] as f64).sqrt() as usize;
390
391 let mut output = Array3::zeros((batch_size, num_heads, seq_len * self.head_dim));
392
393 for b in 0..batch_size {
394 for h in 0..num_heads {
395 for i in 0..seq_len {
396 for j in 0..seq_len {
397 let weight = weights[[b, h, i * seq_len + j]];
398
399 for d in 0..self.head_dim {
400 output[[b, h, i * self.head_dim + d]] +=
401 weight * values[[b, h, j * self.head_dim + d]];
402 }
403 }
404 }
405 }
406 }
407
408 Ok(output)
409 }
410
411 fn apply_mask(&self, scores: &Array3<f64>, mask: &Array2<bool>) -> Result<Array3<f64>> {
413 let mut masked_scores = scores.clone();
414
415 for b in 0..scores.shape()[0] {
416 for h in 0..scores.shape()[1] {
417 for (idx, &is_masked) in mask.iter().enumerate() {
418 if is_masked && idx < scores.shape()[2] {
419 masked_scores[[b, h, idx]] = -1e9; }
421 }
422 }
423 }
424
425 Ok(masked_scores)
426 }
427
428 fn project_output(&self, attended: &Array3<f64>) -> Result<Array2<f64>> {
430 let batch_size = attended.shape()[0];
431 let seq_len = attended.shape()[2] / self.head_dim;
432
433 let mut output = Array2::zeros((batch_size, seq_len * self.embed_dim));
434
435 for b in 0..batch_size {
436 for s in 0..seq_len {
437 let mut concat = Array1::zeros(self.embed_dim);
439 for h in 0..self.num_heads {
440 for d in 0..self.head_dim {
441 concat[h * self.head_dim + d] = attended[[b, h, s * self.head_dim + d]];
442 }
443 }
444
445 let projected = self.output_circuit.forward(&concat)?;
447
448 for d in 0..self.embed_dim {
449 output[[b, s * self.embed_dim + d]] = projected[d];
450 }
451 }
452 }
453
454 Ok(output)
455 }
456}
457
458#[derive(Debug)]
460pub struct QuantumTransformerBlock {
461 self_attention: QuantumSelfAttention,
463 ff_dim: usize,
465 ff1: QuantumFeedForward,
467 ff2: QuantumFeedForward,
469 layer_norm1: LayerNorm,
471 layer_norm2: LayerNorm,
472 dropout_rate: f64,
474}
475
476#[derive(Debug)]
478struct QuantumFeedForward {
479 input_dim: usize,
480 output_dim: usize,
481 circuit: VariationalCircuit,
482}
483
484impl QuantumFeedForward {
485 fn new(input_dim: usize, output_dim: usize) -> Self {
486 let num_qubits = ((input_dim.max(output_dim)) as f64).log2().ceil() as usize;
487 let circuit = Self::build_ff_circuit(num_qubits);
488
489 Self {
490 input_dim,
491 output_dim,
492 circuit,
493 }
494 }
495
496 fn build_ff_circuit(num_qubits: usize) -> VariationalCircuit {
497 let mut circuit = VariationalCircuit::new(num_qubits);
498
499 for layer in 0..3 {
501 for q in 0..num_qubits {
503 circuit.add_gate("RY", vec![q], vec![format!("ff_ry_{}_{}", layer, q)]);
504 circuit.add_gate("RZ", vec![q], vec![format!("ff_rz_{}_{}", layer, q)]);
505 }
506
507 for i in 0..num_qubits {
509 for j in i + 1..num_qubits {
510 circuit.add_gate("CZ", vec![i, j], vec![]);
511 }
512 }
513 }
514
515 circuit
516 }
517
518 fn forward(&self, input: &Array1<f64>) -> Result<Array1<f64>> {
519 let mut output = Array1::zeros(self.output_dim);
521
522 for i in 0..self.output_dim {
524 if i < input.len() {
525 output[i] = (input[i] * 2.0 * PI).sin() * 0.5 + 0.5;
526 }
527 }
528
529 Ok(output)
530 }
531}
532
533#[derive(Debug)]
535struct LayerNorm {
536 normalized_shape: usize,
537 epsilon: f64,
538}
539
540impl LayerNorm {
541 fn new(normalized_shape: usize) -> Self {
542 Self {
543 normalized_shape,
544 epsilon: 1e-5,
545 }
546 }
547
548 fn forward(&self, input: &Array2<f64>) -> Array2<f64> {
549 let mean = input.mean_axis(Axis(1)).unwrap();
550 let variance = input.var_axis(Axis(1), 0.0);
551
552 let mut output = input.clone();
553 for i in 0..input.nrows() {
554 let std = (variance[i] + self.epsilon).sqrt();
555 for j in 0..input.ncols() {
556 output[[i, j]] = (input[[i, j]] - mean[i]) / std;
557 }
558 }
559
560 output
561 }
562}
563
564impl QuantumTransformerBlock {
565 pub fn new(embed_dim: usize, num_heads: usize, ff_dim: usize, dropout_rate: f64) -> Self {
567 Self {
568 self_attention: QuantumSelfAttention::new(embed_dim, num_heads, dropout_rate),
569 ff_dim,
570 ff1: QuantumFeedForward::new(embed_dim, ff_dim),
571 ff2: QuantumFeedForward::new(ff_dim, embed_dim),
572 layer_norm1: LayerNorm::new(embed_dim),
573 layer_norm2: LayerNorm::new(embed_dim),
574 dropout_rate,
575 }
576 }
577
578 pub fn forward(&self, input: &Array2<f64>, mask: Option<&Array2<bool>>) -> Result<Array2<f64>> {
580 let attended = self.self_attention.forward(input, input, input, mask)?;
582 let residual1 = &attended + input;
583 let norm1 = self.layer_norm1.forward(&residual1);
584
585 let batch_size = norm1.nrows();
587 let seq_dim = norm1.ncols();
588 let seq_len = seq_dim / self.self_attention.embed_dim;
589
590 let mut ff_output = Array2::zeros((batch_size, seq_dim));
591
592 for b in 0..batch_size {
593 for s in 0..seq_len {
594 let start = s * self.self_attention.embed_dim;
595 let end = start + self.self_attention.embed_dim;
596
597 let input_slice = norm1.slice(s![b, start..end]).to_owned();
598 let hidden = self.ff1.forward(&input_slice)?;
599 let output = self.ff2.forward(&hidden)?;
600
601 for i in 0..self.self_attention.embed_dim {
602 ff_output[[b, start + i]] = output[i];
603 }
604 }
605 }
606
607 let residual2 = &ff_output + &norm1;
608 let output = self.layer_norm2.forward(&residual2);
609
610 Ok(output)
611 }
612}
613
614#[derive(Debug)]
616pub struct QuantumTransformer {
617 embed_dim: usize,
619 num_layers: usize,
621 blocks: Vec<QuantumTransformerBlock>,
623 positional_encoding: PositionalEncoding,
625}
626
627#[derive(Debug)]
629struct PositionalEncoding {
630 max_length: usize,
631 embed_dim: usize,
632}
633
634impl PositionalEncoding {
635 fn new(max_length: usize, embed_dim: usize) -> Self {
636 Self {
637 max_length,
638 embed_dim,
639 }
640 }
641
642 fn encode(&self, seq_len: usize) -> Array2<f64> {
643 let mut encoding = Array2::zeros((seq_len, self.embed_dim));
644
645 for pos in 0..seq_len {
646 for i in 0..self.embed_dim {
647 let angle = if i % 2 == 0 {
648 (pos as f64) / 10000_f64.powf((i as f64) / (self.embed_dim as f64))
649 } else {
650 (pos as f64) / 10000_f64.powf(((i - 1) as f64) / (self.embed_dim as f64))
651 };
652
653 encoding[[pos, i]] = if i % 2 == 0 { angle.sin() } else { angle.cos() };
654 }
655 }
656
657 encoding
658 }
659}
660
661impl QuantumTransformer {
662 pub fn new(
664 embed_dim: usize,
665 num_layers: usize,
666 num_heads: usize,
667 ff_dim: usize,
668 max_length: usize,
669 dropout_rate: f64,
670 ) -> Self {
671 let blocks = (0..num_layers)
672 .map(|_| QuantumTransformerBlock::new(embed_dim, num_heads, ff_dim, dropout_rate))
673 .collect();
674
675 Self {
676 embed_dim,
677 num_layers,
678 blocks,
679 positional_encoding: PositionalEncoding::new(max_length, embed_dim),
680 }
681 }
682
683 pub fn forward(&self, input: &Array2<f64>, mask: Option<&Array2<bool>>) -> Result<Array2<f64>> {
685 let seq_len = input.ncols() / self.embed_dim;
686
687 let pos_encoding = self.positional_encoding.encode(seq_len);
689 let mut encoded = input.clone();
690
691 for i in 0..input.nrows() {
692 for s in 0..seq_len {
693 for d in 0..self.embed_dim {
694 encoded[[i, s * self.embed_dim + d]] += pos_encoding[[s, d]];
695 }
696 }
697 }
698
699 let mut output = encoded;
701 for block in &self.blocks {
702 output = block.forward(&output, mask)?;
703 }
704
705 Ok(output)
706 }
707}
708
709#[cfg(test)]
710mod tests {
711 use super::*;
712 use ndarray::array;
713
714 #[test]
715 fn test_quantum_projection() {
716 let proj = QuantumProjection::new(8, 8);
717 let input = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]);
718
719 let output = proj.forward(&input).unwrap();
720 assert_eq!(output.len(), 8);
721 }
722
723 #[test]
724 fn test_quantum_self_attention() {
725 let attention = QuantumSelfAttention::new(16, 4, 0.1);
726
727 let batch_size = 2;
728 let seq_len = 3;
729 let embed_dim = 16;
730
731 let mut input = Array2::zeros((batch_size, seq_len * embed_dim));
733 for i in 0..batch_size {
734 for j in 0..seq_len * embed_dim {
735 input[[i, j]] = 0.1 + (i * seq_len * embed_dim + j) as f64 * 0.01;
736 }
737 }
738
739 let output = attention.forward(&input, &input, &input, None).unwrap();
740
741 assert_eq!(output.shape(), &[batch_size, seq_len * embed_dim]);
742 }
743
744 #[test]
745 fn test_quantum_transformer_block() {
746 let block = QuantumTransformerBlock::new(8, 2, 16, 0.1);
747
748 let batch_size = 1;
749 let seq_len = 2;
750 let embed_dim = 8;
751
752 let input = Array2::ones((batch_size, seq_len * embed_dim));
753 let output = block.forward(&input, None).unwrap();
754
755 assert_eq!(output.shape(), &[batch_size, seq_len * embed_dim]);
756 }
757
758 #[test]
759 fn test_positional_encoding() {
760 let pos_enc = PositionalEncoding::new(100, 16);
761 let encoding = pos_enc.encode(10);
762
763 assert_eq!(encoding.shape(), &[10, 16]);
764
765 let pos0 = encoding.row(0);
767 let pos1 = encoding.row(1);
768 let diff: f64 = (&pos1 - &pos0).iter().map(|x| x.abs()).sum();
769 assert!(diff > 0.0);
770 }
771
772 #[test]
773 fn test_quantum_transformer() {
774 let transformer = QuantumTransformer::new(8, 2, 2, 16, 100, 0.1);
775
776 let batch_size = 1;
777 let seq_len = 3;
778 let embed_dim = 8;
779
780 let input = Array2::zeros((batch_size, seq_len * embed_dim));
781 let output = transformer.forward(&input, None).unwrap();
782
783 assert_eq!(output.shape(), &[batch_size, seq_len * embed_dim]);
784 }
785}