quantrs2_ml/torchquantum/layer/
tqexcitationpreservinglayer_traits.rs1use crate::error::{MLError, Result};
12use crate::torchquantum::gates::{TQFSimGate, TQGivensRotation};
13use crate::torchquantum::{TQDevice, TQModule, TQOperator, TQParameter};
14
15use super::functions::{create_single_qubit_gate, create_two_qubit_gate};
16use super::types::TQExcitationPreservingLayer;
17
18impl TQModule for TQExcitationPreservingLayer {
19 fn forward(&mut self, qdev: &mut TQDevice) -> Result<()> {
20 let n_pairs = if self.circular {
21 self.n_wires
22 } else {
23 self.n_wires.saturating_sub(1)
24 };
25 let mut gate_idx = 0;
26 for _ in 0..self.n_blocks {
27 for pair in 0..n_pairs {
28 let w0 = pair;
29 let w1 = (pair + 1) % self.n_wires;
30 if gate_idx < self.gates.len() {
31 self.gates[gate_idx].apply(qdev, &[w0, w1])?;
32 gate_idx += 1;
33 }
34 }
35 }
36 Ok(())
37 }
38 fn parameters(&self) -> Vec<TQParameter> {
39 self.gates.iter().flat_map(|g| g.parameters()).collect()
40 }
41 fn n_wires(&self) -> Option<usize> {
42 Some(self.n_wires)
43 }
44 fn set_n_wires(&mut self, n_wires: usize) {
45 self.n_wires = n_wires;
46 let n_pairs = if self.circular {
47 n_wires
48 } else {
49 n_wires.saturating_sub(1)
50 };
51 let total_gates = n_pairs * self.n_blocks;
52 self.gates = (0..total_gates)
53 .map(|_| TQFSimGate::new(true, true))
54 .collect();
55 }
56 fn is_static_mode(&self) -> bool {
57 self.static_mode
58 }
59 fn static_on(&mut self) {
60 self.static_mode = true;
61 for gate in &mut self.gates {
62 gate.static_on();
63 }
64 }
65 fn static_off(&mut self) {
66 self.static_mode = false;
67 for gate in &mut self.gates {
68 gate.static_off();
69 }
70 }
71 fn name(&self) -> &str {
72 "ExcitationPreservingLayer"
73 }
74 fn zero_grad(&mut self) {
75 for gate in &mut self.gates {
76 gate.zero_grad();
77 }
78 }
79}