1use crate::error::{MLError, Result};
6use crate::torchquantum::{
7 CType, NParamsEnum, OpHistoryEntry, TQDevice, TQModule, TQOperator, TQParameter, WiresEnum,
8};
9use scirs2_core::ndarray::{Array2, ArrayD, IxDyn};
10
11#[derive(Debug, Clone)]
13pub struct TQRXX {
14 params: Option<TQParameter>,
15 has_params: bool,
16 trainable: bool,
17 inverse: bool,
18 static_mode: bool,
19}
20
21impl TQRXX {
22 pub fn new(has_params: bool, trainable: bool) -> Self {
23 let params = if has_params {
24 Some(TQParameter::new(ArrayD::zeros(IxDyn(&[1, 1])), "rxx_theta"))
25 } else {
26 None
27 };
28
29 Self {
30 params,
31 has_params,
32 trainable,
33 inverse: false,
34 static_mode: false,
35 }
36 }
37
38 pub fn with_init_params(mut self, theta: f64) -> Self {
39 if let Some(ref mut p) = self.params {
40 p.data[[0, 0]] = theta;
41 }
42 self
43 }
44}
45
46impl Default for TQRXX {
47 fn default() -> Self {
48 Self::new(true, true)
49 }
50}
51
52impl TQModule for TQRXX {
53 fn forward(&mut self, _qdev: &mut TQDevice) -> Result<()> {
54 Err(MLError::InvalidConfiguration(
55 "Use apply() instead of forward() for operators".to_string(),
56 ))
57 }
58
59 fn parameters(&self) -> Vec<TQParameter> {
60 self.params.iter().cloned().collect()
61 }
62
63 fn n_wires(&self) -> Option<usize> {
64 Some(2)
65 }
66
67 fn set_n_wires(&mut self, _n_wires: usize) {}
68
69 fn is_static_mode(&self) -> bool {
70 self.static_mode
71 }
72
73 fn static_on(&mut self) {
74 self.static_mode = true;
75 }
76
77 fn static_off(&mut self) {
78 self.static_mode = false;
79 }
80
81 fn name(&self) -> &str {
82 "RXX"
83 }
84
85 fn zero_grad(&mut self) {
86 if let Some(ref mut p) = self.params {
87 p.zero_grad();
88 }
89 }
90}
91
92impl TQOperator for TQRXX {
93 fn num_wires(&self) -> WiresEnum {
94 WiresEnum::Fixed(2)
95 }
96
97 fn num_params(&self) -> NParamsEnum {
98 NParamsEnum::Fixed(1)
99 }
100
101 fn get_matrix(&self, params: Option<&[f64]>) -> Array2<CType> {
102 let theta = params
103 .and_then(|p| p.first().copied())
104 .or_else(|| self.params.as_ref().map(|p| p.data[[0, 0]]))
105 .unwrap_or(0.0);
106
107 let theta = if self.inverse { -theta } else { theta };
108 let half_theta = theta / 2.0;
109 let c = half_theta.cos();
110 let s = half_theta.sin();
111
112 Array2::from_shape_vec(
117 (4, 4),
118 vec![
119 CType::new(c, 0.0),
120 CType::new(0.0, 0.0),
121 CType::new(0.0, 0.0),
122 CType::new(0.0, -s),
123 CType::new(0.0, 0.0),
124 CType::new(c, 0.0),
125 CType::new(0.0, -s),
126 CType::new(0.0, 0.0),
127 CType::new(0.0, 0.0),
128 CType::new(0.0, -s),
129 CType::new(c, 0.0),
130 CType::new(0.0, 0.0),
131 CType::new(0.0, -s),
132 CType::new(0.0, 0.0),
133 CType::new(0.0, 0.0),
134 CType::new(c, 0.0),
135 ],
136 )
137 .unwrap_or_else(|_| Array2::eye(4).mapv(|x| CType::new(x, 0.0)))
138 }
139
140 fn apply(&mut self, qdev: &mut TQDevice, wires: &[usize]) -> Result<()> {
141 self.apply_with_params(qdev, wires, None)
142 }
143
144 fn apply_with_params(
145 &mut self,
146 qdev: &mut TQDevice,
147 wires: &[usize],
148 params: Option<&[f64]>,
149 ) -> Result<()> {
150 if wires.len() < 2 {
151 return Err(MLError::InvalidConfiguration(
152 "RXX gate requires exactly 2 wires".to_string(),
153 ));
154 }
155 let matrix = self.get_matrix(params);
156 qdev.apply_two_qubit_gate(wires[0], wires[1], &matrix)?;
157
158 if qdev.record_op {
159 qdev.record_operation(OpHistoryEntry {
160 name: "rxx".to_string(),
161 wires: wires.to_vec(),
162 params: params.map(|p| p.to_vec()),
163 inverse: self.inverse,
164 trainable: self.trainable,
165 });
166 }
167
168 Ok(())
169 }
170
171 fn has_params(&self) -> bool {
172 self.has_params
173 }
174
175 fn trainable(&self) -> bool {
176 self.trainable
177 }
178
179 fn inverse(&self) -> bool {
180 self.inverse
181 }
182
183 fn set_inverse(&mut self, inverse: bool) {
184 self.inverse = inverse;
185 }
186}
187
188#[derive(Debug, Clone)]
190pub struct TQRYY {
191 params: Option<TQParameter>,
192 has_params: bool,
193 trainable: bool,
194 inverse: bool,
195 static_mode: bool,
196}
197
198impl TQRYY {
199 pub fn new(has_params: bool, trainable: bool) -> Self {
200 let params = if has_params {
201 Some(TQParameter::new(ArrayD::zeros(IxDyn(&[1, 1])), "ryy_theta"))
202 } else {
203 None
204 };
205
206 Self {
207 params,
208 has_params,
209 trainable,
210 inverse: false,
211 static_mode: false,
212 }
213 }
214
215 pub fn with_init_params(mut self, theta: f64) -> Self {
216 if let Some(ref mut p) = self.params {
217 p.data[[0, 0]] = theta;
218 }
219 self
220 }
221}
222
223impl Default for TQRYY {
224 fn default() -> Self {
225 Self::new(true, true)
226 }
227}
228
229impl TQModule for TQRYY {
230 fn forward(&mut self, _qdev: &mut TQDevice) -> Result<()> {
231 Err(MLError::InvalidConfiguration(
232 "Use apply() instead of forward() for operators".to_string(),
233 ))
234 }
235
236 fn parameters(&self) -> Vec<TQParameter> {
237 self.params.iter().cloned().collect()
238 }
239
240 fn n_wires(&self) -> Option<usize> {
241 Some(2)
242 }
243
244 fn set_n_wires(&mut self, _n_wires: usize) {}
245
246 fn is_static_mode(&self) -> bool {
247 self.static_mode
248 }
249
250 fn static_on(&mut self) {
251 self.static_mode = true;
252 }
253
254 fn static_off(&mut self) {
255 self.static_mode = false;
256 }
257
258 fn name(&self) -> &str {
259 "RYY"
260 }
261
262 fn zero_grad(&mut self) {
263 if let Some(ref mut p) = self.params {
264 p.zero_grad();
265 }
266 }
267}
268
269impl TQOperator for TQRYY {
270 fn num_wires(&self) -> WiresEnum {
271 WiresEnum::Fixed(2)
272 }
273
274 fn num_params(&self) -> NParamsEnum {
275 NParamsEnum::Fixed(1)
276 }
277
278 fn get_matrix(&self, params: Option<&[f64]>) -> Array2<CType> {
279 let theta = params
280 .and_then(|p| p.first().copied())
281 .or_else(|| self.params.as_ref().map(|p| p.data[[0, 0]]))
282 .unwrap_or(0.0);
283
284 let theta = if self.inverse { -theta } else { theta };
285 let half_theta = theta / 2.0;
286 let c = half_theta.cos();
287 let s = half_theta.sin();
288
289 Array2::from_shape_vec(
294 (4, 4),
295 vec![
296 CType::new(c, 0.0),
297 CType::new(0.0, 0.0),
298 CType::new(0.0, 0.0),
299 CType::new(0.0, s),
300 CType::new(0.0, 0.0),
301 CType::new(c, 0.0),
302 CType::new(0.0, -s),
303 CType::new(0.0, 0.0),
304 CType::new(0.0, 0.0),
305 CType::new(0.0, -s),
306 CType::new(c, 0.0),
307 CType::new(0.0, 0.0),
308 CType::new(0.0, s),
309 CType::new(0.0, 0.0),
310 CType::new(0.0, 0.0),
311 CType::new(c, 0.0),
312 ],
313 )
314 .unwrap_or_else(|_| Array2::eye(4).mapv(|x| CType::new(x, 0.0)))
315 }
316
317 fn apply(&mut self, qdev: &mut TQDevice, wires: &[usize]) -> Result<()> {
318 self.apply_with_params(qdev, wires, None)
319 }
320
321 fn apply_with_params(
322 &mut self,
323 qdev: &mut TQDevice,
324 wires: &[usize],
325 params: Option<&[f64]>,
326 ) -> Result<()> {
327 if wires.len() < 2 {
328 return Err(MLError::InvalidConfiguration(
329 "RYY gate requires exactly 2 wires".to_string(),
330 ));
331 }
332 let matrix = self.get_matrix(params);
333 qdev.apply_two_qubit_gate(wires[0], wires[1], &matrix)?;
334
335 if qdev.record_op {
336 qdev.record_operation(OpHistoryEntry {
337 name: "ryy".to_string(),
338 wires: wires.to_vec(),
339 params: params.map(|p| p.to_vec()),
340 inverse: self.inverse,
341 trainable: self.trainable,
342 });
343 }
344
345 Ok(())
346 }
347
348 fn has_params(&self) -> bool {
349 self.has_params
350 }
351
352 fn trainable(&self) -> bool {
353 self.trainable
354 }
355
356 fn inverse(&self) -> bool {
357 self.inverse
358 }
359
360 fn set_inverse(&mut self, inverse: bool) {
361 self.inverse = inverse;
362 }
363}
364
365#[derive(Debug, Clone)]
367pub struct TQRZZ {
368 params: Option<TQParameter>,
369 has_params: bool,
370 trainable: bool,
371 inverse: bool,
372 static_mode: bool,
373}
374
375impl TQRZZ {
376 pub fn new(has_params: bool, trainable: bool) -> Self {
377 let params = if has_params {
378 Some(TQParameter::new(ArrayD::zeros(IxDyn(&[1, 1])), "rzz_theta"))
379 } else {
380 None
381 };
382
383 Self {
384 params,
385 has_params,
386 trainable,
387 inverse: false,
388 static_mode: false,
389 }
390 }
391
392 pub fn with_init_params(mut self, theta: f64) -> Self {
393 if let Some(ref mut p) = self.params {
394 p.data[[0, 0]] = theta;
395 }
396 self
397 }
398}
399
400impl Default for TQRZZ {
401 fn default() -> Self {
402 Self::new(true, true)
403 }
404}
405
406impl TQModule for TQRZZ {
407 fn forward(&mut self, _qdev: &mut TQDevice) -> Result<()> {
408 Err(MLError::InvalidConfiguration(
409 "Use apply() instead of forward() for operators".to_string(),
410 ))
411 }
412
413 fn parameters(&self) -> Vec<TQParameter> {
414 self.params.iter().cloned().collect()
415 }
416
417 fn n_wires(&self) -> Option<usize> {
418 Some(2)
419 }
420
421 fn set_n_wires(&mut self, _n_wires: usize) {}
422
423 fn is_static_mode(&self) -> bool {
424 self.static_mode
425 }
426
427 fn static_on(&mut self) {
428 self.static_mode = true;
429 }
430
431 fn static_off(&mut self) {
432 self.static_mode = false;
433 }
434
435 fn name(&self) -> &str {
436 "RZZ"
437 }
438
439 fn zero_grad(&mut self) {
440 if let Some(ref mut p) = self.params {
441 p.zero_grad();
442 }
443 }
444}
445
446impl TQOperator for TQRZZ {
447 fn num_wires(&self) -> WiresEnum {
448 WiresEnum::Fixed(2)
449 }
450
451 fn num_params(&self) -> NParamsEnum {
452 NParamsEnum::Fixed(1)
453 }
454
455 fn get_matrix(&self, params: Option<&[f64]>) -> Array2<CType> {
456 let theta = params
457 .and_then(|p| p.first().copied())
458 .or_else(|| self.params.as_ref().map(|p| p.data[[0, 0]]))
459 .unwrap_or(0.0);
460
461 let theta = if self.inverse { -theta } else { theta };
462 let half_theta = theta / 2.0;
463
464 let exp_neg = CType::from_polar(1.0, -half_theta);
466 let exp_pos = CType::from_polar(1.0, half_theta);
467
468 Array2::from_shape_vec(
469 (4, 4),
470 vec![
471 exp_neg,
472 CType::new(0.0, 0.0),
473 CType::new(0.0, 0.0),
474 CType::new(0.0, 0.0),
475 CType::new(0.0, 0.0),
476 exp_pos,
477 CType::new(0.0, 0.0),
478 CType::new(0.0, 0.0),
479 CType::new(0.0, 0.0),
480 CType::new(0.0, 0.0),
481 exp_pos,
482 CType::new(0.0, 0.0),
483 CType::new(0.0, 0.0),
484 CType::new(0.0, 0.0),
485 CType::new(0.0, 0.0),
486 exp_neg,
487 ],
488 )
489 .unwrap_or_else(|_| Array2::eye(4).mapv(|x| CType::new(x, 0.0)))
490 }
491
492 fn apply(&mut self, qdev: &mut TQDevice, wires: &[usize]) -> Result<()> {
493 self.apply_with_params(qdev, wires, None)
494 }
495
496 fn apply_with_params(
497 &mut self,
498 qdev: &mut TQDevice,
499 wires: &[usize],
500 params: Option<&[f64]>,
501 ) -> Result<()> {
502 if wires.len() < 2 {
503 return Err(MLError::InvalidConfiguration(
504 "RZZ gate requires exactly 2 wires".to_string(),
505 ));
506 }
507 let matrix = self.get_matrix(params);
508 qdev.apply_two_qubit_gate(wires[0], wires[1], &matrix)?;
509
510 if qdev.record_op {
511 qdev.record_operation(OpHistoryEntry {
512 name: "rzz".to_string(),
513 wires: wires.to_vec(),
514 params: params.map(|p| p.to_vec()),
515 inverse: self.inverse,
516 trainable: self.trainable,
517 });
518 }
519
520 Ok(())
521 }
522
523 fn has_params(&self) -> bool {
524 self.has_params
525 }
526
527 fn trainable(&self) -> bool {
528 self.trainable
529 }
530
531 fn inverse(&self) -> bool {
532 self.inverse
533 }
534
535 fn set_inverse(&mut self, inverse: bool) {
536 self.inverse = inverse;
537 }
538}
539
540#[derive(Debug, Clone)]
542pub struct TQRZX {
543 params: Option<TQParameter>,
544 has_params: bool,
545 trainable: bool,
546 inverse: bool,
547 static_mode: bool,
548}
549
550impl TQRZX {
551 pub fn new(has_params: bool, trainable: bool) -> Self {
552 let params = if has_params {
553 Some(TQParameter::new(ArrayD::zeros(IxDyn(&[1, 1])), "rzx_theta"))
554 } else {
555 None
556 };
557
558 Self {
559 params,
560 has_params,
561 trainable,
562 inverse: false,
563 static_mode: false,
564 }
565 }
566
567 pub fn with_init_params(mut self, theta: f64) -> Self {
568 if let Some(ref mut p) = self.params {
569 p.data[[0, 0]] = theta;
570 }
571 self
572 }
573}
574
575impl Default for TQRZX {
576 fn default() -> Self {
577 Self::new(true, true)
578 }
579}
580
581impl TQModule for TQRZX {
582 fn forward(&mut self, _qdev: &mut TQDevice) -> Result<()> {
583 Err(MLError::InvalidConfiguration(
584 "Use apply() instead of forward() for operators".to_string(),
585 ))
586 }
587
588 fn parameters(&self) -> Vec<TQParameter> {
589 self.params.iter().cloned().collect()
590 }
591
592 fn n_wires(&self) -> Option<usize> {
593 Some(2)
594 }
595
596 fn set_n_wires(&mut self, _n_wires: usize) {}
597
598 fn is_static_mode(&self) -> bool {
599 self.static_mode
600 }
601
602 fn static_on(&mut self) {
603 self.static_mode = true;
604 }
605
606 fn static_off(&mut self) {
607 self.static_mode = false;
608 }
609
610 fn name(&self) -> &str {
611 "RZX"
612 }
613
614 fn zero_grad(&mut self) {
615 if let Some(ref mut p) = self.params {
616 p.zero_grad();
617 }
618 }
619}
620
621impl TQOperator for TQRZX {
622 fn num_wires(&self) -> WiresEnum {
623 WiresEnum::Fixed(2)
624 }
625
626 fn num_params(&self) -> NParamsEnum {
627 NParamsEnum::Fixed(1)
628 }
629
630 fn get_matrix(&self, params: Option<&[f64]>) -> Array2<CType> {
631 let theta = params
632 .and_then(|p| p.first().copied())
633 .or_else(|| self.params.as_ref().map(|p| p.data[[0, 0]]))
634 .unwrap_or(0.0);
635
636 let theta = if self.inverse { -theta } else { theta };
637 let half_theta = theta / 2.0;
638 let c = half_theta.cos();
639 let s = half_theta.sin();
640
641 Array2::from_shape_vec(
646 (4, 4),
647 vec![
648 CType::new(c, 0.0),
649 CType::new(0.0, -s),
650 CType::new(0.0, 0.0),
651 CType::new(0.0, 0.0),
652 CType::new(0.0, -s),
653 CType::new(c, 0.0),
654 CType::new(0.0, 0.0),
655 CType::new(0.0, 0.0),
656 CType::new(0.0, 0.0),
657 CType::new(0.0, 0.0),
658 CType::new(c, 0.0),
659 CType::new(0.0, s),
660 CType::new(0.0, 0.0),
661 CType::new(0.0, 0.0),
662 CType::new(0.0, s),
663 CType::new(c, 0.0),
664 ],
665 )
666 .unwrap_or_else(|_| Array2::eye(4).mapv(|x| CType::new(x, 0.0)))
667 }
668
669 fn apply(&mut self, qdev: &mut TQDevice, wires: &[usize]) -> Result<()> {
670 self.apply_with_params(qdev, wires, None)
671 }
672
673 fn apply_with_params(
674 &mut self,
675 qdev: &mut TQDevice,
676 wires: &[usize],
677 params: Option<&[f64]>,
678 ) -> Result<()> {
679 if wires.len() < 2 {
680 return Err(MLError::InvalidConfiguration(
681 "RZX gate requires exactly 2 wires".to_string(),
682 ));
683 }
684 let matrix = self.get_matrix(params);
685 qdev.apply_two_qubit_gate(wires[0], wires[1], &matrix)?;
686
687 if qdev.record_op {
688 qdev.record_operation(OpHistoryEntry {
689 name: "rzx".to_string(),
690 wires: wires.to_vec(),
691 params: params.map(|p| p.to_vec()),
692 inverse: self.inverse,
693 trainable: self.trainable,
694 });
695 }
696
697 Ok(())
698 }
699
700 fn has_params(&self) -> bool {
701 self.has_params
702 }
703
704 fn trainable(&self) -> bool {
705 self.trainable
706 }
707
708 fn inverse(&self) -> bool {
709 self.inverse
710 }
711
712 fn set_inverse(&mut self, inverse: bool) {
713 self.inverse = inverse;
714 }
715}