1use crate::error::{MLError, Result};
6use crate::torchquantum::{
7 CType, NParamsEnum, OpHistoryEntry, TQDevice, TQModule, TQOperator, TQParameter, WiresEnum,
8};
9use scirs2_core::ndarray::Array2;
10
11#[derive(Debug, Clone)]
13pub struct TQCNOT {
14 inverse: bool,
15 static_mode: bool,
16}
17
18impl TQCNOT {
19 pub fn new() -> Self {
20 Self {
21 inverse: false,
22 static_mode: false,
23 }
24 }
25}
26
27impl Default for TQCNOT {
28 fn default() -> Self {
29 Self::new()
30 }
31}
32
33impl TQModule for TQCNOT {
34 fn forward(&mut self, _qdev: &mut TQDevice) -> Result<()> {
35 Err(MLError::InvalidConfiguration(
36 "Use apply() instead of forward() for operators".to_string(),
37 ))
38 }
39
40 fn parameters(&self) -> Vec<TQParameter> {
41 Vec::new()
42 }
43
44 fn n_wires(&self) -> Option<usize> {
45 Some(2)
46 }
47
48 fn set_n_wires(&mut self, _n_wires: usize) {}
49
50 fn is_static_mode(&self) -> bool {
51 self.static_mode
52 }
53
54 fn static_on(&mut self) {
55 self.static_mode = true;
56 }
57
58 fn static_off(&mut self) {
59 self.static_mode = false;
60 }
61
62 fn name(&self) -> &str {
63 "CNOT"
64 }
65}
66
67impl TQOperator for TQCNOT {
68 fn num_wires(&self) -> WiresEnum {
69 WiresEnum::Fixed(2)
70 }
71
72 fn num_params(&self) -> NParamsEnum {
73 NParamsEnum::Fixed(0)
74 }
75
76 fn get_matrix(&self, _params: Option<&[f64]>) -> Array2<CType> {
77 Array2::from_shape_vec(
78 (4, 4),
79 vec![
80 CType::new(1.0, 0.0),
81 CType::new(0.0, 0.0),
82 CType::new(0.0, 0.0),
83 CType::new(0.0, 0.0),
84 CType::new(0.0, 0.0),
85 CType::new(1.0, 0.0),
86 CType::new(0.0, 0.0),
87 CType::new(0.0, 0.0),
88 CType::new(0.0, 0.0),
89 CType::new(0.0, 0.0),
90 CType::new(0.0, 0.0),
91 CType::new(1.0, 0.0),
92 CType::new(0.0, 0.0),
93 CType::new(0.0, 0.0),
94 CType::new(1.0, 0.0),
95 CType::new(0.0, 0.0),
96 ],
97 )
98 .unwrap_or_else(|_| Array2::eye(4).mapv(|x| CType::new(x, 0.0)))
99 }
100
101 fn apply(&mut self, qdev: &mut TQDevice, wires: &[usize]) -> Result<()> {
102 self.apply_with_params(qdev, wires, None)
103 }
104
105 fn apply_with_params(
106 &mut self,
107 qdev: &mut TQDevice,
108 wires: &[usize],
109 _params: Option<&[f64]>,
110 ) -> Result<()> {
111 if wires.len() < 2 {
112 return Err(MLError::InvalidConfiguration(
113 "CNOT gate requires exactly 2 wires".to_string(),
114 ));
115 }
116 let matrix = self.get_matrix(None);
117 qdev.apply_two_qubit_gate(wires[0], wires[1], &matrix)?;
118
119 if qdev.record_op {
120 qdev.record_operation(OpHistoryEntry {
121 name: "cnot".to_string(),
122 wires: wires.to_vec(),
123 params: None,
124 inverse: self.inverse,
125 trainable: false,
126 });
127 }
128
129 Ok(())
130 }
131
132 fn has_params(&self) -> bool {
133 false
134 }
135
136 fn trainable(&self) -> bool {
137 false
138 }
139
140 fn inverse(&self) -> bool {
141 self.inverse
142 }
143
144 fn set_inverse(&mut self, inverse: bool) {
145 self.inverse = inverse;
146 }
147}
148
149#[derive(Debug, Clone)]
151pub struct TQCZ {
152 inverse: bool,
153 static_mode: bool,
154}
155
156impl TQCZ {
157 pub fn new() -> Self {
158 Self {
159 inverse: false,
160 static_mode: false,
161 }
162 }
163}
164
165impl Default for TQCZ {
166 fn default() -> Self {
167 Self::new()
168 }
169}
170
171impl TQModule for TQCZ {
172 fn forward(&mut self, _qdev: &mut TQDevice) -> Result<()> {
173 Err(MLError::InvalidConfiguration(
174 "Use apply() instead of forward() for operators".to_string(),
175 ))
176 }
177
178 fn parameters(&self) -> Vec<TQParameter> {
179 Vec::new()
180 }
181
182 fn n_wires(&self) -> Option<usize> {
183 Some(2)
184 }
185
186 fn set_n_wires(&mut self, _n_wires: usize) {}
187
188 fn is_static_mode(&self) -> bool {
189 self.static_mode
190 }
191
192 fn static_on(&mut self) {
193 self.static_mode = true;
194 }
195
196 fn static_off(&mut self) {
197 self.static_mode = false;
198 }
199
200 fn name(&self) -> &str {
201 "CZ"
202 }
203}
204
205impl TQOperator for TQCZ {
206 fn num_wires(&self) -> WiresEnum {
207 WiresEnum::Fixed(2)
208 }
209
210 fn num_params(&self) -> NParamsEnum {
211 NParamsEnum::Fixed(0)
212 }
213
214 fn get_matrix(&self, _params: Option<&[f64]>) -> Array2<CType> {
215 Array2::from_shape_vec(
216 (4, 4),
217 vec![
218 CType::new(1.0, 0.0),
219 CType::new(0.0, 0.0),
220 CType::new(0.0, 0.0),
221 CType::new(0.0, 0.0),
222 CType::new(0.0, 0.0),
223 CType::new(1.0, 0.0),
224 CType::new(0.0, 0.0),
225 CType::new(0.0, 0.0),
226 CType::new(0.0, 0.0),
227 CType::new(0.0, 0.0),
228 CType::new(1.0, 0.0),
229 CType::new(0.0, 0.0),
230 CType::new(0.0, 0.0),
231 CType::new(0.0, 0.0),
232 CType::new(0.0, 0.0),
233 CType::new(-1.0, 0.0),
234 ],
235 )
236 .unwrap_or_else(|_| Array2::eye(4).mapv(|x| CType::new(x, 0.0)))
237 }
238
239 fn apply(&mut self, qdev: &mut TQDevice, wires: &[usize]) -> Result<()> {
240 self.apply_with_params(qdev, wires, None)
241 }
242
243 fn apply_with_params(
244 &mut self,
245 qdev: &mut TQDevice,
246 wires: &[usize],
247 _params: Option<&[f64]>,
248 ) -> Result<()> {
249 if wires.len() < 2 {
250 return Err(MLError::InvalidConfiguration(
251 "CZ gate requires exactly 2 wires".to_string(),
252 ));
253 }
254 let matrix = self.get_matrix(None);
255 qdev.apply_two_qubit_gate(wires[0], wires[1], &matrix)?;
256
257 if qdev.record_op {
258 qdev.record_operation(OpHistoryEntry {
259 name: "cz".to_string(),
260 wires: wires.to_vec(),
261 params: None,
262 inverse: self.inverse,
263 trainable: false,
264 });
265 }
266
267 Ok(())
268 }
269
270 fn has_params(&self) -> bool {
271 false
272 }
273
274 fn trainable(&self) -> bool {
275 false
276 }
277
278 fn inverse(&self) -> bool {
279 self.inverse
280 }
281
282 fn set_inverse(&mut self, inverse: bool) {
283 self.inverse = inverse;
284 }
285}
286
287#[derive(Debug, Clone)]
289pub struct TQSWAP {
290 inverse: bool,
291 static_mode: bool,
292}
293
294impl TQSWAP {
295 pub fn new() -> Self {
296 Self {
297 inverse: false,
298 static_mode: false,
299 }
300 }
301}
302
303impl Default for TQSWAP {
304 fn default() -> Self {
305 Self::new()
306 }
307}
308
309impl TQModule for TQSWAP {
310 fn forward(&mut self, _qdev: &mut TQDevice) -> Result<()> {
311 Err(MLError::InvalidConfiguration(
312 "Use apply() instead of forward() for operators".to_string(),
313 ))
314 }
315
316 fn parameters(&self) -> Vec<TQParameter> {
317 Vec::new()
318 }
319
320 fn n_wires(&self) -> Option<usize> {
321 Some(2)
322 }
323
324 fn set_n_wires(&mut self, _n_wires: usize) {}
325
326 fn is_static_mode(&self) -> bool {
327 self.static_mode
328 }
329
330 fn static_on(&mut self) {
331 self.static_mode = true;
332 }
333
334 fn static_off(&mut self) {
335 self.static_mode = false;
336 }
337
338 fn name(&self) -> &str {
339 "SWAP"
340 }
341}
342
343impl TQOperator for TQSWAP {
344 fn num_wires(&self) -> WiresEnum {
345 WiresEnum::Fixed(2)
346 }
347
348 fn num_params(&self) -> NParamsEnum {
349 NParamsEnum::Fixed(0)
350 }
351
352 fn get_matrix(&self, _params: Option<&[f64]>) -> Array2<CType> {
353 Array2::from_shape_vec(
354 (4, 4),
355 vec![
356 CType::new(1.0, 0.0),
357 CType::new(0.0, 0.0),
358 CType::new(0.0, 0.0),
359 CType::new(0.0, 0.0),
360 CType::new(0.0, 0.0),
361 CType::new(0.0, 0.0),
362 CType::new(1.0, 0.0),
363 CType::new(0.0, 0.0),
364 CType::new(0.0, 0.0),
365 CType::new(1.0, 0.0),
366 CType::new(0.0, 0.0),
367 CType::new(0.0, 0.0),
368 CType::new(0.0, 0.0),
369 CType::new(0.0, 0.0),
370 CType::new(0.0, 0.0),
371 CType::new(1.0, 0.0),
372 ],
373 )
374 .unwrap_or_else(|_| Array2::eye(4).mapv(|x| CType::new(x, 0.0)))
375 }
376
377 fn apply(&mut self, qdev: &mut TQDevice, wires: &[usize]) -> Result<()> {
378 self.apply_with_params(qdev, wires, None)
379 }
380
381 fn apply_with_params(
382 &mut self,
383 qdev: &mut TQDevice,
384 wires: &[usize],
385 _params: Option<&[f64]>,
386 ) -> Result<()> {
387 if wires.len() < 2 {
388 return Err(MLError::InvalidConfiguration(
389 "SWAP gate requires exactly 2 wires".to_string(),
390 ));
391 }
392 let matrix = self.get_matrix(None);
393 qdev.apply_two_qubit_gate(wires[0], wires[1], &matrix)?;
394
395 if qdev.record_op {
396 qdev.record_operation(OpHistoryEntry {
397 name: "swap".to_string(),
398 wires: wires.to_vec(),
399 params: None,
400 inverse: self.inverse,
401 trainable: false,
402 });
403 }
404
405 Ok(())
406 }
407
408 fn has_params(&self) -> bool {
409 false
410 }
411
412 fn trainable(&self) -> bool {
413 false
414 }
415
416 fn inverse(&self) -> bool {
417 self.inverse
418 }
419
420 fn set_inverse(&mut self, inverse: bool) {
421 self.inverse = inverse;
422 }
423}