1use crate::{
24 error::{QuantRS2Error, QuantRS2Result},
25 gate::GateOp,
26 qubit::QubitId,
27};
28use scirs2_core::ndarray::{Array1, Array2, Array3};
29use scirs2_core::random::prelude::*;
30use scirs2_core::Complex64;
31use std::f64::consts::PI;
32
33#[derive(Debug, Clone)]
35pub struct QuantumTransformerConfig {
36 pub num_qubits: usize,
38 pub num_heads: usize,
40 pub head_dim: usize,
42 pub num_layers: usize,
44 pub ffn_dim: usize,
46 pub dropout_rate: f64,
48 pub max_seq_length: usize,
50 pub use_layer_norm: bool,
52}
53
54impl Default for QuantumTransformerConfig {
55 fn default() -> Self {
56 Self {
57 num_qubits: 4,
58 num_heads: 2,
59 head_dim: 2,
60 num_layers: 2,
61 ffn_dim: 8,
62 dropout_rate: 0.1,
63 max_seq_length: 64,
64 use_layer_norm: true,
65 }
66 }
67}
68
69#[derive(Debug, Clone)]
71pub struct QuantumAttention {
72 num_qubits: usize,
74 num_heads: usize,
76 head_dim: usize,
78 query_params: Array2<f64>,
80 key_params: Array2<f64>,
82 value_params: Array2<f64>,
84 output_params: Array2<f64>,
86}
87
88impl QuantumAttention {
89 pub fn new(num_qubits: usize, num_heads: usize, head_dim: usize) -> QuantRS2Result<Self> {
91 if num_qubits < 2 {
92 return Err(QuantRS2Error::InvalidInput(
93 "Quantum attention requires at least 2 qubits".to_string(),
94 ));
95 }
96
97 if num_heads == 0 || head_dim == 0 {
98 return Err(QuantRS2Error::InvalidInput(
99 "Number of heads and head dimension must be positive".to_string(),
100 ));
101 }
102
103 let total_dim = num_heads * head_dim;
104 let mut rng = thread_rng();
105
106 let scale = (2.0 / (num_qubits as f64)).sqrt();
108
109 let query_params =
110 Array2::from_shape_fn((total_dim, num_qubits), |_| rng.gen_range(-scale..scale));
111
112 let key_params =
113 Array2::from_shape_fn((total_dim, num_qubits), |_| rng.gen_range(-scale..scale));
114
115 let value_params =
116 Array2::from_shape_fn((total_dim, num_qubits), |_| rng.gen_range(-scale..scale));
117
118 let output_params =
119 Array2::from_shape_fn((num_qubits, total_dim), |_| rng.gen_range(-scale..scale));
120
121 Ok(Self {
122 num_qubits,
123 num_heads,
124 head_dim,
125 query_params,
126 key_params,
127 value_params,
128 output_params,
129 })
130 }
131
132 pub fn attention_scores(
134 &self,
135 query: &Array2<Complex64>,
136 key: &Array2<Complex64>,
137 ) -> QuantRS2Result<Array2<f64>> {
138 let seq_len = query.shape()[0];
139 let mut scores = Array2::zeros((seq_len, seq_len));
140
141 for i in 0..seq_len {
143 for j in 0..seq_len {
144 let q = query.row(i);
145 let k = key.row(j);
146
147 let mut score = Complex64::new(0.0, 0.0);
149 for (qi, ki) in q.iter().zip(k.iter()) {
150 score += qi.conj() * ki;
151 }
152
153 let scaled_score = score.norm() / (self.head_dim as f64).sqrt();
155 scores[[i, j]] = scaled_score;
156 }
157 }
158
159 Ok(scores)
160 }
161
162 pub fn softmax(&self, scores: &Array2<f64>) -> Array2<f64> {
164 let seq_len = scores.shape()[0];
165 let mut softmax_scores = Array2::zeros((seq_len, seq_len));
166
167 for i in 0..seq_len {
168 let row = scores.row(i);
169 let max_score = row.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
170
171 let mut exp_scores = Array1::zeros(seq_len);
173 let mut sum_exp = 0.0;
174
175 for (j, &score) in row.iter().enumerate() {
176 let exp_val = (score - max_score).exp();
177 exp_scores[j] = exp_val;
178 sum_exp += exp_val;
179 }
180
181 for j in 0..seq_len {
183 softmax_scores[[i, j]] = exp_scores[j] / sum_exp;
184 }
185 }
186
187 softmax_scores
188 }
189
190 pub fn forward(&self, input: &Array2<Complex64>) -> QuantRS2Result<Array2<Complex64>> {
192 let seq_len = input.shape()[0];
193
194 let query = self.project_qkv(input, &self.query_params)?;
196 let key = self.project_qkv(input, &self.key_params)?;
197 let value = self.project_qkv(input, &self.value_params)?;
198
199 let scores = self.attention_scores(&query, &key)?;
201 let attention_weights = self.softmax(&scores);
202
203 let total_dim = self.num_heads * self.head_dim;
205 let mut output = Array2::zeros((seq_len, total_dim));
206
207 for i in 0..seq_len {
208 for j in 0..seq_len {
209 let weight = attention_weights[[i, j]];
210 for k in 0..total_dim {
211 output[[i, k]] = output[[i, k]] + value[[j, k]] * weight;
212 }
213 }
214 }
215
216 self.project_output(&output)
218 }
219
220 fn project_qkv(
222 &self,
223 input: &Array2<Complex64>,
224 params: &Array2<f64>,
225 ) -> QuantRS2Result<Array2<Complex64>> {
226 let seq_len = input.shape()[0];
227 let out_dim = params.shape()[0];
228 let mut output = Array2::zeros((seq_len, out_dim));
229
230 for i in 0..seq_len {
231 for j in 0..out_dim {
232 let mut sum = Complex64::new(0.0, 0.0);
233 for k in 0..self.num_qubits {
234 let angle = params[[j, k]];
236 let rotation = Complex64::new(angle.cos(), angle.sin());
237 sum += input[[i, k]] * rotation;
238 }
239 output[[i, j]] = sum;
240 }
241 }
242
243 Ok(output)
244 }
245
246 fn project_output(
248 &self,
249 attention_out: &Array2<Complex64>,
250 ) -> QuantRS2Result<Array2<Complex64>> {
251 let seq_len = attention_out.shape()[0];
252 let mut output = Array2::zeros((seq_len, self.num_qubits));
253
254 for i in 0..seq_len {
255 for j in 0..self.num_qubits {
256 let mut sum = Complex64::new(0.0, 0.0);
257 for k in 0..(self.num_heads * self.head_dim) {
258 let angle = self.output_params[[j, k]];
259 let rotation = Complex64::new(angle.cos(), angle.sin());
260 sum += attention_out[[i, k]] * rotation;
261 }
262 output[[i, j]] = sum;
263 }
264 }
265
266 Ok(output)
267 }
268}
269
270#[derive(Debug, Clone)]
272pub struct QuantumPositionalEncoding {
273 max_seq_length: usize,
275 num_qubits: usize,
277 encoding: Array2<f64>,
279}
280
281impl QuantumPositionalEncoding {
282 pub fn new(max_seq_length: usize, num_qubits: usize) -> Self {
284 let mut encoding = Array2::zeros((max_seq_length, num_qubits));
285
286 for pos in 0..max_seq_length {
288 for i in 0..num_qubits {
289 if i % 2 == 0 {
290 let freq = 1.0 / 10000_f64.powf(i as f64 / num_qubits as f64);
291 encoding[[pos, i]] = (pos as f64 * freq).sin();
292 } else {
293 let freq = 1.0 / 10000_f64.powf((i - 1) as f64 / num_qubits as f64);
294 encoding[[pos, i]] = (pos as f64 * freq).cos();
295 }
296 }
297 }
298
299 Self {
300 max_seq_length,
301 num_qubits,
302 encoding,
303 }
304 }
305
306 pub fn encode(&self, input: &Array2<Complex64>) -> QuantRS2Result<Array2<Complex64>> {
308 let seq_len = input.shape()[0];
309
310 if seq_len > self.max_seq_length {
311 return Err(QuantRS2Error::InvalidInput(format!(
312 "Sequence length {} exceeds maximum {}",
313 seq_len, self.max_seq_length
314 )));
315 }
316
317 let mut output = input.clone();
318
319 for i in 0..seq_len {
321 for j in 0..self.num_qubits {
322 let phase = self.encoding[[i, j]];
323 let phase_shift = Complex64::new(phase.cos(), phase.sin());
324 output[[i, j]] = output[[i, j]] * phase_shift;
325 }
326 }
327
328 Ok(output)
329 }
330}
331
332#[derive(Debug, Clone)]
334pub struct QuantumFeedForward {
335 input_dim: usize,
337 hidden_dim: usize,
339 w1: Array2<f64>,
341 w2: Array2<f64>,
343}
344
345impl QuantumFeedForward {
346 pub fn new(input_dim: usize, hidden_dim: usize) -> Self {
348 let mut rng = thread_rng();
349 let scale1 = (2.0 / input_dim as f64).sqrt();
350 let scale2 = (2.0 / hidden_dim as f64).sqrt();
351
352 let w1 = Array2::from_shape_fn((hidden_dim, input_dim), |_| rng.gen_range(-scale1..scale1));
353
354 let w2 = Array2::from_shape_fn((input_dim, hidden_dim), |_| rng.gen_range(-scale2..scale2));
355
356 Self {
357 input_dim,
358 hidden_dim,
359 w1,
360 w2,
361 }
362 }
363
364 pub fn forward(&self, input: &Array2<Complex64>) -> QuantRS2Result<Array2<Complex64>> {
366 let seq_len = input.shape()[0];
367
368 let mut hidden = Array2::zeros((seq_len, self.hidden_dim));
370 for i in 0..seq_len {
371 for j in 0..self.hidden_dim {
372 let mut sum = Complex64::new(0.0, 0.0);
373 for k in 0..self.input_dim {
374 let angle = self.w1[[j, k]];
375 let rotation = Complex64::new(angle.cos(), angle.sin());
376 sum += input[[i, k]] * rotation;
377 }
378 hidden[[i, j]] = self.quantum_activation(sum);
380 }
381 }
382
383 let mut output = Array2::zeros((seq_len, self.input_dim));
385 for i in 0..seq_len {
386 for j in 0..self.input_dim {
387 let mut sum = Complex64::new(0.0, 0.0);
388 for k in 0..self.hidden_dim {
389 let angle = self.w2[[j, k]];
390 let rotation = Complex64::new(angle.cos(), angle.sin());
391 sum += hidden[[i, k]] * rotation;
392 }
393 output[[i, j]] = sum;
394 }
395 }
396
397 Ok(output)
398 }
399
400 fn quantum_activation(&self, z: Complex64) -> Complex64 {
402 let amplitude = z.norm();
404 let phase = z.arg();
405
406 if amplitude > 0.0 {
407 let amplified = amplitude.tanh();
409 Complex64::new(amplified * phase.cos(), amplified * phase.sin())
410 } else {
411 Complex64::new(0.0, 0.0)
412 }
413 }
414}
415
416#[derive(Debug, Clone)]
418pub struct QuantumTransformerLayer {
419 attention: QuantumAttention,
421 ffn: QuantumFeedForward,
423 config: QuantumTransformerConfig,
425}
426
427impl QuantumTransformerLayer {
428 pub fn new(config: QuantumTransformerConfig) -> QuantRS2Result<Self> {
430 let attention =
431 QuantumAttention::new(config.num_qubits, config.num_heads, config.head_dim)?;
432
433 let ffn = QuantumFeedForward::new(config.num_qubits, config.ffn_dim);
434
435 Ok(Self {
436 attention,
437 ffn,
438 config,
439 })
440 }
441
442 pub fn forward(&self, input: &Array2<Complex64>) -> QuantRS2Result<Array2<Complex64>> {
444 let attention_out = self.attention.forward(input)?;
446 let after_attention = self.add_residual(input, &attention_out);
447
448 let normalized = if self.config.use_layer_norm {
450 self.layer_norm(&after_attention)?
451 } else {
452 after_attention
453 };
454
455 let ffn_out = self.ffn.forward(&normalized)?;
457 let output = self.add_residual(&normalized, &ffn_out);
458
459 if self.config.use_layer_norm {
461 self.layer_norm(&output)
462 } else {
463 Ok(output)
464 }
465 }
466
467 fn add_residual(
469 &self,
470 input: &Array2<Complex64>,
471 residual: &Array2<Complex64>,
472 ) -> Array2<Complex64> {
473 input + residual
474 }
475
476 fn layer_norm(&self, input: &Array2<Complex64>) -> QuantRS2Result<Array2<Complex64>> {
478 let seq_len = input.shape()[0];
479 let num_features = input.shape()[1];
480 let mut output = Array2::zeros((seq_len, num_features));
481
482 for i in 0..seq_len {
483 let row = input.row(i);
484
485 let mut mean_real = 0.0;
487 let mut mean_imag = 0.0;
488 for val in row.iter() {
489 mean_real += val.re;
490 mean_imag += val.im;
491 }
492 mean_real /= num_features as f64;
493 mean_imag /= num_features as f64;
494 let mean = Complex64::new(mean_real, mean_imag);
495
496 let mut variance = 0.0;
497 for val in row.iter() {
498 let diff = val - mean;
499 variance += diff.norm_sqr();
500 }
501 variance /= num_features as f64;
502
503 let std = (variance + 1e-5).sqrt();
504
505 for j in 0..num_features {
507 output[[i, j]] = (input[[i, j]] - mean) / std;
508 }
509 }
510
511 Ok(output)
512 }
513}
514
515#[derive(Debug, Clone)]
517pub struct QuantumTransformer {
518 config: QuantumTransformerConfig,
520 pos_encoding: QuantumPositionalEncoding,
522 layers: Vec<QuantumTransformerLayer>,
524}
525
526impl QuantumTransformer {
527 pub fn new(config: QuantumTransformerConfig) -> QuantRS2Result<Self> {
529 let pos_encoding = QuantumPositionalEncoding::new(config.max_seq_length, config.num_qubits);
530
531 let mut layers = Vec::with_capacity(config.num_layers);
532 for _ in 0..config.num_layers {
533 layers.push(QuantumTransformerLayer::new(config.clone())?);
534 }
535
536 Ok(Self {
537 config,
538 pos_encoding,
539 layers,
540 })
541 }
542
543 pub fn forward(&self, input: &Array2<Complex64>) -> QuantRS2Result<Array2<Complex64>> {
545 let mut x = self.pos_encoding.encode(input)?;
547
548 for layer in &self.layers {
550 x = layer.forward(&x)?;
551 }
552
553 Ok(x)
554 }
555
556 pub fn config(&self) -> &QuantumTransformerConfig {
558 &self.config
559 }
560}
561
562#[cfg(test)]
563mod tests {
564 use super::*;
565
566 #[test]
567 fn test_quantum_attention() {
568 let attention = QuantumAttention::new(4, 2, 2).unwrap();
569
570 let mut input = Array2::zeros((3, 4));
572 for i in 0..3 {
573 for j in 0..4 {
574 input[[i, j]] = Complex64::new((i + j) as f64 * 0.1, 0.0);
575 }
576 }
577
578 let output = attention.forward(&input).unwrap();
579 assert_eq!(output.shape(), &[3, 4]);
580 }
581
582 #[test]
583 fn test_positional_encoding() {
584 let pos_enc = QuantumPositionalEncoding::new(64, 4);
585
586 let mut input = Array2::zeros((3, 4));
587 for i in 0..3 {
588 for j in 0..4 {
589 input[[i, j]] = Complex64::new(1.0, 0.0);
590 }
591 }
592
593 let encoded = pos_enc.encode(&input).unwrap();
594 assert_eq!(encoded.shape(), &[3, 4]);
595 }
596
597 #[test]
598 fn test_quantum_transformer() {
599 let config = QuantumTransformerConfig {
600 num_qubits: 4,
601 num_heads: 2,
602 head_dim: 2,
603 num_layers: 2,
604 ffn_dim: 8,
605 dropout_rate: 0.1,
606 max_seq_length: 64,
607 use_layer_norm: true,
608 };
609
610 let transformer = QuantumTransformer::new(config).unwrap();
611
612 let mut input = Array2::zeros((3, 4));
614 for i in 0..3 {
615 for j in 0..4 {
616 input[[i, j]] = Complex64::new((i + j) as f64 * 0.1, 0.0);
617 }
618 }
619
620 let output = transformer.forward(&input).unwrap();
621 assert_eq!(output.shape(), &[3, 4]);
622 }
623}