1use crate::error::{MLError, Result};
6use crate::torchquantum::{
7 CType, NParamsEnum, OpHistoryEntry, TQDevice, TQModule, TQOperator, TQParameter, WiresEnum,
8};
9use scirs2_core::ndarray::{Array2, ArrayD, IxDyn};
10
11#[derive(Debug, Clone)]
13pub struct TQCRX {
14 params: Option<TQParameter>,
15 has_params: bool,
16 trainable: bool,
17 inverse: bool,
18 static_mode: bool,
19}
20
21impl TQCRX {
22 pub fn new(has_params: bool, trainable: bool) -> Self {
23 let params = if has_params {
24 Some(TQParameter::new(ArrayD::zeros(IxDyn(&[1, 1])), "crx_theta"))
25 } else {
26 None
27 };
28
29 Self {
30 params,
31 has_params,
32 trainable,
33 inverse: false,
34 static_mode: false,
35 }
36 }
37
38 pub fn with_init_params(mut self, theta: f64) -> Self {
39 if let Some(ref mut p) = self.params {
40 p.data[[0, 0]] = theta;
41 }
42 self
43 }
44}
45
46impl Default for TQCRX {
47 fn default() -> Self {
48 Self::new(true, true)
49 }
50}
51
52impl TQModule for TQCRX {
53 fn forward(&mut self, _qdev: &mut TQDevice) -> Result<()> {
54 Err(MLError::InvalidConfiguration(
55 "Use apply() instead of forward() for operators".to_string(),
56 ))
57 }
58
59 fn parameters(&self) -> Vec<TQParameter> {
60 self.params.iter().cloned().collect()
61 }
62
63 fn n_wires(&self) -> Option<usize> {
64 Some(2)
65 }
66
67 fn set_n_wires(&mut self, _n_wires: usize) {}
68
69 fn is_static_mode(&self) -> bool {
70 self.static_mode
71 }
72
73 fn static_on(&mut self) {
74 self.static_mode = true;
75 }
76
77 fn static_off(&mut self) {
78 self.static_mode = false;
79 }
80
81 fn name(&self) -> &str {
82 "CRX"
83 }
84
85 fn zero_grad(&mut self) {
86 if let Some(ref mut p) = self.params {
87 p.zero_grad();
88 }
89 }
90}
91
92impl TQOperator for TQCRX {
93 fn num_wires(&self) -> WiresEnum {
94 WiresEnum::Fixed(2)
95 }
96
97 fn num_params(&self) -> NParamsEnum {
98 NParamsEnum::Fixed(1)
99 }
100
101 fn get_matrix(&self, params: Option<&[f64]>) -> Array2<CType> {
102 let theta = params
103 .and_then(|p| p.first().copied())
104 .or_else(|| self.params.as_ref().map(|p| p.data[[0, 0]]))
105 .unwrap_or(0.0);
106
107 let theta = if self.inverse { -theta } else { theta };
108 let half_theta = theta / 2.0;
109 let c = half_theta.cos();
110 let s = half_theta.sin();
111
112 Array2::from_shape_vec(
114 (4, 4),
115 vec![
116 CType::new(1.0, 0.0),
117 CType::new(0.0, 0.0),
118 CType::new(0.0, 0.0),
119 CType::new(0.0, 0.0),
120 CType::new(0.0, 0.0),
121 CType::new(1.0, 0.0),
122 CType::new(0.0, 0.0),
123 CType::new(0.0, 0.0),
124 CType::new(0.0, 0.0),
125 CType::new(0.0, 0.0),
126 CType::new(c, 0.0),
127 CType::new(0.0, -s),
128 CType::new(0.0, 0.0),
129 CType::new(0.0, 0.0),
130 CType::new(0.0, -s),
131 CType::new(c, 0.0),
132 ],
133 )
134 .unwrap_or_else(|_| Array2::eye(4).mapv(|x| CType::new(x, 0.0)))
135 }
136
137 fn apply(&mut self, qdev: &mut TQDevice, wires: &[usize]) -> Result<()> {
138 self.apply_with_params(qdev, wires, None)
139 }
140
141 fn apply_with_params(
142 &mut self,
143 qdev: &mut TQDevice,
144 wires: &[usize],
145 params: Option<&[f64]>,
146 ) -> Result<()> {
147 if wires.len() < 2 {
148 return Err(MLError::InvalidConfiguration(
149 "CRX gate requires exactly 2 wires".to_string(),
150 ));
151 }
152 let matrix = self.get_matrix(params);
153 qdev.apply_two_qubit_gate(wires[0], wires[1], &matrix)?;
154
155 if qdev.record_op {
156 qdev.record_operation(OpHistoryEntry {
157 name: "crx".to_string(),
158 wires: wires.to_vec(),
159 params: params.map(|p| p.to_vec()),
160 inverse: self.inverse,
161 trainable: self.trainable,
162 });
163 }
164
165 Ok(())
166 }
167
168 fn has_params(&self) -> bool {
169 self.has_params
170 }
171
172 fn trainable(&self) -> bool {
173 self.trainable
174 }
175
176 fn inverse(&self) -> bool {
177 self.inverse
178 }
179
180 fn set_inverse(&mut self, inverse: bool) {
181 self.inverse = inverse;
182 }
183}
184
185#[derive(Debug, Clone)]
187pub struct TQCRY {
188 params: Option<TQParameter>,
189 has_params: bool,
190 trainable: bool,
191 inverse: bool,
192 static_mode: bool,
193}
194
195impl TQCRY {
196 pub fn new(has_params: bool, trainable: bool) -> Self {
197 let params = if has_params {
198 Some(TQParameter::new(ArrayD::zeros(IxDyn(&[1, 1])), "cry_theta"))
199 } else {
200 None
201 };
202
203 Self {
204 params,
205 has_params,
206 trainable,
207 inverse: false,
208 static_mode: false,
209 }
210 }
211
212 pub fn with_init_params(mut self, theta: f64) -> Self {
213 if let Some(ref mut p) = self.params {
214 p.data[[0, 0]] = theta;
215 }
216 self
217 }
218}
219
220impl Default for TQCRY {
221 fn default() -> Self {
222 Self::new(true, true)
223 }
224}
225
226impl TQModule for TQCRY {
227 fn forward(&mut self, _qdev: &mut TQDevice) -> Result<()> {
228 Err(MLError::InvalidConfiguration(
229 "Use apply() instead of forward() for operators".to_string(),
230 ))
231 }
232
233 fn parameters(&self) -> Vec<TQParameter> {
234 self.params.iter().cloned().collect()
235 }
236
237 fn n_wires(&self) -> Option<usize> {
238 Some(2)
239 }
240
241 fn set_n_wires(&mut self, _n_wires: usize) {}
242
243 fn is_static_mode(&self) -> bool {
244 self.static_mode
245 }
246
247 fn static_on(&mut self) {
248 self.static_mode = true;
249 }
250
251 fn static_off(&mut self) {
252 self.static_mode = false;
253 }
254
255 fn name(&self) -> &str {
256 "CRY"
257 }
258
259 fn zero_grad(&mut self) {
260 if let Some(ref mut p) = self.params {
261 p.zero_grad();
262 }
263 }
264}
265
266impl TQOperator for TQCRY {
267 fn num_wires(&self) -> WiresEnum {
268 WiresEnum::Fixed(2)
269 }
270
271 fn num_params(&self) -> NParamsEnum {
272 NParamsEnum::Fixed(1)
273 }
274
275 fn get_matrix(&self, params: Option<&[f64]>) -> Array2<CType> {
276 let theta = params
277 .and_then(|p| p.first().copied())
278 .or_else(|| self.params.as_ref().map(|p| p.data[[0, 0]]))
279 .unwrap_or(0.0);
280
281 let theta = if self.inverse { -theta } else { theta };
282 let half_theta = theta / 2.0;
283 let c = half_theta.cos();
284 let s = half_theta.sin();
285
286 Array2::from_shape_vec(
288 (4, 4),
289 vec![
290 CType::new(1.0, 0.0),
291 CType::new(0.0, 0.0),
292 CType::new(0.0, 0.0),
293 CType::new(0.0, 0.0),
294 CType::new(0.0, 0.0),
295 CType::new(1.0, 0.0),
296 CType::new(0.0, 0.0),
297 CType::new(0.0, 0.0),
298 CType::new(0.0, 0.0),
299 CType::new(0.0, 0.0),
300 CType::new(c, 0.0),
301 CType::new(-s, 0.0),
302 CType::new(0.0, 0.0),
303 CType::new(0.0, 0.0),
304 CType::new(s, 0.0),
305 CType::new(c, 0.0),
306 ],
307 )
308 .unwrap_or_else(|_| Array2::eye(4).mapv(|x| CType::new(x, 0.0)))
309 }
310
311 fn apply(&mut self, qdev: &mut TQDevice, wires: &[usize]) -> Result<()> {
312 self.apply_with_params(qdev, wires, None)
313 }
314
315 fn apply_with_params(
316 &mut self,
317 qdev: &mut TQDevice,
318 wires: &[usize],
319 params: Option<&[f64]>,
320 ) -> Result<()> {
321 if wires.len() < 2 {
322 return Err(MLError::InvalidConfiguration(
323 "CRY gate requires exactly 2 wires".to_string(),
324 ));
325 }
326 let matrix = self.get_matrix(params);
327 qdev.apply_two_qubit_gate(wires[0], wires[1], &matrix)?;
328
329 if qdev.record_op {
330 qdev.record_operation(OpHistoryEntry {
331 name: "cry".to_string(),
332 wires: wires.to_vec(),
333 params: params.map(|p| p.to_vec()),
334 inverse: self.inverse,
335 trainable: self.trainable,
336 });
337 }
338
339 Ok(())
340 }
341
342 fn has_params(&self) -> bool {
343 self.has_params
344 }
345
346 fn trainable(&self) -> bool {
347 self.trainable
348 }
349
350 fn inverse(&self) -> bool {
351 self.inverse
352 }
353
354 fn set_inverse(&mut self, inverse: bool) {
355 self.inverse = inverse;
356 }
357}
358
359#[derive(Debug, Clone)]
361pub struct TQCRZ {
362 params: Option<TQParameter>,
363 has_params: bool,
364 trainable: bool,
365 inverse: bool,
366 static_mode: bool,
367}
368
369impl TQCRZ {
370 pub fn new(has_params: bool, trainable: bool) -> Self {
371 let params = if has_params {
372 Some(TQParameter::new(ArrayD::zeros(IxDyn(&[1, 1])), "crz_theta"))
373 } else {
374 None
375 };
376
377 Self {
378 params,
379 has_params,
380 trainable,
381 inverse: false,
382 static_mode: false,
383 }
384 }
385
386 pub fn with_init_params(mut self, theta: f64) -> Self {
387 if let Some(ref mut p) = self.params {
388 p.data[[0, 0]] = theta;
389 }
390 self
391 }
392}
393
394impl Default for TQCRZ {
395 fn default() -> Self {
396 Self::new(true, true)
397 }
398}
399
400impl TQModule for TQCRZ {
401 fn forward(&mut self, _qdev: &mut TQDevice) -> Result<()> {
402 Err(MLError::InvalidConfiguration(
403 "Use apply() instead of forward() for operators".to_string(),
404 ))
405 }
406
407 fn parameters(&self) -> Vec<TQParameter> {
408 self.params.iter().cloned().collect()
409 }
410
411 fn n_wires(&self) -> Option<usize> {
412 Some(2)
413 }
414
415 fn set_n_wires(&mut self, _n_wires: usize) {}
416
417 fn is_static_mode(&self) -> bool {
418 self.static_mode
419 }
420
421 fn static_on(&mut self) {
422 self.static_mode = true;
423 }
424
425 fn static_off(&mut self) {
426 self.static_mode = false;
427 }
428
429 fn name(&self) -> &str {
430 "CRZ"
431 }
432
433 fn zero_grad(&mut self) {
434 if let Some(ref mut p) = self.params {
435 p.zero_grad();
436 }
437 }
438}
439
440impl TQOperator for TQCRZ {
441 fn num_wires(&self) -> WiresEnum {
442 WiresEnum::Fixed(2)
443 }
444
445 fn num_params(&self) -> NParamsEnum {
446 NParamsEnum::Fixed(1)
447 }
448
449 fn get_matrix(&self, params: Option<&[f64]>) -> Array2<CType> {
450 let theta = params
451 .and_then(|p| p.first().copied())
452 .or_else(|| self.params.as_ref().map(|p| p.data[[0, 0]]))
453 .unwrap_or(0.0);
454
455 let theta = if self.inverse { -theta } else { theta };
456 let half_theta = theta / 2.0;
457
458 let exp_neg = CType::from_polar(1.0, -half_theta);
460 let exp_pos = CType::from_polar(1.0, half_theta);
461
462 Array2::from_shape_vec(
463 (4, 4),
464 vec![
465 CType::new(1.0, 0.0),
466 CType::new(0.0, 0.0),
467 CType::new(0.0, 0.0),
468 CType::new(0.0, 0.0),
469 CType::new(0.0, 0.0),
470 CType::new(1.0, 0.0),
471 CType::new(0.0, 0.0),
472 CType::new(0.0, 0.0),
473 CType::new(0.0, 0.0),
474 CType::new(0.0, 0.0),
475 exp_neg,
476 CType::new(0.0, 0.0),
477 CType::new(0.0, 0.0),
478 CType::new(0.0, 0.0),
479 CType::new(0.0, 0.0),
480 exp_pos,
481 ],
482 )
483 .unwrap_or_else(|_| Array2::eye(4).mapv(|x| CType::new(x, 0.0)))
484 }
485
486 fn apply(&mut self, qdev: &mut TQDevice, wires: &[usize]) -> Result<()> {
487 self.apply_with_params(qdev, wires, None)
488 }
489
490 fn apply_with_params(
491 &mut self,
492 qdev: &mut TQDevice,
493 wires: &[usize],
494 params: Option<&[f64]>,
495 ) -> Result<()> {
496 if wires.len() < 2 {
497 return Err(MLError::InvalidConfiguration(
498 "CRZ gate requires exactly 2 wires".to_string(),
499 ));
500 }
501 let matrix = self.get_matrix(params);
502 qdev.apply_two_qubit_gate(wires[0], wires[1], &matrix)?;
503
504 if qdev.record_op {
505 qdev.record_operation(OpHistoryEntry {
506 name: "crz".to_string(),
507 wires: wires.to_vec(),
508 params: params.map(|p| p.to_vec()),
509 inverse: self.inverse,
510 trainable: self.trainable,
511 });
512 }
513
514 Ok(())
515 }
516
517 fn has_params(&self) -> bool {
518 self.has_params
519 }
520
521 fn trainable(&self) -> bool {
522 self.trainable
523 }
524
525 fn inverse(&self) -> bool {
526 self.inverse
527 }
528
529 fn set_inverse(&mut self, inverse: bool) {
530 self.inverse = inverse;
531 }
532}