quantrs2_ml/torchquantum/gates/two_qubit/
rotation.rs

1//! Parameterized Two-Qubit Rotation Gates
2//!
3//! This module provides Ising-type rotation gates: RXX, RYY, RZZ, RZX
4
5use crate::error::{MLError, Result};
6use crate::torchquantum::{
7    CType, NParamsEnum, OpHistoryEntry, TQDevice, TQModule, TQOperator, TQParameter, WiresEnum,
8};
9use scirs2_core::ndarray::{Array2, ArrayD, IxDyn};
10
11/// RXX gate - Ising XX coupling: exp(-i θ/2 XX)
12#[derive(Debug, Clone)]
13pub struct TQRXX {
14    params: Option<TQParameter>,
15    has_params: bool,
16    trainable: bool,
17    inverse: bool,
18    static_mode: bool,
19}
20
21impl TQRXX {
22    pub fn new(has_params: bool, trainable: bool) -> Self {
23        let params = if has_params {
24            Some(TQParameter::new(ArrayD::zeros(IxDyn(&[1, 1])), "rxx_theta"))
25        } else {
26            None
27        };
28
29        Self {
30            params,
31            has_params,
32            trainable,
33            inverse: false,
34            static_mode: false,
35        }
36    }
37
38    pub fn with_init_params(mut self, theta: f64) -> Self {
39        if let Some(ref mut p) = self.params {
40            p.data[[0, 0]] = theta;
41        }
42        self
43    }
44}
45
46impl Default for TQRXX {
47    fn default() -> Self {
48        Self::new(true, true)
49    }
50}
51
52impl TQModule for TQRXX {
53    fn forward(&mut self, _qdev: &mut TQDevice) -> Result<()> {
54        Err(MLError::InvalidConfiguration(
55            "Use apply() instead of forward() for operators".to_string(),
56        ))
57    }
58
59    fn parameters(&self) -> Vec<TQParameter> {
60        self.params.iter().cloned().collect()
61    }
62
63    fn n_wires(&self) -> Option<usize> {
64        Some(2)
65    }
66
67    fn set_n_wires(&mut self, _n_wires: usize) {}
68
69    fn is_static_mode(&self) -> bool {
70        self.static_mode
71    }
72
73    fn static_on(&mut self) {
74        self.static_mode = true;
75    }
76
77    fn static_off(&mut self) {
78        self.static_mode = false;
79    }
80
81    fn name(&self) -> &str {
82        "RXX"
83    }
84
85    fn zero_grad(&mut self) {
86        if let Some(ref mut p) = self.params {
87            p.zero_grad();
88        }
89    }
90}
91
92impl TQOperator for TQRXX {
93    fn num_wires(&self) -> WiresEnum {
94        WiresEnum::Fixed(2)
95    }
96
97    fn num_params(&self) -> NParamsEnum {
98        NParamsEnum::Fixed(1)
99    }
100
101    fn get_matrix(&self, params: Option<&[f64]>) -> Array2<CType> {
102        let theta = params
103            .and_then(|p| p.first().copied())
104            .or_else(|| self.params.as_ref().map(|p| p.data[[0, 0]]))
105            .unwrap_or(0.0);
106
107        let theta = if self.inverse { -theta } else { theta };
108        let half_theta = theta / 2.0;
109        let c = half_theta.cos();
110        let s = half_theta.sin();
111
112        // RXX(θ) = [[cos(θ/2), 0, 0, -i sin(θ/2)],
113        //           [0, cos(θ/2), -i sin(θ/2), 0],
114        //           [0, -i sin(θ/2), cos(θ/2), 0],
115        //           [-i sin(θ/2), 0, 0, cos(θ/2)]]
116        Array2::from_shape_vec(
117            (4, 4),
118            vec![
119                CType::new(c, 0.0),
120                CType::new(0.0, 0.0),
121                CType::new(0.0, 0.0),
122                CType::new(0.0, -s),
123                CType::new(0.0, 0.0),
124                CType::new(c, 0.0),
125                CType::new(0.0, -s),
126                CType::new(0.0, 0.0),
127                CType::new(0.0, 0.0),
128                CType::new(0.0, -s),
129                CType::new(c, 0.0),
130                CType::new(0.0, 0.0),
131                CType::new(0.0, -s),
132                CType::new(0.0, 0.0),
133                CType::new(0.0, 0.0),
134                CType::new(c, 0.0),
135            ],
136        )
137        .unwrap_or_else(|_| Array2::eye(4).mapv(|x| CType::new(x, 0.0)))
138    }
139
140    fn apply(&mut self, qdev: &mut TQDevice, wires: &[usize]) -> Result<()> {
141        self.apply_with_params(qdev, wires, None)
142    }
143
144    fn apply_with_params(
145        &mut self,
146        qdev: &mut TQDevice,
147        wires: &[usize],
148        params: Option<&[f64]>,
149    ) -> Result<()> {
150        if wires.len() < 2 {
151            return Err(MLError::InvalidConfiguration(
152                "RXX gate requires exactly 2 wires".to_string(),
153            ));
154        }
155        let matrix = self.get_matrix(params);
156        qdev.apply_two_qubit_gate(wires[0], wires[1], &matrix)?;
157
158        if qdev.record_op {
159            qdev.record_operation(OpHistoryEntry {
160                name: "rxx".to_string(),
161                wires: wires.to_vec(),
162                params: params.map(|p| p.to_vec()),
163                inverse: self.inverse,
164                trainable: self.trainable,
165            });
166        }
167
168        Ok(())
169    }
170
171    fn has_params(&self) -> bool {
172        self.has_params
173    }
174
175    fn trainable(&self) -> bool {
176        self.trainable
177    }
178
179    fn inverse(&self) -> bool {
180        self.inverse
181    }
182
183    fn set_inverse(&mut self, inverse: bool) {
184        self.inverse = inverse;
185    }
186}
187
188/// RYY gate - Ising YY coupling: exp(-i θ/2 YY)
189#[derive(Debug, Clone)]
190pub struct TQRYY {
191    params: Option<TQParameter>,
192    has_params: bool,
193    trainable: bool,
194    inverse: bool,
195    static_mode: bool,
196}
197
198impl TQRYY {
199    pub fn new(has_params: bool, trainable: bool) -> Self {
200        let params = if has_params {
201            Some(TQParameter::new(ArrayD::zeros(IxDyn(&[1, 1])), "ryy_theta"))
202        } else {
203            None
204        };
205
206        Self {
207            params,
208            has_params,
209            trainable,
210            inverse: false,
211            static_mode: false,
212        }
213    }
214
215    pub fn with_init_params(mut self, theta: f64) -> Self {
216        if let Some(ref mut p) = self.params {
217            p.data[[0, 0]] = theta;
218        }
219        self
220    }
221}
222
223impl Default for TQRYY {
224    fn default() -> Self {
225        Self::new(true, true)
226    }
227}
228
229impl TQModule for TQRYY {
230    fn forward(&mut self, _qdev: &mut TQDevice) -> Result<()> {
231        Err(MLError::InvalidConfiguration(
232            "Use apply() instead of forward() for operators".to_string(),
233        ))
234    }
235
236    fn parameters(&self) -> Vec<TQParameter> {
237        self.params.iter().cloned().collect()
238    }
239
240    fn n_wires(&self) -> Option<usize> {
241        Some(2)
242    }
243
244    fn set_n_wires(&mut self, _n_wires: usize) {}
245
246    fn is_static_mode(&self) -> bool {
247        self.static_mode
248    }
249
250    fn static_on(&mut self) {
251        self.static_mode = true;
252    }
253
254    fn static_off(&mut self) {
255        self.static_mode = false;
256    }
257
258    fn name(&self) -> &str {
259        "RYY"
260    }
261
262    fn zero_grad(&mut self) {
263        if let Some(ref mut p) = self.params {
264            p.zero_grad();
265        }
266    }
267}
268
269impl TQOperator for TQRYY {
270    fn num_wires(&self) -> WiresEnum {
271        WiresEnum::Fixed(2)
272    }
273
274    fn num_params(&self) -> NParamsEnum {
275        NParamsEnum::Fixed(1)
276    }
277
278    fn get_matrix(&self, params: Option<&[f64]>) -> Array2<CType> {
279        let theta = params
280            .and_then(|p| p.first().copied())
281            .or_else(|| self.params.as_ref().map(|p| p.data[[0, 0]]))
282            .unwrap_or(0.0);
283
284        let theta = if self.inverse { -theta } else { theta };
285        let half_theta = theta / 2.0;
286        let c = half_theta.cos();
287        let s = half_theta.sin();
288
289        // RYY(θ) = [[cos(θ/2), 0, 0, i sin(θ/2)],
290        //           [0, cos(θ/2), -i sin(θ/2), 0],
291        //           [0, -i sin(θ/2), cos(θ/2), 0],
292        //           [i sin(θ/2), 0, 0, cos(θ/2)]]
293        Array2::from_shape_vec(
294            (4, 4),
295            vec![
296                CType::new(c, 0.0),
297                CType::new(0.0, 0.0),
298                CType::new(0.0, 0.0),
299                CType::new(0.0, s),
300                CType::new(0.0, 0.0),
301                CType::new(c, 0.0),
302                CType::new(0.0, -s),
303                CType::new(0.0, 0.0),
304                CType::new(0.0, 0.0),
305                CType::new(0.0, -s),
306                CType::new(c, 0.0),
307                CType::new(0.0, 0.0),
308                CType::new(0.0, s),
309                CType::new(0.0, 0.0),
310                CType::new(0.0, 0.0),
311                CType::new(c, 0.0),
312            ],
313        )
314        .unwrap_or_else(|_| Array2::eye(4).mapv(|x| CType::new(x, 0.0)))
315    }
316
317    fn apply(&mut self, qdev: &mut TQDevice, wires: &[usize]) -> Result<()> {
318        self.apply_with_params(qdev, wires, None)
319    }
320
321    fn apply_with_params(
322        &mut self,
323        qdev: &mut TQDevice,
324        wires: &[usize],
325        params: Option<&[f64]>,
326    ) -> Result<()> {
327        if wires.len() < 2 {
328            return Err(MLError::InvalidConfiguration(
329                "RYY gate requires exactly 2 wires".to_string(),
330            ));
331        }
332        let matrix = self.get_matrix(params);
333        qdev.apply_two_qubit_gate(wires[0], wires[1], &matrix)?;
334
335        if qdev.record_op {
336            qdev.record_operation(OpHistoryEntry {
337                name: "ryy".to_string(),
338                wires: wires.to_vec(),
339                params: params.map(|p| p.to_vec()),
340                inverse: self.inverse,
341                trainable: self.trainable,
342            });
343        }
344
345        Ok(())
346    }
347
348    fn has_params(&self) -> bool {
349        self.has_params
350    }
351
352    fn trainable(&self) -> bool {
353        self.trainable
354    }
355
356    fn inverse(&self) -> bool {
357        self.inverse
358    }
359
360    fn set_inverse(&mut self, inverse: bool) {
361        self.inverse = inverse;
362    }
363}
364
365/// RZZ gate - Ising ZZ coupling: exp(-i θ/2 ZZ)
366#[derive(Debug, Clone)]
367pub struct TQRZZ {
368    params: Option<TQParameter>,
369    has_params: bool,
370    trainable: bool,
371    inverse: bool,
372    static_mode: bool,
373}
374
375impl TQRZZ {
376    pub fn new(has_params: bool, trainable: bool) -> Self {
377        let params = if has_params {
378            Some(TQParameter::new(ArrayD::zeros(IxDyn(&[1, 1])), "rzz_theta"))
379        } else {
380            None
381        };
382
383        Self {
384            params,
385            has_params,
386            trainable,
387            inverse: false,
388            static_mode: false,
389        }
390    }
391
392    pub fn with_init_params(mut self, theta: f64) -> Self {
393        if let Some(ref mut p) = self.params {
394            p.data[[0, 0]] = theta;
395        }
396        self
397    }
398}
399
400impl Default for TQRZZ {
401    fn default() -> Self {
402        Self::new(true, true)
403    }
404}
405
406impl TQModule for TQRZZ {
407    fn forward(&mut self, _qdev: &mut TQDevice) -> Result<()> {
408        Err(MLError::InvalidConfiguration(
409            "Use apply() instead of forward() for operators".to_string(),
410        ))
411    }
412
413    fn parameters(&self) -> Vec<TQParameter> {
414        self.params.iter().cloned().collect()
415    }
416
417    fn n_wires(&self) -> Option<usize> {
418        Some(2)
419    }
420
421    fn set_n_wires(&mut self, _n_wires: usize) {}
422
423    fn is_static_mode(&self) -> bool {
424        self.static_mode
425    }
426
427    fn static_on(&mut self) {
428        self.static_mode = true;
429    }
430
431    fn static_off(&mut self) {
432        self.static_mode = false;
433    }
434
435    fn name(&self) -> &str {
436        "RZZ"
437    }
438
439    fn zero_grad(&mut self) {
440        if let Some(ref mut p) = self.params {
441            p.zero_grad();
442        }
443    }
444}
445
446impl TQOperator for TQRZZ {
447    fn num_wires(&self) -> WiresEnum {
448        WiresEnum::Fixed(2)
449    }
450
451    fn num_params(&self) -> NParamsEnum {
452        NParamsEnum::Fixed(1)
453    }
454
455    fn get_matrix(&self, params: Option<&[f64]>) -> Array2<CType> {
456        let theta = params
457            .and_then(|p| p.first().copied())
458            .or_else(|| self.params.as_ref().map(|p| p.data[[0, 0]]))
459            .unwrap_or(0.0);
460
461        let theta = if self.inverse { -theta } else { theta };
462        let half_theta = theta / 2.0;
463
464        // RZZ(θ) = diag(e^(-iθ/2), e^(iθ/2), e^(iθ/2), e^(-iθ/2))
465        let exp_neg = CType::from_polar(1.0, -half_theta);
466        let exp_pos = CType::from_polar(1.0, half_theta);
467
468        Array2::from_shape_vec(
469            (4, 4),
470            vec![
471                exp_neg,
472                CType::new(0.0, 0.0),
473                CType::new(0.0, 0.0),
474                CType::new(0.0, 0.0),
475                CType::new(0.0, 0.0),
476                exp_pos,
477                CType::new(0.0, 0.0),
478                CType::new(0.0, 0.0),
479                CType::new(0.0, 0.0),
480                CType::new(0.0, 0.0),
481                exp_pos,
482                CType::new(0.0, 0.0),
483                CType::new(0.0, 0.0),
484                CType::new(0.0, 0.0),
485                CType::new(0.0, 0.0),
486                exp_neg,
487            ],
488        )
489        .unwrap_or_else(|_| Array2::eye(4).mapv(|x| CType::new(x, 0.0)))
490    }
491
492    fn apply(&mut self, qdev: &mut TQDevice, wires: &[usize]) -> Result<()> {
493        self.apply_with_params(qdev, wires, None)
494    }
495
496    fn apply_with_params(
497        &mut self,
498        qdev: &mut TQDevice,
499        wires: &[usize],
500        params: Option<&[f64]>,
501    ) -> Result<()> {
502        if wires.len() < 2 {
503            return Err(MLError::InvalidConfiguration(
504                "RZZ gate requires exactly 2 wires".to_string(),
505            ));
506        }
507        let matrix = self.get_matrix(params);
508        qdev.apply_two_qubit_gate(wires[0], wires[1], &matrix)?;
509
510        if qdev.record_op {
511            qdev.record_operation(OpHistoryEntry {
512                name: "rzz".to_string(),
513                wires: wires.to_vec(),
514                params: params.map(|p| p.to_vec()),
515                inverse: self.inverse,
516                trainable: self.trainable,
517            });
518        }
519
520        Ok(())
521    }
522
523    fn has_params(&self) -> bool {
524        self.has_params
525    }
526
527    fn trainable(&self) -> bool {
528        self.trainable
529    }
530
531    fn inverse(&self) -> bool {
532        self.inverse
533    }
534
535    fn set_inverse(&mut self, inverse: bool) {
536        self.inverse = inverse;
537    }
538}
539
540/// RZX gate - Cross-resonance rotation: exp(-i θ/2 ZX)
541#[derive(Debug, Clone)]
542pub struct TQRZX {
543    params: Option<TQParameter>,
544    has_params: bool,
545    trainable: bool,
546    inverse: bool,
547    static_mode: bool,
548}
549
550impl TQRZX {
551    pub fn new(has_params: bool, trainable: bool) -> Self {
552        let params = if has_params {
553            Some(TQParameter::new(ArrayD::zeros(IxDyn(&[1, 1])), "rzx_theta"))
554        } else {
555            None
556        };
557
558        Self {
559            params,
560            has_params,
561            trainable,
562            inverse: false,
563            static_mode: false,
564        }
565    }
566
567    pub fn with_init_params(mut self, theta: f64) -> Self {
568        if let Some(ref mut p) = self.params {
569            p.data[[0, 0]] = theta;
570        }
571        self
572    }
573}
574
575impl Default for TQRZX {
576    fn default() -> Self {
577        Self::new(true, true)
578    }
579}
580
581impl TQModule for TQRZX {
582    fn forward(&mut self, _qdev: &mut TQDevice) -> Result<()> {
583        Err(MLError::InvalidConfiguration(
584            "Use apply() instead of forward() for operators".to_string(),
585        ))
586    }
587
588    fn parameters(&self) -> Vec<TQParameter> {
589        self.params.iter().cloned().collect()
590    }
591
592    fn n_wires(&self) -> Option<usize> {
593        Some(2)
594    }
595
596    fn set_n_wires(&mut self, _n_wires: usize) {}
597
598    fn is_static_mode(&self) -> bool {
599        self.static_mode
600    }
601
602    fn static_on(&mut self) {
603        self.static_mode = true;
604    }
605
606    fn static_off(&mut self) {
607        self.static_mode = false;
608    }
609
610    fn name(&self) -> &str {
611        "RZX"
612    }
613
614    fn zero_grad(&mut self) {
615        if let Some(ref mut p) = self.params {
616            p.zero_grad();
617        }
618    }
619}
620
621impl TQOperator for TQRZX {
622    fn num_wires(&self) -> WiresEnum {
623        WiresEnum::Fixed(2)
624    }
625
626    fn num_params(&self) -> NParamsEnum {
627        NParamsEnum::Fixed(1)
628    }
629
630    fn get_matrix(&self, params: Option<&[f64]>) -> Array2<CType> {
631        let theta = params
632            .and_then(|p| p.first().copied())
633            .or_else(|| self.params.as_ref().map(|p| p.data[[0, 0]]))
634            .unwrap_or(0.0);
635
636        let theta = if self.inverse { -theta } else { theta };
637        let half_theta = theta / 2.0;
638        let c = half_theta.cos();
639        let s = half_theta.sin();
640
641        // RZX(θ) = [[cos(θ/2), -i sin(θ/2), 0, 0],
642        //           [-i sin(θ/2), cos(θ/2), 0, 0],
643        //           [0, 0, cos(θ/2), i sin(θ/2)],
644        //           [0, 0, i sin(θ/2), cos(θ/2)]]
645        Array2::from_shape_vec(
646            (4, 4),
647            vec![
648                CType::new(c, 0.0),
649                CType::new(0.0, -s),
650                CType::new(0.0, 0.0),
651                CType::new(0.0, 0.0),
652                CType::new(0.0, -s),
653                CType::new(c, 0.0),
654                CType::new(0.0, 0.0),
655                CType::new(0.0, 0.0),
656                CType::new(0.0, 0.0),
657                CType::new(0.0, 0.0),
658                CType::new(c, 0.0),
659                CType::new(0.0, s),
660                CType::new(0.0, 0.0),
661                CType::new(0.0, 0.0),
662                CType::new(0.0, s),
663                CType::new(c, 0.0),
664            ],
665        )
666        .unwrap_or_else(|_| Array2::eye(4).mapv(|x| CType::new(x, 0.0)))
667    }
668
669    fn apply(&mut self, qdev: &mut TQDevice, wires: &[usize]) -> Result<()> {
670        self.apply_with_params(qdev, wires, None)
671    }
672
673    fn apply_with_params(
674        &mut self,
675        qdev: &mut TQDevice,
676        wires: &[usize],
677        params: Option<&[f64]>,
678    ) -> Result<()> {
679        if wires.len() < 2 {
680            return Err(MLError::InvalidConfiguration(
681                "RZX gate requires exactly 2 wires".to_string(),
682            ));
683        }
684        let matrix = self.get_matrix(params);
685        qdev.apply_two_qubit_gate(wires[0], wires[1], &matrix)?;
686
687        if qdev.record_op {
688            qdev.record_operation(OpHistoryEntry {
689                name: "rzx".to_string(),
690                wires: wires.to_vec(),
691                params: params.map(|p| p.to_vec()),
692                inverse: self.inverse,
693                trainable: self.trainable,
694            });
695        }
696
697        Ok(())
698    }
699
700    fn has_params(&self) -> bool {
701        self.has_params
702    }
703
704    fn trainable(&self) -> bool {
705        self.trainable
706    }
707
708    fn inverse(&self) -> bool {
709        self.inverse
710    }
711
712    fn set_inverse(&mut self, inverse: bool) {
713        self.inverse = inverse;
714    }
715}