1use crate::error::{MLError, Result};
25use scirs2_core::ndarray::{Array1, Array2, ArrayD, IxDyn};
26use scirs2_core::Complex64;
27use std::f64::consts::PI;
28
29pub mod ansatz;
31pub mod autograd;
32pub mod conv;
33pub mod encoding;
34pub mod functional;
35pub mod gates;
36pub mod layer;
37pub mod measurement;
38pub mod noise;
39pub mod pooling;
40pub mod tensor_network;
41
42pub type CType = Complex64;
48
49pub type FType = f64;
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub enum WiresEnum {
55 AnyWires,
57 AllWires,
59 Fixed(usize),
61}
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq)]
65pub enum NParamsEnum {
66 AnyNParams,
68 Fixed(usize),
70}
71
72pub trait TQModule: Send + Sync {
84 fn forward(&mut self, qdev: &mut TQDevice) -> Result<()>;
86
87 fn forward_with_input(&mut self, qdev: &mut TQDevice, _x: Option<&Array2<f64>>) -> Result<()> {
89 self.forward(qdev)
90 }
91
92 fn parameters(&self) -> Vec<TQParameter>;
94
95 fn n_wires(&self) -> Option<usize>;
97
98 fn set_n_wires(&mut self, n_wires: usize);
100
101 fn is_static_mode(&self) -> bool;
103
104 fn static_on(&mut self);
106
107 fn static_off(&mut self);
109
110 fn get_unitary(&self) -> Option<Array2<CType>> {
112 None
113 }
114
115 fn name(&self) -> &str;
117
118 fn zero_grad(&mut self) {
120 }
122
123 fn train(&mut self, _mode: bool) {
125 }
127
128 fn training(&self) -> bool {
130 true
131 }
132}
133
134#[derive(Debug, Clone)]
140pub struct TQParameter {
141 pub data: ArrayD<f64>,
143 pub name: String,
145 pub requires_grad: bool,
147 pub grad: Option<ArrayD<f64>>,
149}
150
151impl TQParameter {
152 pub fn new(data: ArrayD<f64>, name: impl Into<String>) -> Self {
154 Self {
155 data,
156 name: name.into(),
157 requires_grad: true,
158 grad: None,
159 }
160 }
161
162 pub fn no_grad(data: ArrayD<f64>, name: impl Into<String>) -> Self {
164 Self {
165 data,
166 name: name.into(),
167 requires_grad: false,
168 grad: None,
169 }
170 }
171
172 pub fn shape(&self) -> &[usize] {
174 self.data.shape()
175 }
176
177 pub fn numel(&self) -> usize {
179 self.data.len()
180 }
181
182 pub fn zero_grad(&mut self) {
184 self.grad = None;
185 }
186
187 pub fn init_uniform_pi(&mut self) {
189 for elem in self.data.iter_mut() {
190 *elem = (fastrand::f64() * 2.0 - 1.0) * PI;
191 }
192 }
193
194 pub fn init_constant(&mut self, value: f64) {
196 for elem in self.data.iter_mut() {
197 *elem = value;
198 }
199 }
200}
201
202#[derive(Debug, Clone)]
213pub struct TQDevice {
214 pub n_wires: usize,
216 pub device_name: String,
218 pub bsz: usize,
220 pub states: ArrayD<CType>,
222 pub record_op: bool,
224 pub op_history: Vec<OpHistoryEntry>,
226}
227
228#[derive(Debug, Clone)]
230pub struct OpHistoryEntry {
231 pub name: String,
233 pub wires: Vec<usize>,
235 pub params: Option<Vec<f64>>,
237 pub inverse: bool,
239 pub trainable: bool,
241}
242
243impl TQDevice {
244 pub fn new(n_wires: usize) -> Self {
246 Self::with_batch_size(n_wires, 1)
247 }
248
249 pub fn with_batch_size(n_wires: usize, bsz: usize) -> Self {
251 let state_size = 1 << n_wires; let mut state_data = vec![CType::new(0.0, 0.0); state_size * bsz];
254 for b in 0..bsz {
256 state_data[b * state_size] = CType::new(1.0, 0.0);
257 }
258
259 let mut shape = vec![bsz];
261 shape.extend(vec![2; n_wires]);
262
263 let states = ArrayD::from_shape_vec(IxDyn(&shape), state_data)
264 .unwrap_or_else(|_| ArrayD::zeros(IxDyn(&shape)));
265
266 Self {
267 n_wires,
268 device_name: "default".to_string(),
269 bsz,
270 states,
271 record_op: false,
272 op_history: Vec::new(),
273 }
274 }
275
276 pub fn reset_states(&mut self, bsz: usize) {
278 self.bsz = bsz;
279 let state_size = 1 << self.n_wires;
280 let mut state_data = vec![CType::new(0.0, 0.0); state_size * bsz];
281 for b in 0..bsz {
282 state_data[b * state_size] = CType::new(1.0, 0.0);
283 }
284
285 let mut shape = vec![bsz];
286 shape.extend(vec![2; self.n_wires]);
287 self.states = ArrayD::from_shape_vec(IxDyn(&shape), state_data)
288 .unwrap_or_else(|_| ArrayD::zeros(IxDyn(&shape)));
289 }
290
291 pub fn reset_identity_states(&mut self) {
293 let state_size = 1 << self.n_wires;
294 self.bsz = state_size;
295
296 let mut state_data = vec![CType::new(0.0, 0.0); state_size * state_size];
297 for i in 0..state_size {
299 state_data[i * state_size + i] = CType::new(1.0, 0.0);
300 }
301
302 let mut shape = vec![state_size];
303 shape.extend(vec![2; self.n_wires]);
304 self.states = ArrayD::from_shape_vec(IxDyn(&shape), state_data)
305 .unwrap_or_else(|_| ArrayD::zeros(IxDyn(&shape)));
306 }
307
308 pub fn reset_all_eq_states(&mut self, bsz: usize) {
310 self.bsz = bsz;
311 let state_size = 1 << self.n_wires;
312 let amplitude = 1.0 / (state_size as f64).sqrt();
313 let state_data = vec![CType::new(amplitude, 0.0); state_size * bsz];
314
315 let mut shape = vec![bsz];
316 shape.extend(vec![2; self.n_wires]);
317 self.states = ArrayD::from_shape_vec(IxDyn(&shape), state_data)
318 .unwrap_or_else(|_| ArrayD::zeros(IxDyn(&shape)));
319 }
320
321 pub fn clone_states(&mut self, other: &TQDevice) {
323 self.states = other.states.clone();
324 self.bsz = other.bsz;
325 }
326
327 pub fn set_states(&mut self, states: ArrayD<CType>) {
329 self.bsz = states.shape()[0];
330 self.states = states;
331 }
332
333 pub fn get_states_1d(&self) -> Array2<CType> {
335 let state_size = 1 << self.n_wires;
336 let flat: Vec<CType> = self.states.iter().cloned().collect();
337 Array2::from_shape_vec((self.bsz, state_size), flat)
338 .unwrap_or_else(|_| Array2::zeros((self.bsz, state_size)))
339 }
340
341 pub fn get_probs_1d(&self) -> Array2<f64> {
343 let states_1d = self.get_states_1d();
344 states_1d.mapv(|c| c.norm_sqr())
345 }
346
347 pub fn record_operation(&mut self, entry: OpHistoryEntry) {
349 if self.record_op {
350 self.op_history.push(entry);
351 }
352 }
353
354 pub fn reset_op_history(&mut self) {
356 self.op_history.clear();
357 }
358
359 pub fn apply_single_qubit_gate(&mut self, wire: usize, matrix: &Array2<CType>) -> Result<()> {
361 if wire >= self.n_wires {
362 return Err(MLError::InvalidConfiguration(format!(
363 "Wire {} out of range for {} qubits",
364 wire, self.n_wires
365 )));
366 }
367
368 let state_size = 1 << self.n_wires;
369 let states_1d = self.get_states_1d();
370 let mut new_states = states_1d.clone();
371
372 for batch in 0..self.bsz {
373 for i in 0..state_size {
374 let bit = (i >> (self.n_wires - 1 - wire)) & 1;
376 if bit == 0 {
377 let j = i | (1 << (self.n_wires - 1 - wire));
378 let amp0 = states_1d[[batch, i]];
379 let amp1 = states_1d[[batch, j]];
380 new_states[[batch, i]] = matrix[[0, 0]] * amp0 + matrix[[0, 1]] * amp1;
381 new_states[[batch, j]] = matrix[[1, 0]] * amp0 + matrix[[1, 1]] * amp1;
382 }
383 }
384 }
385
386 let flat: Vec<CType> = new_states.iter().cloned().collect();
388 let mut shape = vec![self.bsz];
389 shape.extend(vec![2; self.n_wires]);
390 self.states = ArrayD::from_shape_vec(IxDyn(&shape), flat)
391 .map_err(|e| MLError::InvalidConfiguration(e.to_string()))?;
392
393 Ok(())
394 }
395
396 pub fn apply_two_qubit_gate(
398 &mut self,
399 wire0: usize,
400 wire1: usize,
401 matrix: &Array2<CType>,
402 ) -> Result<()> {
403 if wire0 >= self.n_wires || wire1 >= self.n_wires {
404 return Err(MLError::InvalidConfiguration(format!(
405 "Wires ({}, {}) out of range for {} qubits",
406 wire0, wire1, self.n_wires
407 )));
408 }
409
410 let state_size = 1 << self.n_wires;
411 let states_1d = self.get_states_1d();
412 let mut new_states = states_1d.clone();
413
414 let pos0 = self.n_wires - 1 - wire0;
415 let pos1 = self.n_wires - 1 - wire1;
416
417 for batch in 0..self.bsz {
418 let mut visited = vec![false; state_size];
419
420 for i in 0..state_size {
421 if visited[i] {
422 continue;
423 }
424
425 let base = i & !(1 << pos0) & !(1 << pos1);
428
429 let indices = [
430 base, base | (1 << pos1), base | (1 << pos0), base | (1 << pos0) | (1 << pos1), ];
435
436 let amps: Vec<CType> = indices.iter().map(|&idx| states_1d[[batch, idx]]).collect();
437
438 for (row, &idx) in indices.iter().enumerate() {
439 let mut new_amp = CType::new(0.0, 0.0);
440 for (col, &) in amps.iter().enumerate() {
441 new_amp += matrix[[row, col]] * amp;
442 }
443 new_states[[batch, idx]] = new_amp;
444 visited[idx] = true;
445 }
446 }
447 }
448
449 let flat: Vec<CType> = new_states.iter().cloned().collect();
451 let mut shape = vec![self.bsz];
452 shape.extend(vec![2; self.n_wires]);
453 self.states = ArrayD::from_shape_vec(IxDyn(&shape), flat)
454 .map_err(|e| MLError::InvalidConfiguration(e.to_string()))?;
455
456 Ok(())
457 }
458
459 pub fn apply_multi_qubit_gate(
461 &mut self,
462 wires: &[usize],
463 matrix: &Array2<CType>,
464 ) -> Result<()> {
465 let n_qubits = wires.len();
466
467 for &wire in wires {
469 if wire >= self.n_wires {
470 return Err(MLError::InvalidConfiguration(format!(
471 "Wire {} out of range for {} qubits",
472 wire, self.n_wires
473 )));
474 }
475 }
476
477 let gate_dim = 1 << n_qubits;
479 if matrix.nrows() != gate_dim || matrix.ncols() != gate_dim {
480 return Err(MLError::InvalidConfiguration(format!(
481 "Gate matrix must be {}x{} for {}-qubit gate",
482 gate_dim, gate_dim, n_qubits
483 )));
484 }
485
486 let state_size = 1 << self.n_wires;
487 let states_1d = self.get_states_1d();
488 let mut new_states = states_1d.clone();
489
490 let positions: Vec<usize> = wires.iter().map(|&w| self.n_wires - 1 - w).collect();
492
493 let mut wire_mask: usize = 0;
495 for &pos in &positions {
496 wire_mask |= 1 << pos;
497 }
498
499 for batch in 0..self.bsz {
500 let mut visited = vec![false; state_size];
501
502 for base_idx in 0..state_size {
503 if visited[base_idx] {
504 continue;
505 }
506
507 let base = base_idx & !wire_mask;
509
510 let mut indices = Vec::with_capacity(gate_dim);
512 for gate_idx in 0..gate_dim {
513 let mut idx = base;
514 for (bit_pos, &pos) in positions.iter().enumerate() {
516 if (gate_idx >> (n_qubits - 1 - bit_pos)) & 1 == 1 {
517 idx |= 1 << pos;
518 }
519 }
520 indices.push(idx);
521 }
522
523 let amps: Vec<CType> = indices.iter().map(|&idx| states_1d[[batch, idx]]).collect();
525
526 for (row, &idx) in indices.iter().enumerate() {
528 let mut new_amp = CType::new(0.0, 0.0);
529 for (col, &) in amps.iter().enumerate() {
530 new_amp += matrix[[row, col]] * amp;
531 }
532 new_states[[batch, idx]] = new_amp;
533 visited[idx] = true;
534 }
535 }
536 }
537
538 let flat: Vec<CType> = new_states.iter().cloned().collect();
540 let mut shape = vec![self.bsz];
541 shape.extend(vec![2; self.n_wires]);
542 self.states = ArrayD::from_shape_vec(IxDyn(&shape), flat)
543 .map_err(|e| MLError::InvalidConfiguration(e.to_string()))?;
544
545 Ok(())
546 }
547}
548
549pub trait TQOperator: TQModule {
555 fn num_wires(&self) -> WiresEnum;
557
558 fn num_params(&self) -> NParamsEnum;
560
561 fn get_matrix(&self, params: Option<&[f64]>) -> Array2<CType>;
563
564 fn get_eigvals(&self, _params: Option<&[f64]>) -> Option<Array1<CType>> {
566 None
567 }
568
569 fn apply(&mut self, qdev: &mut TQDevice, wires: &[usize]) -> Result<()>;
571
572 fn apply_with_params(
574 &mut self,
575 qdev: &mut TQDevice,
576 wires: &[usize],
577 params: Option<&[f64]>,
578 ) -> Result<()>;
579
580 fn has_params(&self) -> bool;
582
583 fn trainable(&self) -> bool;
585
586 fn inverse(&self) -> bool;
588 fn set_inverse(&mut self, inverse: bool);
589}
590
591pub struct TQModuleList {
597 modules: Vec<Box<dyn TQModule>>,
598 static_mode: bool,
599}
600
601impl TQModuleList {
602 pub fn new() -> Self {
604 Self {
605 modules: Vec::new(),
606 static_mode: false,
607 }
608 }
609
610 pub fn append(&mut self, module: Box<dyn TQModule>) {
612 self.modules.push(module);
613 }
614
615 pub fn len(&self) -> usize {
617 self.modules.len()
618 }
619
620 pub fn is_empty(&self) -> bool {
622 self.modules.is_empty()
623 }
624
625 pub fn get(&self, index: usize) -> Option<&Box<dyn TQModule>> {
627 self.modules.get(index)
628 }
629
630 pub fn get_mut(&mut self, index: usize) -> Option<&mut Box<dyn TQModule>> {
632 self.modules.get_mut(index)
633 }
634
635 pub fn iter(&self) -> impl Iterator<Item = &Box<dyn TQModule>> {
637 self.modules.iter()
638 }
639
640 pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut Box<dyn TQModule>> {
642 self.modules.iter_mut()
643 }
644}
645
646impl Default for TQModuleList {
647 fn default() -> Self {
648 Self::new()
649 }
650}
651
652impl TQModule for TQModuleList {
653 fn forward(&mut self, qdev: &mut TQDevice) -> Result<()> {
654 for module in &mut self.modules {
655 module.forward(qdev)?;
656 }
657 Ok(())
658 }
659
660 fn parameters(&self) -> Vec<TQParameter> {
661 self.modules.iter().flat_map(|m| m.parameters()).collect()
662 }
663
664 fn n_wires(&self) -> Option<usize> {
665 self.modules.first().and_then(|m| m.n_wires())
666 }
667
668 fn set_n_wires(&mut self, n_wires: usize) {
669 for module in &mut self.modules {
670 module.set_n_wires(n_wires);
671 }
672 }
673
674 fn is_static_mode(&self) -> bool {
675 self.static_mode
676 }
677
678 fn static_on(&mut self) {
679 self.static_mode = true;
680 for module in &mut self.modules {
681 module.static_on();
682 }
683 }
684
685 fn static_off(&mut self) {
686 self.static_mode = false;
687 for module in &mut self.modules {
688 module.static_off();
689 }
690 }
691
692 fn name(&self) -> &str {
693 "ModuleList"
694 }
695
696 fn zero_grad(&mut self) {
697 for module in &mut self.modules {
698 module.zero_grad();
699 }
700 }
701}
702
703pub mod prelude {
708 pub use super::{
711 CType, FType, NParamsEnum, OpHistoryEntry, TQDevice, TQModule, TQModuleList, TQOperator,
712 TQParameter, WiresEnum,
713 };
714
715 pub use super::gates::{
717 TQHadamard,
719 TQPauliX,
720 TQPauliY,
721 TQPauliZ,
722 TQRx,
723 TQRy,
724 TQRz,
725 TQCNOT,
727 TQCRX,
729 TQCRY,
730 TQCRZ,
731 TQCZ,
732 TQRXX,
734 TQRYY,
735 TQRZX,
736 TQRZZ,
737 TQS,
738 TQSWAP,
739 TQSX,
740 TQT,
741 TQU1,
742 TQU2,
743 TQU3,
744 };
745
746 pub use super::encoding::{
748 EncodingOp, TQAmplitudeEncoder, TQEncoder, TQGeneralEncoder, TQPhaseEncoder, TQStateEncoder,
749 };
750
751 pub use super::measurement::{
753 expval_joint_analytical, expval_joint_sampling, gen_bitstrings, measure, TQMeasureAll,
754 };
755
756 pub use super::layer::{
758 TQBarrenLayer, TQFarhiLayer, TQLayerConfig, TQMaxwellLayer, TQOp1QAllLayer, TQOp2QAllLayer,
759 TQRXYZCXLayer, TQSethLayer, TQStrongEntanglingLayer,
760 };
761
762 pub use super::autograd::{
764 gradient_norm, gradient_statistics, ClippingStatistics, ClippingStrategy,
765 GradientAccumulator, GradientCheckResult, GradientChecker, GradientClipper,
766 GradientStatistics, ParameterGroup, ParameterGroupManager, ParameterRegistry,
767 ParameterStatistics,
768 };
769
770 pub use super::ansatz::{
772 EfficientSU2Layer, EntanglementPattern, RealAmplitudesLayer, TwoLocalLayer,
773 };
774
775 pub use super::conv::{QConv1D, QConv2D};
777
778 pub use super::pooling::{QAvgPool, QMaxPool};
780
781 pub use super::tensor_network::{
783 CompressionMethod, MPSTensor, MatrixProductState, TQTensorNetworkBackend,
784 TensorNetworkConfig,
785 };
786
787 pub use super::noise::{
789 GateTimes, MitigatedExpectation, MitigatedExpectationConfig, MitigationMethod,
790 NoiseAwareGradient, NoiseAwareGradientConfig, NoiseAwareTrainer, NoiseModel,
791 SingleQubitNoiseType, TrainingHistory, TrainingStatistics, TwoQubitNoiseType,
792 VarianceReduction, ZNEExtrapolation,
793 };
794}
795
796#[cfg(test)]
801mod tests {
802 use super::prelude::*;
803 use std::f64::consts::PI;
804
805 #[test]
806 fn test_tq_device_creation() {
807 let qdev = TQDevice::new(4);
808 assert_eq!(qdev.n_wires, 4);
809 assert_eq!(qdev.bsz, 1);
810
811 let probs = qdev.get_probs_1d();
813 assert!((probs[[0, 0]] - 1.0).abs() < 1e-10);
814 for i in 1..(1 << 4) {
815 assert!(probs[[0, i]].abs() < 1e-10);
816 }
817 }
818
819 #[test]
820 fn test_tq_device_reset() {
821 let mut qdev = TQDevice::new(2);
822 qdev.reset_all_eq_states(1);
823
824 let probs = qdev.get_probs_1d();
825 let expected = 0.25; for i in 0..4 {
827 assert!((probs[[0, i]] - expected).abs() < 1e-10);
828 }
829 }
830
831 #[test]
832 fn test_tq_parameter() {
833 use scirs2_core::ndarray::ArrayD;
834
835 let mut param =
836 TQParameter::new(ArrayD::zeros(scirs2_core::ndarray::IxDyn(&[2, 3])), "test");
837 assert_eq!(param.shape(), &[2, 3]);
838 assert_eq!(param.numel(), 6);
839
840 param.init_constant(1.5);
841 for elem in param.data.iter() {
842 assert!((elem - 1.5).abs() < 1e-10);
843 }
844 }
845
846 #[test]
847 fn test_hadamard_gate() {
848 let mut qdev = TQDevice::new(1);
849 let mut h = TQHadamard::new();
850
851 h.apply(&mut qdev, &[0]).expect("Hadamard should succeed");
852
853 let probs = qdev.get_probs_1d();
854 assert!((probs[[0, 0]] - 0.5).abs() < 1e-10);
855 assert!((probs[[0, 1]] - 0.5).abs() < 1e-10);
856 }
857
858 #[test]
859 fn test_pauli_x_gate() {
860 let mut qdev = TQDevice::new(1);
861 let mut x = TQPauliX::new();
862
863 x.apply(&mut qdev, &[0]).expect("PauliX should succeed");
864
865 let probs = qdev.get_probs_1d();
866 assert!(probs[[0, 0]].abs() < 1e-10);
867 assert!((probs[[0, 1]] - 1.0).abs() < 1e-10);
868 }
869
870 #[test]
871 fn test_rx_gate() {
872 let mut qdev = TQDevice::new(1);
873 let mut rx = TQRx::new(true, false);
874
875 rx.apply_with_params(&mut qdev, &[0], Some(&[PI]))
877 .expect("RX should succeed");
878
879 let probs = qdev.get_probs_1d();
880 assert!(probs[[0, 0]].abs() < 1e-10);
881 assert!((probs[[0, 1]] - 1.0).abs() < 1e-10);
882 }
883
884 #[test]
885 fn test_cnot_gate() {
886 let mut qdev = TQDevice::new(2);
887 let mut x = TQPauliX::new();
888 let mut cnot = TQCNOT::new();
889
890 x.apply(&mut qdev, &[0]).expect("X should succeed");
892 cnot.apply(&mut qdev, &[0, 1]).expect("CNOT should succeed");
893
894 let probs = qdev.get_probs_1d();
895 assert!(probs[[0, 0]].abs() < 1e-10); assert!(probs[[0, 1]].abs() < 1e-10); assert!(probs[[0, 2]].abs() < 1e-10); assert!((probs[[0, 3]] - 1.0).abs() < 1e-10); }
901
902 #[test]
903 fn test_bell_state() {
904 let mut qdev = TQDevice::new(2);
905 let mut h = TQHadamard::new();
906 let mut cnot = TQCNOT::new();
907
908 h.apply(&mut qdev, &[0]).expect("H should succeed");
909 cnot.apply(&mut qdev, &[0, 1]).expect("CNOT should succeed");
910
911 let probs = qdev.get_probs_1d();
912 assert!((probs[[0, 0]] - 0.5).abs() < 1e-10); assert!(probs[[0, 1]].abs() < 1e-10); assert!(probs[[0, 2]].abs() < 1e-10); assert!((probs[[0, 3]] - 0.5).abs() < 1e-10); }
918
919 #[test]
920 fn test_module_list() {
921 let mut qdev = TQDevice::new(2);
922 let mut module_list = TQModuleList::new();
923
924 module_list.append(Box::new(TQHadamard::new()));
925 module_list.append(Box::new(TQPauliX::new()));
926
927 assert_eq!(module_list.len(), 2);
928 assert!(!module_list.is_empty());
929 }
930}