quantrs2_ml/torchquantum/gates/two_qubit/special/
tqcontrolledrot_traits.rs1use crate::error::{MLError, Result};
14use crate::torchquantum::{
15 CType, NParamsEnum, OpHistoryEntry, TQDevice, TQModule, TQOperator, TQParameter, WiresEnum,
16};
17use scirs2_core::ndarray::{Array2, ArrayD, IxDyn};
18
19use super::types::TQControlledRot;
20
21impl Default for TQControlledRot {
22 fn default() -> Self {
23 Self::new(true, true)
24 }
25}
26
27impl TQModule for TQControlledRot {
28 fn forward(&mut self, _qdev: &mut TQDevice) -> Result<()> {
29 Err(MLError::InvalidConfiguration(
30 "Use apply() instead of forward() for operators".to_string(),
31 ))
32 }
33 fn parameters(&self) -> Vec<TQParameter> {
34 self.params.iter().cloned().collect()
35 }
36 fn n_wires(&self) -> Option<usize> {
37 Some(2)
38 }
39 fn set_n_wires(&mut self, _n_wires: usize) {}
40 fn is_static_mode(&self) -> bool {
41 self.static_mode
42 }
43 fn static_on(&mut self) {
44 self.static_mode = true;
45 }
46 fn static_off(&mut self) {
47 self.static_mode = false;
48 }
49 fn name(&self) -> &str {
50 "CRot"
51 }
52 fn zero_grad(&mut self) {
53 if let Some(ref mut p) = self.params {
54 p.zero_grad();
55 }
56 }
57}
58
59impl TQOperator for TQControlledRot {
60 fn num_wires(&self) -> WiresEnum {
61 WiresEnum::Fixed(2)
62 }
63 fn num_params(&self) -> NParamsEnum {
64 NParamsEnum::Fixed(3)
65 }
66 fn get_matrix(&self, params: Option<&[f64]>) -> Array2<CType> {
67 let theta = params
68 .and_then(|p| p.first().copied())
69 .or_else(|| self.params.as_ref().map(|p| p.data[[0, 0]]))
70 .unwrap_or(0.0);
71 let phi = params
72 .and_then(|p| p.get(1).copied())
73 .or_else(|| self.params.as_ref().map(|p| p.data[[0, 1]]))
74 .unwrap_or(0.0);
75 let omega = params
76 .and_then(|p| p.get(2).copied())
77 .or_else(|| self.params.as_ref().map(|p| p.data[[0, 2]]))
78 .unwrap_or(0.0);
79 let (theta, phi, omega) = if self.inverse {
80 (-omega, -phi, -theta)
81 } else {
82 (theta, phi, omega)
83 };
84 let half_theta = theta / 2.0;
85 let half_phi = phi / 2.0;
86 let half_omega = omega / 2.0;
87 let cos_phi = half_phi.cos();
88 let sin_phi = half_phi.sin();
89 let u00 = CType::from_polar(cos_phi, -(half_theta + half_omega));
90 let u01 = CType::from_polar(-sin_phi, -(half_theta - half_omega));
91 let u10 = CType::from_polar(sin_phi, half_theta - half_omega);
92 let u11 = CType::from_polar(cos_phi, half_theta + half_omega);
93 Array2::from_shape_vec(
94 (4, 4),
95 vec![
96 CType::new(1.0, 0.0),
97 CType::new(0.0, 0.0),
98 CType::new(0.0, 0.0),
99 CType::new(0.0, 0.0),
100 CType::new(0.0, 0.0),
101 CType::new(1.0, 0.0),
102 CType::new(0.0, 0.0),
103 CType::new(0.0, 0.0),
104 CType::new(0.0, 0.0),
105 CType::new(0.0, 0.0),
106 u00,
107 u01,
108 CType::new(0.0, 0.0),
109 CType::new(0.0, 0.0),
110 u10,
111 u11,
112 ],
113 )
114 .unwrap_or_else(|_| Array2::eye(4).mapv(|x| CType::new(x, 0.0)))
115 }
116 fn apply(&mut self, qdev: &mut TQDevice, wires: &[usize]) -> Result<()> {
117 self.apply_with_params(qdev, wires, None)
118 }
119 fn apply_with_params(
120 &mut self,
121 qdev: &mut TQDevice,
122 wires: &[usize],
123 params: Option<&[f64]>,
124 ) -> Result<()> {
125 if wires.len() < 2 {
126 return Err(MLError::InvalidConfiguration(
127 "CRot gate requires exactly 2 wires".to_string(),
128 ));
129 }
130 let matrix = self.get_matrix(params);
131 qdev.apply_two_qubit_gate(wires[0], wires[1], &matrix)?;
132 if qdev.record_op {
133 qdev.record_operation(OpHistoryEntry {
134 name: "crot".to_string(),
135 wires: wires.to_vec(),
136 params: params.map(|p| p.to_vec()),
137 inverse: self.inverse,
138 trainable: self.trainable,
139 });
140 }
141 Ok(())
142 }
143 fn has_params(&self) -> bool {
144 self.has_params
145 }
146 fn trainable(&self) -> bool {
147 self.trainable
148 }
149 fn inverse(&self) -> bool {
150 self.inverse
151 }
152 fn set_inverse(&mut self, inverse: bool) {
153 self.inverse = inverse;
154 }
155}