quantrs2_ml/torchquantum/gates/single_qubit/
tqu1_traits.rs

1//! # TQU1 - Trait Implementations
2//!
3//! This module contains trait implementations for `TQU1`.
4//!
5//! ## Implemented Traits
6//!
7//! - `TQModule`
8//! - `TQOperator`
9//!
10//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
11
12use super::super::super::{
13    CType, NParamsEnum, OpHistoryEntry, TQDevice, TQModule, TQOperator, TQParameter, WiresEnum,
14};
15use crate::error::{MLError, Result};
16use scirs2_core::ndarray::{Array1, Array2, ArrayD, IxDyn};
17
18use super::types::TQU1;
19
20impl TQModule for TQU1 {
21    fn forward(&mut self, _qdev: &mut TQDevice) -> Result<()> {
22        Err(MLError::InvalidConfiguration(
23            "Use apply() instead of forward() for operators".to_string(),
24        ))
25    }
26    fn parameters(&self) -> Vec<TQParameter> {
27        self.params.iter().cloned().collect()
28    }
29    fn n_wires(&self) -> Option<usize> {
30        Some(1)
31    }
32    fn set_n_wires(&mut self, _n_wires: usize) {}
33    fn is_static_mode(&self) -> bool {
34        self.static_mode
35    }
36    fn static_on(&mut self) {
37        self.static_mode = true;
38    }
39    fn static_off(&mut self) {
40        self.static_mode = false;
41    }
42    fn name(&self) -> &str {
43        "U1"
44    }
45    fn zero_grad(&mut self) {
46        if let Some(ref mut p) = self.params {
47            p.zero_grad();
48        }
49    }
50}
51
52impl TQOperator for TQU1 {
53    fn num_wires(&self) -> WiresEnum {
54        WiresEnum::Fixed(1)
55    }
56    fn num_params(&self) -> NParamsEnum {
57        NParamsEnum::Fixed(1)
58    }
59    fn get_matrix(&self, params: Option<&[f64]>) -> Array2<CType> {
60        let lambda = params
61            .and_then(|p| p.first().copied())
62            .or_else(|| self.params.as_ref().map(|p| p.data[[0, 0]]))
63            .unwrap_or(0.0);
64        let lambda = if self.inverse { -lambda } else { lambda };
65        Array2::from_shape_vec(
66            (2, 2),
67            vec![
68                CType::new(1.0, 0.0),
69                CType::new(0.0, 0.0),
70                CType::new(0.0, 0.0),
71                CType::from_polar(1.0, lambda),
72            ],
73        )
74        .unwrap_or_else(|_| Array2::eye(2).mapv(|x| CType::new(x, 0.0)))
75    }
76    fn apply(&mut self, qdev: &mut TQDevice, wires: &[usize]) -> Result<()> {
77        self.apply_with_params(qdev, wires, None)
78    }
79    fn apply_with_params(
80        &mut self,
81        qdev: &mut TQDevice,
82        wires: &[usize],
83        params: Option<&[f64]>,
84    ) -> Result<()> {
85        if wires.is_empty() {
86            return Err(MLError::InvalidConfiguration(
87                "U1 gate requires exactly 1 wire".to_string(),
88            ));
89        }
90        let matrix = self.get_matrix(params);
91        qdev.apply_single_qubit_gate(wires[0], &matrix)?;
92        if qdev.record_op {
93            qdev.record_operation(OpHistoryEntry {
94                name: "u1".to_string(),
95                wires: wires.to_vec(),
96                params: params.map(|p| p.to_vec()),
97                inverse: self.inverse,
98                trainable: self.trainable,
99            });
100        }
101        Ok(())
102    }
103    fn has_params(&self) -> bool {
104        self.has_params
105    }
106    fn trainable(&self) -> bool {
107        self.trainable
108    }
109    fn inverse(&self) -> bool {
110        self.inverse
111    }
112    fn set_inverse(&mut self, inverse: bool) {
113        self.inverse = inverse;
114    }
115}