Skip to main content

quantrs2_ml/torchquantum/layer/
tqparticleconservinglayer_traits.rs

1//! # TQParticleConservingLayer - Trait Implementations
2//!
3//! This module contains trait implementations for `TQParticleConservingLayer`.
4//!
5//! ## Implemented Traits
6//!
7//! - `TQModule`
8//!
9//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
10
11use crate::error::{MLError, Result};
12use crate::torchquantum::gates::{TQFSimGate, TQGivensRotation};
13use crate::torchquantum::{TQDevice, TQModule, TQOperator, TQParameter};
14
15use super::functions::{create_single_qubit_gate, create_two_qubit_gate};
16use super::types::TQParticleConservingLayer;
17
18impl TQModule for TQParticleConservingLayer {
19    fn forward(&mut self, qdev: &mut TQDevice) -> Result<()> {
20        let pairs = self.get_wire_pairs();
21        for (gate_idx, (w0, w1)) in pairs.iter().enumerate() {
22            if gate_idx < self.gates.len() {
23                self.gates[gate_idx].apply(qdev, &[*w0, *w1])?;
24            }
25        }
26        Ok(())
27    }
28    fn parameters(&self) -> Vec<TQParameter> {
29        self.gates.iter().flat_map(|g| g.parameters()).collect()
30    }
31    fn n_wires(&self) -> Option<usize> {
32        Some(self.n_wires)
33    }
34    fn set_n_wires(&mut self, n_wires: usize) {
35        self.n_wires = n_wires;
36        let n_gates = Self::count_gates(n_wires, self.n_blocks, self.pattern);
37        self.gates = (0..n_gates)
38            .map(|_| TQGivensRotation::new(true, true))
39            .collect();
40    }
41    fn is_static_mode(&self) -> bool {
42        self.static_mode
43    }
44    fn static_on(&mut self) {
45        self.static_mode = true;
46        for gate in &mut self.gates {
47            gate.static_on();
48        }
49    }
50    fn static_off(&mut self) {
51        self.static_mode = false;
52        for gate in &mut self.gates {
53            gate.static_off();
54        }
55    }
56    fn name(&self) -> &str {
57        "ParticleConservingLayer"
58    }
59    fn zero_grad(&mut self) {
60        for gate in &mut self.gates {
61            gate.zero_grad();
62        }
63    }
64}