quantrs2_ml/torchquantum/gates/two_qubit/
controlled.rs

1//! Controlled Rotation Gates
2//!
3//! This module provides controlled rotation gates: CRX, CRY, CRZ
4
5use crate::error::{MLError, Result};
6use crate::torchquantum::{
7    CType, NParamsEnum, OpHistoryEntry, TQDevice, TQModule, TQOperator, TQParameter, WiresEnum,
8};
9use scirs2_core::ndarray::{Array2, ArrayD, IxDyn};
10
11/// CRX gate - Controlled RX rotation
12#[derive(Debug, Clone)]
13pub struct TQCRX {
14    params: Option<TQParameter>,
15    has_params: bool,
16    trainable: bool,
17    inverse: bool,
18    static_mode: bool,
19}
20
21impl TQCRX {
22    pub fn new(has_params: bool, trainable: bool) -> Self {
23        let params = if has_params {
24            Some(TQParameter::new(ArrayD::zeros(IxDyn(&[1, 1])), "crx_theta"))
25        } else {
26            None
27        };
28
29        Self {
30            params,
31            has_params,
32            trainable,
33            inverse: false,
34            static_mode: false,
35        }
36    }
37
38    pub fn with_init_params(mut self, theta: f64) -> Self {
39        if let Some(ref mut p) = self.params {
40            p.data[[0, 0]] = theta;
41        }
42        self
43    }
44}
45
46impl Default for TQCRX {
47    fn default() -> Self {
48        Self::new(true, true)
49    }
50}
51
52impl TQModule for TQCRX {
53    fn forward(&mut self, _qdev: &mut TQDevice) -> Result<()> {
54        Err(MLError::InvalidConfiguration(
55            "Use apply() instead of forward() for operators".to_string(),
56        ))
57    }
58
59    fn parameters(&self) -> Vec<TQParameter> {
60        self.params.iter().cloned().collect()
61    }
62
63    fn n_wires(&self) -> Option<usize> {
64        Some(2)
65    }
66
67    fn set_n_wires(&mut self, _n_wires: usize) {}
68
69    fn is_static_mode(&self) -> bool {
70        self.static_mode
71    }
72
73    fn static_on(&mut self) {
74        self.static_mode = true;
75    }
76
77    fn static_off(&mut self) {
78        self.static_mode = false;
79    }
80
81    fn name(&self) -> &str {
82        "CRX"
83    }
84
85    fn zero_grad(&mut self) {
86        if let Some(ref mut p) = self.params {
87            p.zero_grad();
88        }
89    }
90}
91
92impl TQOperator for TQCRX {
93    fn num_wires(&self) -> WiresEnum {
94        WiresEnum::Fixed(2)
95    }
96
97    fn num_params(&self) -> NParamsEnum {
98        NParamsEnum::Fixed(1)
99    }
100
101    fn get_matrix(&self, params: Option<&[f64]>) -> Array2<CType> {
102        let theta = params
103            .and_then(|p| p.first().copied())
104            .or_else(|| self.params.as_ref().map(|p| p.data[[0, 0]]))
105            .unwrap_or(0.0);
106
107        let theta = if self.inverse { -theta } else { theta };
108        let half_theta = theta / 2.0;
109        let c = half_theta.cos();
110        let s = half_theta.sin();
111
112        // CRX(θ) - controlled rotation about X-axis
113        Array2::from_shape_vec(
114            (4, 4),
115            vec![
116                CType::new(1.0, 0.0),
117                CType::new(0.0, 0.0),
118                CType::new(0.0, 0.0),
119                CType::new(0.0, 0.0),
120                CType::new(0.0, 0.0),
121                CType::new(1.0, 0.0),
122                CType::new(0.0, 0.0),
123                CType::new(0.0, 0.0),
124                CType::new(0.0, 0.0),
125                CType::new(0.0, 0.0),
126                CType::new(c, 0.0),
127                CType::new(0.0, -s),
128                CType::new(0.0, 0.0),
129                CType::new(0.0, 0.0),
130                CType::new(0.0, -s),
131                CType::new(c, 0.0),
132            ],
133        )
134        .unwrap_or_else(|_| Array2::eye(4).mapv(|x| CType::new(x, 0.0)))
135    }
136
137    fn apply(&mut self, qdev: &mut TQDevice, wires: &[usize]) -> Result<()> {
138        self.apply_with_params(qdev, wires, None)
139    }
140
141    fn apply_with_params(
142        &mut self,
143        qdev: &mut TQDevice,
144        wires: &[usize],
145        params: Option<&[f64]>,
146    ) -> Result<()> {
147        if wires.len() < 2 {
148            return Err(MLError::InvalidConfiguration(
149                "CRX gate requires exactly 2 wires".to_string(),
150            ));
151        }
152        let matrix = self.get_matrix(params);
153        qdev.apply_two_qubit_gate(wires[0], wires[1], &matrix)?;
154
155        if qdev.record_op {
156            qdev.record_operation(OpHistoryEntry {
157                name: "crx".to_string(),
158                wires: wires.to_vec(),
159                params: params.map(|p| p.to_vec()),
160                inverse: self.inverse,
161                trainable: self.trainable,
162            });
163        }
164
165        Ok(())
166    }
167
168    fn has_params(&self) -> bool {
169        self.has_params
170    }
171
172    fn trainable(&self) -> bool {
173        self.trainable
174    }
175
176    fn inverse(&self) -> bool {
177        self.inverse
178    }
179
180    fn set_inverse(&mut self, inverse: bool) {
181        self.inverse = inverse;
182    }
183}
184
185/// CRY gate - Controlled RY rotation
186#[derive(Debug, Clone)]
187pub struct TQCRY {
188    params: Option<TQParameter>,
189    has_params: bool,
190    trainable: bool,
191    inverse: bool,
192    static_mode: bool,
193}
194
195impl TQCRY {
196    pub fn new(has_params: bool, trainable: bool) -> Self {
197        let params = if has_params {
198            Some(TQParameter::new(ArrayD::zeros(IxDyn(&[1, 1])), "cry_theta"))
199        } else {
200            None
201        };
202
203        Self {
204            params,
205            has_params,
206            trainable,
207            inverse: false,
208            static_mode: false,
209        }
210    }
211
212    pub fn with_init_params(mut self, theta: f64) -> Self {
213        if let Some(ref mut p) = self.params {
214            p.data[[0, 0]] = theta;
215        }
216        self
217    }
218}
219
220impl Default for TQCRY {
221    fn default() -> Self {
222        Self::new(true, true)
223    }
224}
225
226impl TQModule for TQCRY {
227    fn forward(&mut self, _qdev: &mut TQDevice) -> Result<()> {
228        Err(MLError::InvalidConfiguration(
229            "Use apply() instead of forward() for operators".to_string(),
230        ))
231    }
232
233    fn parameters(&self) -> Vec<TQParameter> {
234        self.params.iter().cloned().collect()
235    }
236
237    fn n_wires(&self) -> Option<usize> {
238        Some(2)
239    }
240
241    fn set_n_wires(&mut self, _n_wires: usize) {}
242
243    fn is_static_mode(&self) -> bool {
244        self.static_mode
245    }
246
247    fn static_on(&mut self) {
248        self.static_mode = true;
249    }
250
251    fn static_off(&mut self) {
252        self.static_mode = false;
253    }
254
255    fn name(&self) -> &str {
256        "CRY"
257    }
258
259    fn zero_grad(&mut self) {
260        if let Some(ref mut p) = self.params {
261            p.zero_grad();
262        }
263    }
264}
265
266impl TQOperator for TQCRY {
267    fn num_wires(&self) -> WiresEnum {
268        WiresEnum::Fixed(2)
269    }
270
271    fn num_params(&self) -> NParamsEnum {
272        NParamsEnum::Fixed(1)
273    }
274
275    fn get_matrix(&self, params: Option<&[f64]>) -> Array2<CType> {
276        let theta = params
277            .and_then(|p| p.first().copied())
278            .or_else(|| self.params.as_ref().map(|p| p.data[[0, 0]]))
279            .unwrap_or(0.0);
280
281        let theta = if self.inverse { -theta } else { theta };
282        let half_theta = theta / 2.0;
283        let c = half_theta.cos();
284        let s = half_theta.sin();
285
286        // CRY(θ) - controlled rotation about Y-axis
287        Array2::from_shape_vec(
288            (4, 4),
289            vec![
290                CType::new(1.0, 0.0),
291                CType::new(0.0, 0.0),
292                CType::new(0.0, 0.0),
293                CType::new(0.0, 0.0),
294                CType::new(0.0, 0.0),
295                CType::new(1.0, 0.0),
296                CType::new(0.0, 0.0),
297                CType::new(0.0, 0.0),
298                CType::new(0.0, 0.0),
299                CType::new(0.0, 0.0),
300                CType::new(c, 0.0),
301                CType::new(-s, 0.0),
302                CType::new(0.0, 0.0),
303                CType::new(0.0, 0.0),
304                CType::new(s, 0.0),
305                CType::new(c, 0.0),
306            ],
307        )
308        .unwrap_or_else(|_| Array2::eye(4).mapv(|x| CType::new(x, 0.0)))
309    }
310
311    fn apply(&mut self, qdev: &mut TQDevice, wires: &[usize]) -> Result<()> {
312        self.apply_with_params(qdev, wires, None)
313    }
314
315    fn apply_with_params(
316        &mut self,
317        qdev: &mut TQDevice,
318        wires: &[usize],
319        params: Option<&[f64]>,
320    ) -> Result<()> {
321        if wires.len() < 2 {
322            return Err(MLError::InvalidConfiguration(
323                "CRY gate requires exactly 2 wires".to_string(),
324            ));
325        }
326        let matrix = self.get_matrix(params);
327        qdev.apply_two_qubit_gate(wires[0], wires[1], &matrix)?;
328
329        if qdev.record_op {
330            qdev.record_operation(OpHistoryEntry {
331                name: "cry".to_string(),
332                wires: wires.to_vec(),
333                params: params.map(|p| p.to_vec()),
334                inverse: self.inverse,
335                trainable: self.trainable,
336            });
337        }
338
339        Ok(())
340    }
341
342    fn has_params(&self) -> bool {
343        self.has_params
344    }
345
346    fn trainable(&self) -> bool {
347        self.trainable
348    }
349
350    fn inverse(&self) -> bool {
351        self.inverse
352    }
353
354    fn set_inverse(&mut self, inverse: bool) {
355        self.inverse = inverse;
356    }
357}
358
359/// CRZ gate - Controlled RZ rotation
360#[derive(Debug, Clone)]
361pub struct TQCRZ {
362    params: Option<TQParameter>,
363    has_params: bool,
364    trainable: bool,
365    inverse: bool,
366    static_mode: bool,
367}
368
369impl TQCRZ {
370    pub fn new(has_params: bool, trainable: bool) -> Self {
371        let params = if has_params {
372            Some(TQParameter::new(ArrayD::zeros(IxDyn(&[1, 1])), "crz_theta"))
373        } else {
374            None
375        };
376
377        Self {
378            params,
379            has_params,
380            trainable,
381            inverse: false,
382            static_mode: false,
383        }
384    }
385
386    pub fn with_init_params(mut self, theta: f64) -> Self {
387        if let Some(ref mut p) = self.params {
388            p.data[[0, 0]] = theta;
389        }
390        self
391    }
392}
393
394impl Default for TQCRZ {
395    fn default() -> Self {
396        Self::new(true, true)
397    }
398}
399
400impl TQModule for TQCRZ {
401    fn forward(&mut self, _qdev: &mut TQDevice) -> Result<()> {
402        Err(MLError::InvalidConfiguration(
403            "Use apply() instead of forward() for operators".to_string(),
404        ))
405    }
406
407    fn parameters(&self) -> Vec<TQParameter> {
408        self.params.iter().cloned().collect()
409    }
410
411    fn n_wires(&self) -> Option<usize> {
412        Some(2)
413    }
414
415    fn set_n_wires(&mut self, _n_wires: usize) {}
416
417    fn is_static_mode(&self) -> bool {
418        self.static_mode
419    }
420
421    fn static_on(&mut self) {
422        self.static_mode = true;
423    }
424
425    fn static_off(&mut self) {
426        self.static_mode = false;
427    }
428
429    fn name(&self) -> &str {
430        "CRZ"
431    }
432
433    fn zero_grad(&mut self) {
434        if let Some(ref mut p) = self.params {
435            p.zero_grad();
436        }
437    }
438}
439
440impl TQOperator for TQCRZ {
441    fn num_wires(&self) -> WiresEnum {
442        WiresEnum::Fixed(2)
443    }
444
445    fn num_params(&self) -> NParamsEnum {
446        NParamsEnum::Fixed(1)
447    }
448
449    fn get_matrix(&self, params: Option<&[f64]>) -> Array2<CType> {
450        let theta = params
451            .and_then(|p| p.first().copied())
452            .or_else(|| self.params.as_ref().map(|p| p.data[[0, 0]]))
453            .unwrap_or(0.0);
454
455        let theta = if self.inverse { -theta } else { theta };
456        let half_theta = theta / 2.0;
457
458        // CRZ(θ) - controlled rotation about Z-axis
459        let exp_neg = CType::from_polar(1.0, -half_theta);
460        let exp_pos = CType::from_polar(1.0, half_theta);
461
462        Array2::from_shape_vec(
463            (4, 4),
464            vec![
465                CType::new(1.0, 0.0),
466                CType::new(0.0, 0.0),
467                CType::new(0.0, 0.0),
468                CType::new(0.0, 0.0),
469                CType::new(0.0, 0.0),
470                CType::new(1.0, 0.0),
471                CType::new(0.0, 0.0),
472                CType::new(0.0, 0.0),
473                CType::new(0.0, 0.0),
474                CType::new(0.0, 0.0),
475                exp_neg,
476                CType::new(0.0, 0.0),
477                CType::new(0.0, 0.0),
478                CType::new(0.0, 0.0),
479                CType::new(0.0, 0.0),
480                exp_pos,
481            ],
482        )
483        .unwrap_or_else(|_| Array2::eye(4).mapv(|x| CType::new(x, 0.0)))
484    }
485
486    fn apply(&mut self, qdev: &mut TQDevice, wires: &[usize]) -> Result<()> {
487        self.apply_with_params(qdev, wires, None)
488    }
489
490    fn apply_with_params(
491        &mut self,
492        qdev: &mut TQDevice,
493        wires: &[usize],
494        params: Option<&[f64]>,
495    ) -> Result<()> {
496        if wires.len() < 2 {
497            return Err(MLError::InvalidConfiguration(
498                "CRZ gate requires exactly 2 wires".to_string(),
499            ));
500        }
501        let matrix = self.get_matrix(params);
502        qdev.apply_two_qubit_gate(wires[0], wires[1], &matrix)?;
503
504        if qdev.record_op {
505            qdev.record_operation(OpHistoryEntry {
506                name: "crz".to_string(),
507                wires: wires.to_vec(),
508                params: params.map(|p| p.to_vec()),
509                inverse: self.inverse,
510                trainable: self.trainable,
511            });
512        }
513
514        Ok(())
515    }
516
517    fn has_params(&self) -> bool {
518        self.has_params
519    }
520
521    fn trainable(&self) -> bool {
522        self.trainable
523    }
524
525    fn inverse(&self) -> bool {
526        self.inverse
527    }
528
529    fn set_inverse(&mut self, inverse: bool) {
530        self.inverse = inverse;
531    }
532}