quantrs2_ml/torchquantum/gates/single_qubit/
tqrx_traits.rs

1//! # TQRx - Trait Implementations
2//!
3//! This module contains trait implementations for `TQRx`.
4//!
5//! ## Implemented Traits
6//!
7//! - `TQModule`
8//! - `TQOperator`
9//!
10//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
11
12use super::super::super::{
13    CType, NParamsEnum, OpHistoryEntry, TQDevice, TQModule, TQOperator, TQParameter, WiresEnum,
14};
15use crate::error::{MLError, Result};
16use scirs2_core::ndarray::{Array1, Array2, ArrayD, IxDyn};
17
18use super::types::TQRx;
19
20impl TQModule for TQRx {
21    fn forward(&mut self, _qdev: &mut TQDevice) -> Result<()> {
22        Err(MLError::InvalidConfiguration(
23            "Use apply() instead of forward() for operators".to_string(),
24        ))
25    }
26    fn parameters(&self) -> Vec<TQParameter> {
27        self.params.iter().cloned().collect()
28    }
29    fn n_wires(&self) -> Option<usize> {
30        Some(1)
31    }
32    fn set_n_wires(&mut self, _n_wires: usize) {}
33    fn is_static_mode(&self) -> bool {
34        self.static_mode
35    }
36    fn static_on(&mut self) {
37        self.static_mode = true;
38    }
39    fn static_off(&mut self) {
40        self.static_mode = false;
41    }
42    fn name(&self) -> &str {
43        "RX"
44    }
45    fn zero_grad(&mut self) {
46        if let Some(ref mut p) = self.params {
47            p.zero_grad();
48        }
49    }
50}
51
52impl TQOperator for TQRx {
53    fn num_wires(&self) -> WiresEnum {
54        WiresEnum::Fixed(1)
55    }
56    fn num_params(&self) -> NParamsEnum {
57        NParamsEnum::Fixed(1)
58    }
59    fn get_matrix(&self, params: Option<&[f64]>) -> Array2<CType> {
60        let theta = params
61            .and_then(|p| p.first().copied())
62            .or_else(|| self.params.as_ref().map(|p| p.data[[0, 0]]))
63            .unwrap_or(0.0);
64        let theta = if self.inverse { -theta } else { theta };
65        let cos_half = (theta / 2.0).cos();
66        let sin_half = (theta / 2.0).sin();
67        Array2::from_shape_vec(
68            (2, 2),
69            vec![
70                CType::new(cos_half, 0.0),
71                CType::new(0.0, -sin_half),
72                CType::new(0.0, -sin_half),
73                CType::new(cos_half, 0.0),
74            ],
75        )
76        .unwrap_or_else(|_| Array2::eye(2).mapv(|x| CType::new(x, 0.0)))
77    }
78    fn apply(&mut self, qdev: &mut TQDevice, wires: &[usize]) -> Result<()> {
79        self.apply_with_params(qdev, wires, None)
80    }
81    fn apply_with_params(
82        &mut self,
83        qdev: &mut TQDevice,
84        wires: &[usize],
85        params: Option<&[f64]>,
86    ) -> Result<()> {
87        if wires.is_empty() {
88            return Err(MLError::InvalidConfiguration(
89                "RX gate requires exactly 1 wire".to_string(),
90            ));
91        }
92        let matrix = self.get_matrix(params);
93        qdev.apply_single_qubit_gate(wires[0], &matrix)?;
94        if qdev.record_op {
95            qdev.record_operation(OpHistoryEntry {
96                name: "rx".to_string(),
97                wires: wires.to_vec(),
98                params: params.map(|p| p.to_vec()),
99                inverse: self.inverse,
100                trainable: self.trainable,
101            });
102        }
103        Ok(())
104    }
105    fn has_params(&self) -> bool {
106        self.has_params
107    }
108    fn trainable(&self) -> bool {
109        self.trainable
110    }
111    fn inverse(&self) -> bool {
112        self.inverse
113    }
114    fn set_inverse(&mut self, inverse: bool) {
115        self.inverse = inverse;
116    }
117}