Skip to main content

quantrs2_ml/torchquantum/layer/
tqexcitationpreservinglayer_traits.rs

1//! # TQExcitationPreservingLayer - Trait Implementations
2//!
3//! This module contains trait implementations for `TQExcitationPreservingLayer`.
4//!
5//! ## Implemented Traits
6//!
7//! - `TQModule`
8//!
9//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
10
11use crate::error::{MLError, Result};
12use crate::torchquantum::gates::{TQFSimGate, TQGivensRotation};
13use crate::torchquantum::{TQDevice, TQModule, TQOperator, TQParameter};
14
15use super::functions::{create_single_qubit_gate, create_two_qubit_gate};
16use super::types::TQExcitationPreservingLayer;
17
18impl TQModule for TQExcitationPreservingLayer {
19    fn forward(&mut self, qdev: &mut TQDevice) -> Result<()> {
20        let n_pairs = if self.circular {
21            self.n_wires
22        } else {
23            self.n_wires.saturating_sub(1)
24        };
25        let mut gate_idx = 0;
26        for _ in 0..self.n_blocks {
27            for pair in 0..n_pairs {
28                let w0 = pair;
29                let w1 = (pair + 1) % self.n_wires;
30                if gate_idx < self.gates.len() {
31                    self.gates[gate_idx].apply(qdev, &[w0, w1])?;
32                    gate_idx += 1;
33                }
34            }
35        }
36        Ok(())
37    }
38    fn parameters(&self) -> Vec<TQParameter> {
39        self.gates.iter().flat_map(|g| g.parameters()).collect()
40    }
41    fn n_wires(&self) -> Option<usize> {
42        Some(self.n_wires)
43    }
44    fn set_n_wires(&mut self, n_wires: usize) {
45        self.n_wires = n_wires;
46        let n_pairs = if self.circular {
47            n_wires
48        } else {
49            n_wires.saturating_sub(1)
50        };
51        let total_gates = n_pairs * self.n_blocks;
52        self.gates = (0..total_gates)
53            .map(|_| TQFSimGate::new(true, true))
54            .collect();
55    }
56    fn is_static_mode(&self) -> bool {
57        self.static_mode
58    }
59    fn static_on(&mut self) {
60        self.static_mode = true;
61        for gate in &mut self.gates {
62            gate.static_on();
63        }
64    }
65    fn static_off(&mut self) {
66        self.static_mode = false;
67        for gate in &mut self.gates {
68            gate.static_off();
69        }
70    }
71    fn name(&self) -> &str {
72        "ExcitationPreservingLayer"
73    }
74    fn zero_grad(&mut self) {
75        for gate in &mut self.gates {
76            gate.zero_grad();
77        }
78    }
79}