1use super::{
33 gates::{TQRx, TQRy, TQRz, TQCNOT},
34 CType, TQDevice, TQModule, TQOperator, TQParameter,
35};
36use crate::error::{MLError, Result};
37use scirs2_core::ndarray::{Array1, ArrayD};
38
39#[derive(Debug, Clone, PartialEq, Eq)]
41pub enum EntanglementPattern {
42 Linear,
44 ReverseLinear,
46 Circular,
48 Full,
50 Custom(Vec<(usize, usize)>),
52}
53
54impl EntanglementPattern {
55 pub fn generate_pairs(&self, n_wires: usize) -> Vec<(usize, usize)> {
57 match self {
58 EntanglementPattern::Linear => {
59 (0..n_wires.saturating_sub(1)).map(|i| (i, i + 1)).collect()
60 }
61 EntanglementPattern::ReverseLinear => (1..n_wires).rev().map(|i| (i, i - 1)).collect(),
62 EntanglementPattern::Circular => {
63 let mut pairs: Vec<(usize, usize)> =
64 (0..n_wires.saturating_sub(1)).map(|i| (i, i + 1)).collect();
65 if n_wires > 2 {
66 pairs.push((n_wires - 1, 0));
67 }
68 pairs
69 }
70 EntanglementPattern::Full => {
71 let mut pairs = Vec::new();
72 for i in 0..n_wires {
73 for j in (i + 1)..n_wires {
74 pairs.push((i, j));
75 }
76 }
77 pairs
78 }
79 EntanglementPattern::Custom(pairs) => pairs.clone(),
80 }
81 }
82}
83
84pub struct RealAmplitudesLayer {
92 pub n_wires: usize,
94 pub reps: usize,
96 pub entanglement: EntanglementPattern,
98 ry_gates: Vec<Vec<TQRy>>,
100 pub final_rotation: bool,
102 static_mode: bool,
103}
104
105impl RealAmplitudesLayer {
106 pub fn new(n_wires: usize, reps: usize, entanglement: EntanglementPattern) -> Self {
114 let mut ry_gates = Vec::new();
116 for _ in 0..reps {
117 let layer: Vec<TQRy> = (0..n_wires).map(|_| TQRy::new(true, true)).collect();
118 ry_gates.push(layer);
119 }
120
121 let final_layer: Vec<TQRy> = (0..n_wires).map(|_| TQRy::new(true, true)).collect();
123 ry_gates.push(final_layer);
124
125 Self {
126 n_wires,
127 reps,
128 entanglement,
129 ry_gates,
130 final_rotation: true,
131 static_mode: false,
132 }
133 }
134
135 pub fn without_final_rotation(mut self) -> Self {
137 self.final_rotation = false;
138 self
139 }
140}
141
142impl TQModule for RealAmplitudesLayer {
143 fn forward(&mut self, qdev: &mut TQDevice) -> Result<()> {
144 let entanglement_pairs = self.entanglement.generate_pairs(self.n_wires);
145
146 for rep in 0..self.reps {
147 for (wire, gate) in self.ry_gates[rep].iter_mut().enumerate() {
149 gate.apply(qdev, &[wire])?;
150 }
151
152 for (control, target) in &entanglement_pairs {
154 let mut cnot = TQCNOT::new();
155 cnot.apply(qdev, &[*control, *target])?;
156 }
157 }
158
159 if self.final_rotation && self.reps < self.ry_gates.len() {
161 for (wire, gate) in self.ry_gates[self.reps].iter_mut().enumerate() {
162 gate.apply(qdev, &[wire])?;
163 }
164 }
165
166 Ok(())
167 }
168
169 fn parameters(&self) -> Vec<TQParameter> {
170 let num_layers = if self.final_rotation {
171 self.reps + 1
172 } else {
173 self.reps
174 };
175
176 self.ry_gates[..num_layers]
177 .iter()
178 .flat_map(|layer| layer.iter().flat_map(|g| g.parameters()))
179 .collect()
180 }
181
182 fn n_wires(&self) -> Option<usize> {
183 Some(self.n_wires)
184 }
185
186 fn set_n_wires(&mut self, n_wires: usize) {
187 self.n_wires = n_wires;
188 }
189
190 fn is_static_mode(&self) -> bool {
191 self.static_mode
192 }
193
194 fn static_on(&mut self) {
195 self.static_mode = true;
196 }
197
198 fn static_off(&mut self) {
199 self.static_mode = false;
200 }
201
202 fn name(&self) -> &str {
203 "RealAmplitudesLayer"
204 }
205}
206
207pub struct EfficientSU2Layer {
216 pub n_wires: usize,
218 pub reps: usize,
220 pub entanglement: EntanglementPattern,
222 ry_gates: Vec<Vec<TQRy>>,
224 rz_gates: Vec<Vec<TQRz>>,
226 pub final_rotation: bool,
228 static_mode: bool,
229}
230
231impl EfficientSU2Layer {
232 pub fn new(n_wires: usize, reps: usize, entanglement: EntanglementPattern) -> Self {
234 let mut ry_gates = Vec::new();
235 let mut rz_gates = Vec::new();
236
237 for _ in 0..=reps {
238 let ry_layer: Vec<TQRy> = (0..n_wires).map(|_| TQRy::new(true, true)).collect();
240 ry_gates.push(ry_layer);
241
242 let rz_layer: Vec<TQRz> = (0..n_wires).map(|_| TQRz::new(true, true)).collect();
244 rz_gates.push(rz_layer);
245 }
246
247 Self {
248 n_wires,
249 reps,
250 entanglement,
251 ry_gates,
252 rz_gates,
253 final_rotation: true,
254 static_mode: false,
255 }
256 }
257
258 pub fn without_final_rotation(mut self) -> Self {
259 self.final_rotation = false;
260 self
261 }
262}
263
264impl TQModule for EfficientSU2Layer {
265 fn forward(&mut self, qdev: &mut TQDevice) -> Result<()> {
266 let entanglement_pairs = self.entanglement.generate_pairs(self.n_wires);
267
268 for rep in 0..self.reps {
269 for (wire, gate) in self.ry_gates[rep].iter_mut().enumerate() {
271 gate.apply(qdev, &[wire])?;
272 }
273
274 for (wire, gate) in self.rz_gates[rep].iter_mut().enumerate() {
276 gate.apply(qdev, &[wire])?;
277 }
278
279 for (control, target) in &entanglement_pairs {
281 let mut cnot = TQCNOT::new();
282 cnot.apply(qdev, &[*control, *target])?;
283 }
284 }
285
286 if self.final_rotation {
288 for (wire, gate) in self.ry_gates[self.reps].iter_mut().enumerate() {
289 gate.apply(qdev, &[wire])?;
290 }
291 for (wire, gate) in self.rz_gates[self.reps].iter_mut().enumerate() {
292 gate.apply(qdev, &[wire])?;
293 }
294 }
295
296 Ok(())
297 }
298
299 fn parameters(&self) -> Vec<TQParameter> {
300 let num_layers = if self.final_rotation {
301 self.reps + 1
302 } else {
303 self.reps
304 };
305
306 let ry_params = self.ry_gates[..num_layers]
307 .iter()
308 .flat_map(|layer| layer.iter().flat_map(|g| g.parameters()));
309
310 let rz_params = self.rz_gates[..num_layers]
311 .iter()
312 .flat_map(|layer| layer.iter().flat_map(|g| g.parameters()));
313
314 ry_params.chain(rz_params).collect()
315 }
316
317 fn n_wires(&self) -> Option<usize> {
318 Some(self.n_wires)
319 }
320
321 fn set_n_wires(&mut self, n_wires: usize) {
322 self.n_wires = n_wires;
323 }
324
325 fn is_static_mode(&self) -> bool {
326 self.static_mode
327 }
328
329 fn static_on(&mut self) {
330 self.static_mode = true;
331 }
332
333 fn static_off(&mut self) {
334 self.static_mode = false;
335 }
336
337 fn name(&self) -> &str {
338 "EfficientSU2Layer"
339 }
340}
341
342#[derive(Debug, Clone, Copy, PartialEq, Eq)]
344pub enum RotationType {
345 RX,
346 RY,
347 RZ,
348}
349
350pub struct TwoLocalLayer {
354 pub n_wires: usize,
356 pub reps: usize,
358 pub rotation_gates: Vec<RotationType>,
360 pub entanglement: EntanglementPattern,
362 parameters: Vec<TQParameter>,
364 static_mode: bool,
365}
366
367impl TwoLocalLayer {
368 pub fn new(
370 n_wires: usize,
371 reps: usize,
372 rotation_gates: Vec<RotationType>,
373 entanglement: EntanglementPattern,
374 ) -> Self {
375 let params_per_rep = n_wires * rotation_gates.len();
377 let total_params = (reps + 1) * params_per_rep;
378
379 let parameters: Vec<TQParameter> = (0..total_params)
380 .map(|i| {
381 let param_data = ArrayD::zeros(scirs2_core::ndarray::IxDyn(&[1, 1]));
382 TQParameter::new(param_data, format!("param_{}", i))
383 })
384 .collect();
385
386 Self {
387 n_wires,
388 reps,
389 rotation_gates,
390 entanglement,
391 parameters,
392 static_mode: false,
393 }
394 }
395}
396
397impl TQModule for TwoLocalLayer {
398 fn forward(&mut self, qdev: &mut TQDevice) -> Result<()> {
399 let entanglement_pairs = self.entanglement.generate_pairs(self.n_wires);
400 let params_per_layer = self.n_wires * self.rotation_gates.len();
401
402 for rep in 0..=self.reps {
403 let param_offset = rep * params_per_layer;
404
405 for (rot_idx, rot_type) in self.rotation_gates.iter().enumerate() {
407 for wire in 0..self.n_wires {
408 let param_idx = param_offset + rot_idx * self.n_wires + wire;
409 let param_val = if self.parameters[param_idx].data.len() > 0 {
410 self.parameters[param_idx].data[[0, 0]]
411 } else {
412 0.0
413 };
414
415 match rot_type {
416 RotationType::RX => {
417 let mut gate = TQRx::new(true, true);
418 gate.apply_with_params(qdev, &[wire], Some(&[param_val]))?;
419 }
420 RotationType::RY => {
421 let mut gate = TQRy::new(true, true);
422 gate.apply_with_params(qdev, &[wire], Some(&[param_val]))?;
423 }
424 RotationType::RZ => {
425 let mut gate = TQRz::new(true, true);
426 gate.apply_with_params(qdev, &[wire], Some(&[param_val]))?;
427 }
428 }
429 }
430 }
431
432 if rep < self.reps {
434 for (control, target) in &entanglement_pairs {
435 let mut cnot = TQCNOT::new();
436 cnot.apply(qdev, &[*control, *target])?;
437 }
438 }
439 }
440
441 Ok(())
442 }
443
444 fn parameters(&self) -> Vec<TQParameter> {
445 self.parameters.clone()
446 }
447
448 fn n_wires(&self) -> Option<usize> {
449 Some(self.n_wires)
450 }
451
452 fn set_n_wires(&mut self, n_wires: usize) {
453 self.n_wires = n_wires;
454 }
455
456 fn is_static_mode(&self) -> bool {
457 self.static_mode
458 }
459
460 fn static_on(&mut self) {
461 self.static_mode = true;
462 }
463
464 fn static_off(&mut self) {
465 self.static_mode = false;
466 }
467
468 fn name(&self) -> &str {
469 "TwoLocalLayer"
470 }
471}
472
473#[cfg(test)]
474mod tests {
475 use super::*;
476
477 #[test]
478 fn test_entanglement_linear() {
479 let pattern = EntanglementPattern::Linear;
480 let pairs = pattern.generate_pairs(4);
481 assert_eq!(pairs, vec![(0, 1), (1, 2), (2, 3)]);
482 }
483
484 #[test]
485 fn test_entanglement_circular() {
486 let pattern = EntanglementPattern::Circular;
487 let pairs = pattern.generate_pairs(4);
488 assert_eq!(pairs, vec![(0, 1), (1, 2), (2, 3), (3, 0)]);
489 }
490
491 #[test]
492 fn test_entanglement_full() {
493 let pattern = EntanglementPattern::Full;
494 let pairs = pattern.generate_pairs(3);
495 assert_eq!(pairs, vec![(0, 1), (0, 2), (1, 2)]);
496 }
497
498 #[test]
499 fn test_real_amplitudes_creation() {
500 let layer = RealAmplitudesLayer::new(4, 2, EntanglementPattern::Linear);
501 assert_eq!(layer.n_wires, 4);
502 assert_eq!(layer.reps, 2);
503
504 assert_eq!(layer.ry_gates.len(), 3);
506
507 assert_eq!(layer.ry_gates[0].len(), 4);
509 }
510
511 #[test]
512 fn test_efficient_su2_creation() {
513 let layer = EfficientSU2Layer::new(3, 2, EntanglementPattern::Circular);
514 assert_eq!(layer.n_wires, 3);
515 assert_eq!(layer.reps, 2);
516
517 assert_eq!(layer.ry_gates.len(), 3);
519 assert_eq!(layer.rz_gates.len(), 3);
520 }
521
522 #[test]
523 fn test_two_local_creation() {
524 let rotations = vec![RotationType::RY, RotationType::RZ];
525 let layer = TwoLocalLayer::new(4, 2, rotations, EntanglementPattern::Linear);
526
527 assert_eq!(layer.n_wires, 4);
528 assert_eq!(layer.reps, 2);
529
530 assert_eq!(layer.parameters.len(), 24);
533 }
534
535 #[test]
536 fn test_real_amplitudes_parameters() {
537 let layer = RealAmplitudesLayer::new(3, 2, EntanglementPattern::Linear);
538
539 let params = layer.parameters();
541 assert_eq!(params.len(), 9);
542 }
543
544 #[test]
545 fn test_efficient_su2_parameters() {
546 let layer = EfficientSU2Layer::new(3, 2, EntanglementPattern::Linear);
547
548 let params = layer.parameters();
550 assert_eq!(params.len(), 18);
551 }
552}