quantrs2_ml/torchquantum/layer/
tqparticleconservinglayer_traits.rs1use crate::error::{MLError, Result};
12use crate::torchquantum::gates::{TQFSimGate, TQGivensRotation};
13use crate::torchquantum::{TQDevice, TQModule, TQOperator, TQParameter};
14
15use super::functions::{create_single_qubit_gate, create_two_qubit_gate};
16use super::types::TQParticleConservingLayer;
17
18impl TQModule for TQParticleConservingLayer {
19 fn forward(&mut self, qdev: &mut TQDevice) -> Result<()> {
20 let pairs = self.get_wire_pairs();
21 for (gate_idx, (w0, w1)) in pairs.iter().enumerate() {
22 if gate_idx < self.gates.len() {
23 self.gates[gate_idx].apply(qdev, &[*w0, *w1])?;
24 }
25 }
26 Ok(())
27 }
28 fn parameters(&self) -> Vec<TQParameter> {
29 self.gates.iter().flat_map(|g| g.parameters()).collect()
30 }
31 fn n_wires(&self) -> Option<usize> {
32 Some(self.n_wires)
33 }
34 fn set_n_wires(&mut self, n_wires: usize) {
35 self.n_wires = n_wires;
36 let n_gates = Self::count_gates(n_wires, self.n_blocks, self.pattern);
37 self.gates = (0..n_gates)
38 .map(|_| TQGivensRotation::new(true, true))
39 .collect();
40 }
41 fn is_static_mode(&self) -> bool {
42 self.static_mode
43 }
44 fn static_on(&mut self) {
45 self.static_mode = true;
46 for gate in &mut self.gates {
47 gate.static_on();
48 }
49 }
50 fn static_off(&mut self) {
51 self.static_mode = false;
52 for gate in &mut self.gates {
53 gate.static_off();
54 }
55 }
56 fn name(&self) -> &str {
57 "ParticleConservingLayer"
58 }
59 fn zero_grad(&mut self) {
60 for gate in &mut self.gates {
61 gate.zero_grad();
62 }
63 }
64}