quantrs2_tytan/tensor_network_sampler/
functions.rs1use scirs2_core::random::prelude::*;
6use std::collections::HashMap;
7
8use super::types::{
9 CacheOptimization, CompressionConfig, CompressionMethod, LoadBalancingStrategy, MemoryConfig,
10 NetworkTopology, OptimizationResult, ParallelConfig, QualityAssessment, Tensor, TensorNetwork,
11 TensorNetworkConfig, TensorNetworkError, TensorNetworkMetrics, TensorNetworkSampler,
12 TensorNetworkType, TopologyType,
13};
14
15pub trait TensorSymmetry: Send + Sync + std::fmt::Debug {
17 fn apply_symmetry(&self, tensor: &Tensor) -> Result<Tensor, TensorNetworkError>;
19 fn check_symmetry(&self, tensor: &Tensor) -> bool;
21 fn get_quantum_numbers(&self) -> Vec<i32>;
23 fn get_symmetry_name(&self) -> &str;
25}
26pub trait TensorOptimizationAlgorithm: Send + Sync + std::fmt::Debug {
28 fn optimize(
30 &self,
31 network: &mut TensorNetwork,
32 target: &Tensor,
33 ) -> Result<OptimizationResult, TensorNetworkError>;
34 fn get_algorithm_name(&self) -> &str;
36 fn get_parameters(&self) -> HashMap<String, f64>;
38}
39pub trait ConvergenceMonitor: Send + Sync + std::fmt::Debug {
41 fn check_convergence(&self, iteration: usize, energy: f64, gradient_norm: f64) -> bool;
43 fn get_monitor_name(&self) -> &str;
45}
46pub trait PerformanceTracker: Send + Sync + std::fmt::Debug {
48 fn track_performance(&self, iteration: usize, metrics: &TensorNetworkMetrics);
50 fn get_tracker_name(&self) -> &str;
52}
53pub trait CompressionAlgorithm: Send + Sync + std::fmt::Debug {
55 fn compress(
57 &self,
58 tensor: &Tensor,
59 target_dimension: usize,
60 ) -> Result<Tensor, TensorNetworkError>;
61 fn get_method_name(&self) -> &str;
63 fn estimate_quality(&self, original: &Tensor, compressed: &Tensor) -> f64;
65}
66pub trait CompressionQualityAssessor: Send + Sync + std::fmt::Debug {
68 fn assess_quality(&self, original: &Tensor, compressed: &Tensor) -> QualityAssessment;
70 fn get_assessor_name(&self) -> &str;
72}
73pub const fn create_default_tensor_config() -> TensorNetworkConfig {
75 TensorNetworkConfig {
76 network_type: TensorNetworkType::MPS { bond_dimension: 64 },
77 max_bond_dimension: 128,
78 compression_tolerance: 1e-10,
79 num_sweeps: 100,
80 convergence_tolerance: 1e-8,
81 use_gpu: false,
82 parallel_config: ParallelConfig {
83 num_threads: 4,
84 distributed: false,
85 chunk_size: 1000,
86 load_balancing: LoadBalancingStrategy::Dynamic,
87 },
88 memory_config: MemoryConfig {
89 max_memory_gb: 8.0,
90 memory_mapping: false,
91 gc_frequency: 100,
92 cache_optimization: CacheOptimization::Combined,
93 },
94 }
95}
96pub fn create_mps_sampler(bond_dimension: usize) -> TensorNetworkSampler {
98 let mut config = create_default_tensor_config();
99 config.network_type = TensorNetworkType::MPS { bond_dimension };
100 config.max_bond_dimension = bond_dimension * 2;
101 TensorNetworkSampler::new(config)
102}
103pub fn create_peps_sampler(
105 bond_dimension: usize,
106 lattice_shape: (usize, usize),
107) -> TensorNetworkSampler {
108 let mut config = create_default_tensor_config();
109 config.network_type = TensorNetworkType::PEPS {
110 bond_dimension,
111 lattice_shape,
112 };
113 config.max_bond_dimension = bond_dimension * 2;
114 TensorNetworkSampler::new(config)
115}
116pub fn create_mera_sampler(layers: usize) -> TensorNetworkSampler {
118 let mut config = create_default_tensor_config();
119 config.network_type = TensorNetworkType::MERA {
120 layers,
121 branching_factor: 2,
122 };
123 TensorNetworkSampler::new(config)
124}
125#[cfg(test)]
126mod tests {
127 use super::*;
128 #[test]
129 fn test_tensor_network_sampler_creation() {
130 let sampler = create_mps_sampler(32);
131 assert_eq!(sampler.config.max_bond_dimension, 64);
132 if let TensorNetworkType::MPS { bond_dimension } = sampler.config.network_type {
133 assert_eq!(bond_dimension, 32);
134 } else {
135 panic!("Expected MPS network type ");
136 }
137 }
138 #[test]
139 fn test_peps_sampler_creation() {
140 let sampler = create_peps_sampler(16, (4, 4));
141 if let TensorNetworkType::PEPS {
142 bond_dimension,
143 lattice_shape,
144 } = sampler.config.network_type
145 {
146 assert_eq!(bond_dimension, 16);
147 assert_eq!(lattice_shape, (4, 4));
148 } else {
149 panic!("Expected PEPS network type ");
150 }
151 }
152 #[test]
153 fn test_mera_sampler_creation() {
154 let sampler = create_mera_sampler(3);
155 if let TensorNetworkType::MERA {
156 layers,
157 branching_factor,
158 } = sampler.config.network_type
159 {
160 assert_eq!(layers, 3);
161 assert_eq!(branching_factor, 2);
162 } else {
163 panic!("Expected MERA network type ");
164 }
165 }
166 #[test]
167 fn test_tensor_network_topology() {
168 let mut config = create_default_tensor_config();
169 let topology = NetworkTopology::new(&config.network_type);
170 assert_eq!(topology.topology_type, TopologyType::Chain);
171 }
172 #[test]
173 fn test_compression_config() {
174 let mut config = CompressionConfig::default();
175 assert_eq!(config.target_compression_ratio, 0.5);
176 assert_eq!(config.method, CompressionMethod::SVD);
177 }
178}