1use crate::{
24 error::{QuantRS2Error, QuantRS2Result},
25 gate::GateOp,
26 qubit::QubitId,
27};
28use scirs2_core::ndarray::{Array1, Array2, Array3};
29use scirs2_core::random::prelude::*;
30use scirs2_core::Complex64;
31use std::f64::consts::PI;
32
33#[derive(Debug, Clone)]
35pub struct QuantumTransformerConfig {
36 pub num_qubits: usize,
38 pub num_heads: usize,
40 pub head_dim: usize,
42 pub num_layers: usize,
44 pub ffn_dim: usize,
46 pub dropout_rate: f64,
48 pub max_seq_length: usize,
50 pub use_layer_norm: bool,
52}
53
54impl Default for QuantumTransformerConfig {
55 fn default() -> Self {
56 Self {
57 num_qubits: 4,
58 num_heads: 2,
59 head_dim: 2,
60 num_layers: 2,
61 ffn_dim: 8,
62 dropout_rate: 0.1,
63 max_seq_length: 64,
64 use_layer_norm: true,
65 }
66 }
67}
68
69#[derive(Debug, Clone)]
71pub struct QuantumAttention {
72 num_qubits: usize,
74 num_heads: usize,
76 head_dim: usize,
78 query_params: Array2<f64>,
80 key_params: Array2<f64>,
82 value_params: Array2<f64>,
84 output_params: Array2<f64>,
86}
87
88impl QuantumAttention {
89 pub fn new(num_qubits: usize, num_heads: usize, head_dim: usize) -> QuantRS2Result<Self> {
91 if num_qubits < 2 {
92 return Err(QuantRS2Error::InvalidInput(
93 "Quantum attention requires at least 2 qubits".to_string(),
94 ));
95 }
96
97 if num_heads == 0 || head_dim == 0 {
98 return Err(QuantRS2Error::InvalidInput(
99 "Number of heads and head dimension must be positive".to_string(),
100 ));
101 }
102
103 let total_dim = num_heads * head_dim;
104 let mut rng = thread_rng();
105
106 let scale = (2.0 / (num_qubits as f64)).sqrt();
108
109 let query_params =
110 Array2::from_shape_fn((total_dim, num_qubits), |_| rng.random_range(-scale..scale));
111
112 let key_params =
113 Array2::from_shape_fn((total_dim, num_qubits), |_| rng.random_range(-scale..scale));
114
115 let value_params =
116 Array2::from_shape_fn((total_dim, num_qubits), |_| rng.random_range(-scale..scale));
117
118 let output_params =
119 Array2::from_shape_fn((num_qubits, total_dim), |_| rng.random_range(-scale..scale));
120
121 Ok(Self {
122 num_qubits,
123 num_heads,
124 head_dim,
125 query_params,
126 key_params,
127 value_params,
128 output_params,
129 })
130 }
131
132 pub fn attention_scores(
134 &self,
135 query: &Array2<Complex64>,
136 key: &Array2<Complex64>,
137 ) -> QuantRS2Result<Array2<f64>> {
138 let seq_len = query.shape()[0];
139 let mut scores = Array2::zeros((seq_len, seq_len));
140
141 for i in 0..seq_len {
143 for j in 0..seq_len {
144 let q = query.row(i);
145 let k = key.row(j);
146
147 let mut score = Complex64::new(0.0, 0.0);
149 for (qi, ki) in q.iter().zip(k.iter()) {
150 score += qi.conj() * ki;
151 }
152
153 let scaled_score = score.norm() / (self.head_dim as f64).sqrt();
155 scores[[i, j]] = scaled_score;
156 }
157 }
158
159 Ok(scores)
160 }
161
162 pub fn softmax(&self, scores: &Array2<f64>) -> Array2<f64> {
164 let seq_len = scores.shape()[0];
165 let mut softmax_scores = Array2::zeros((seq_len, seq_len));
166
167 for i in 0..seq_len {
168 let row = scores.row(i);
169 let max_score = row.iter().copied().fold(f64::NEG_INFINITY, f64::max);
170
171 let mut exp_scores = Array1::zeros(seq_len);
173 let mut sum_exp = 0.0;
174
175 for (j, &score) in row.iter().enumerate() {
176 let exp_val = (score - max_score).exp();
177 exp_scores[j] = exp_val;
178 sum_exp += exp_val;
179 }
180
181 for j in 0..seq_len {
183 softmax_scores[[i, j]] = exp_scores[j] / sum_exp;
184 }
185 }
186
187 softmax_scores
188 }
189
190 pub fn forward(&self, input: &Array2<Complex64>) -> QuantRS2Result<Array2<Complex64>> {
192 let seq_len = input.shape()[0];
193
194 let query = self.project_qkv(input, &self.query_params)?;
196 let key = self.project_qkv(input, &self.key_params)?;
197 let value = self.project_qkv(input, &self.value_params)?;
198
199 let scores = self.attention_scores(&query, &key)?;
201 let attention_weights = self.softmax(&scores);
202
203 let total_dim = self.num_heads * self.head_dim;
205 let mut output = Array2::zeros((seq_len, total_dim));
206
207 for i in 0..seq_len {
208 for j in 0..seq_len {
209 let weight = attention_weights[[i, j]];
210 for k in 0..total_dim {
211 output[[i, k]] = output[[i, k]] + value[[j, k]] * weight;
212 }
213 }
214 }
215
216 self.project_output(&output)
218 }
219
220 fn project_qkv(
222 &self,
223 input: &Array2<Complex64>,
224 params: &Array2<f64>,
225 ) -> QuantRS2Result<Array2<Complex64>> {
226 let seq_len = input.shape()[0];
227 let out_dim = params.shape()[0];
228 let mut output = Array2::zeros((seq_len, out_dim));
229
230 for i in 0..seq_len {
231 for j in 0..out_dim {
232 let mut sum = Complex64::new(0.0, 0.0);
233 for k in 0..self.num_qubits {
234 let angle = params[[j, k]];
236 let rotation = Complex64::new(angle.cos(), angle.sin());
237 sum += input[[i, k]] * rotation;
238 }
239 output[[i, j]] = sum;
240 }
241 }
242
243 Ok(output)
244 }
245
246 fn project_output(
248 &self,
249 attention_out: &Array2<Complex64>,
250 ) -> QuantRS2Result<Array2<Complex64>> {
251 let seq_len = attention_out.shape()[0];
252 let mut output = Array2::zeros((seq_len, self.num_qubits));
253
254 for i in 0..seq_len {
255 for j in 0..self.num_qubits {
256 let mut sum = Complex64::new(0.0, 0.0);
257 for k in 0..(self.num_heads * self.head_dim) {
258 let angle = self.output_params[[j, k]];
259 let rotation = Complex64::new(angle.cos(), angle.sin());
260 sum += attention_out[[i, k]] * rotation;
261 }
262 output[[i, j]] = sum;
263 }
264 }
265
266 Ok(output)
267 }
268}
269
270#[derive(Debug, Clone)]
272pub struct QuantumPositionalEncoding {
273 max_seq_length: usize,
275 num_qubits: usize,
277 encoding: Array2<f64>,
279}
280
281impl QuantumPositionalEncoding {
282 pub fn new(max_seq_length: usize, num_qubits: usize) -> Self {
284 let mut encoding = Array2::zeros((max_seq_length, num_qubits));
285
286 for pos in 0..max_seq_length {
288 for i in 0..num_qubits {
289 if i % 2 == 0 {
290 let freq = 1.0 / 10000_f64.powf(i as f64 / num_qubits as f64);
291 encoding[[pos, i]] = (pos as f64 * freq).sin();
292 } else {
293 let freq = 1.0 / 10000_f64.powf((i - 1) as f64 / num_qubits as f64);
294 encoding[[pos, i]] = (pos as f64 * freq).cos();
295 }
296 }
297 }
298
299 Self {
300 max_seq_length,
301 num_qubits,
302 encoding,
303 }
304 }
305
306 pub fn encode(&self, input: &Array2<Complex64>) -> QuantRS2Result<Array2<Complex64>> {
308 let seq_len = input.shape()[0];
309
310 if seq_len > self.max_seq_length {
311 return Err(QuantRS2Error::InvalidInput(format!(
312 "Sequence length {} exceeds maximum {}",
313 seq_len, self.max_seq_length
314 )));
315 }
316
317 let mut output = input.clone();
318
319 for i in 0..seq_len {
321 for j in 0..self.num_qubits {
322 let phase = self.encoding[[i, j]];
323 let phase_shift = Complex64::new(phase.cos(), phase.sin());
324 output[[i, j]] = output[[i, j]] * phase_shift;
325 }
326 }
327
328 Ok(output)
329 }
330}
331
332#[derive(Debug, Clone)]
334pub struct QuantumFeedForward {
335 input_dim: usize,
337 hidden_dim: usize,
339 w1: Array2<f64>,
341 w2: Array2<f64>,
343}
344
345impl QuantumFeedForward {
346 pub fn new(input_dim: usize, hidden_dim: usize) -> Self {
348 let mut rng = thread_rng();
349 let scale1 = (2.0 / input_dim as f64).sqrt();
350 let scale2 = (2.0 / hidden_dim as f64).sqrt();
351
352 let w1 = Array2::from_shape_fn((hidden_dim, input_dim), |_| {
353 rng.random_range(-scale1..scale1)
354 });
355
356 let w2 = Array2::from_shape_fn((input_dim, hidden_dim), |_| {
357 rng.random_range(-scale2..scale2)
358 });
359
360 Self {
361 input_dim,
362 hidden_dim,
363 w1,
364 w2,
365 }
366 }
367
368 pub fn forward(&self, input: &Array2<Complex64>) -> QuantRS2Result<Array2<Complex64>> {
370 let seq_len = input.shape()[0];
371
372 let mut hidden = Array2::zeros((seq_len, self.hidden_dim));
374 for i in 0..seq_len {
375 for j in 0..self.hidden_dim {
376 let mut sum = Complex64::new(0.0, 0.0);
377 for k in 0..self.input_dim {
378 let angle = self.w1[[j, k]];
379 let rotation = Complex64::new(angle.cos(), angle.sin());
380 sum += input[[i, k]] * rotation;
381 }
382 hidden[[i, j]] = self.quantum_activation(sum);
384 }
385 }
386
387 let mut output = Array2::zeros((seq_len, self.input_dim));
389 for i in 0..seq_len {
390 for j in 0..self.input_dim {
391 let mut sum = Complex64::new(0.0, 0.0);
392 for k in 0..self.hidden_dim {
393 let angle = self.w2[[j, k]];
394 let rotation = Complex64::new(angle.cos(), angle.sin());
395 sum += hidden[[i, k]] * rotation;
396 }
397 output[[i, j]] = sum;
398 }
399 }
400
401 Ok(output)
402 }
403
404 fn quantum_activation(&self, z: Complex64) -> Complex64 {
406 let amplitude = z.norm();
408 let phase = z.arg();
409
410 if amplitude > 0.0 {
411 let amplified = amplitude.tanh();
413 Complex64::new(amplified * phase.cos(), amplified * phase.sin())
414 } else {
415 Complex64::new(0.0, 0.0)
416 }
417 }
418}
419
420#[derive(Debug, Clone)]
422pub struct QuantumTransformerLayer {
423 attention: QuantumAttention,
425 ffn: QuantumFeedForward,
427 config: QuantumTransformerConfig,
429}
430
431impl QuantumTransformerLayer {
432 pub fn new(config: QuantumTransformerConfig) -> QuantRS2Result<Self> {
434 let attention =
435 QuantumAttention::new(config.num_qubits, config.num_heads, config.head_dim)?;
436
437 let ffn = QuantumFeedForward::new(config.num_qubits, config.ffn_dim);
438
439 Ok(Self {
440 attention,
441 ffn,
442 config,
443 })
444 }
445
446 pub fn forward(&self, input: &Array2<Complex64>) -> QuantRS2Result<Array2<Complex64>> {
448 let attention_out = self.attention.forward(input)?;
450 let after_attention = self.add_residual(input, &attention_out);
451
452 let normalized = if self.config.use_layer_norm {
454 self.layer_norm(&after_attention)?
455 } else {
456 after_attention
457 };
458
459 let ffn_out = self.ffn.forward(&normalized)?;
461 let output = self.add_residual(&normalized, &ffn_out);
462
463 if self.config.use_layer_norm {
465 self.layer_norm(&output)
466 } else {
467 Ok(output)
468 }
469 }
470
471 fn add_residual(
473 &self,
474 input: &Array2<Complex64>,
475 residual: &Array2<Complex64>,
476 ) -> Array2<Complex64> {
477 input + residual
478 }
479
480 fn layer_norm(&self, input: &Array2<Complex64>) -> QuantRS2Result<Array2<Complex64>> {
482 let seq_len = input.shape()[0];
483 let num_features = input.shape()[1];
484 let mut output = Array2::zeros((seq_len, num_features));
485
486 for i in 0..seq_len {
487 let row = input.row(i);
488
489 let mut mean_real = 0.0;
491 let mut mean_imag = 0.0;
492 for val in row {
493 mean_real += val.re;
494 mean_imag += val.im;
495 }
496 mean_real /= num_features as f64;
497 mean_imag /= num_features as f64;
498 let mean = Complex64::new(mean_real, mean_imag);
499
500 let mut variance = 0.0;
501 for val in row {
502 let diff = val - mean;
503 variance += diff.norm_sqr();
504 }
505 variance /= num_features as f64;
506
507 let std = (variance + 1e-5).sqrt();
508
509 for j in 0..num_features {
511 output[[i, j]] = (input[[i, j]] - mean) / std;
512 }
513 }
514
515 Ok(output)
516 }
517}
518
519#[derive(Debug, Clone)]
521pub struct QuantumTransformer {
522 config: QuantumTransformerConfig,
524 pos_encoding: QuantumPositionalEncoding,
526 layers: Vec<QuantumTransformerLayer>,
528}
529
530impl QuantumTransformer {
531 pub fn new(config: QuantumTransformerConfig) -> QuantRS2Result<Self> {
533 let pos_encoding = QuantumPositionalEncoding::new(config.max_seq_length, config.num_qubits);
534
535 let mut layers = Vec::with_capacity(config.num_layers);
536 for _ in 0..config.num_layers {
537 layers.push(QuantumTransformerLayer::new(config.clone())?);
538 }
539
540 Ok(Self {
541 config,
542 pos_encoding,
543 layers,
544 })
545 }
546
547 pub fn forward(&self, input: &Array2<Complex64>) -> QuantRS2Result<Array2<Complex64>> {
549 let mut x = self.pos_encoding.encode(input)?;
551
552 for layer in &self.layers {
554 x = layer.forward(&x)?;
555 }
556
557 Ok(x)
558 }
559
560 pub const fn config(&self) -> &QuantumTransformerConfig {
562 &self.config
563 }
564}
565
566#[cfg(test)]
567mod tests {
568 use super::*;
569
570 #[test]
571 fn test_quantum_attention() {
572 let attention = QuantumAttention::new(4, 2, 2).expect("Failed to create QuantumAttention");
573
574 let mut input = Array2::zeros((3, 4));
576 for i in 0..3 {
577 for j in 0..4 {
578 input[[i, j]] = Complex64::new((i + j) as f64 * 0.1, 0.0);
579 }
580 }
581
582 let output = attention
583 .forward(&input)
584 .expect("Attention forward pass should succeed");
585 assert_eq!(output.shape(), &[3, 4]);
586 }
587
588 #[test]
589 fn test_positional_encoding() {
590 let pos_enc = QuantumPositionalEncoding::new(64, 4);
591
592 let mut input = Array2::zeros((3, 4));
593 for i in 0..3 {
594 for j in 0..4 {
595 input[[i, j]] = Complex64::new(1.0, 0.0);
596 }
597 }
598
599 let encoded = pos_enc
600 .encode(&input)
601 .expect("Positional encoding should succeed");
602 assert_eq!(encoded.shape(), &[3, 4]);
603 }
604
605 #[test]
606 fn test_quantum_transformer() {
607 let config = QuantumTransformerConfig {
608 num_qubits: 4,
609 num_heads: 2,
610 head_dim: 2,
611 num_layers: 2,
612 ffn_dim: 8,
613 dropout_rate: 0.1,
614 max_seq_length: 64,
615 use_layer_norm: true,
616 };
617
618 let transformer =
619 QuantumTransformer::new(config).expect("Failed to create QuantumTransformer");
620
621 let mut input = Array2::zeros((3, 4));
623 for i in 0..3 {
624 for j in 0..4 {
625 input[[i, j]] = Complex64::new((i + j) as f64 * 0.1, 0.0);
626 }
627 }
628
629 let output = transformer
630 .forward(&input)
631 .expect("Transformer forward pass should succeed");
632 assert_eq!(output.shape(), &[3, 4]);
633 }
634}