Skip to main content

quantrs2_tytan/tensor_network_sampler/
functions.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use scirs2_core::random::prelude::*;
6use std::collections::HashMap;
7
8use super::types::{
9    CacheOptimization, CompressionConfig, CompressionMethod, LoadBalancingStrategy, MemoryConfig,
10    NetworkTopology, OptimizationResult, ParallelConfig, QualityAssessment, Tensor, TensorNetwork,
11    TensorNetworkConfig, TensorNetworkError, TensorNetworkMetrics, TensorNetworkSampler,
12    TensorNetworkType, TopologyType,
13};
14
15/// Tensor symmetry trait
16pub trait TensorSymmetry: Send + Sync + std::fmt::Debug {
17    /// Apply symmetry transformation
18    fn apply_symmetry(&self, tensor: &Tensor) -> Result<Tensor, TensorNetworkError>;
19    /// Check if tensor respects symmetry
20    fn check_symmetry(&self, tensor: &Tensor) -> bool;
21    /// Get symmetry quantum numbers
22    fn get_quantum_numbers(&self) -> Vec<i32>;
23    /// Get symmetry name
24    fn get_symmetry_name(&self) -> &str;
25}
26/// Tensor optimization algorithm trait
27pub trait TensorOptimizationAlgorithm: Send + Sync + std::fmt::Debug {
28    /// Optimize tensor network
29    fn optimize(
30        &self,
31        network: &mut TensorNetwork,
32        target: &Tensor,
33    ) -> Result<OptimizationResult, TensorNetworkError>;
34    /// Get algorithm name
35    fn get_algorithm_name(&self) -> &str;
36    /// Get algorithm parameters
37    fn get_parameters(&self) -> HashMap<String, f64>;
38}
39/// Convergence monitor trait
40pub trait ConvergenceMonitor: Send + Sync + std::fmt::Debug {
41    /// Check convergence
42    fn check_convergence(&self, iteration: usize, energy: f64, gradient_norm: f64) -> bool;
43    /// Get monitor name
44    fn get_monitor_name(&self) -> &str;
45}
46/// Performance tracker trait
47pub trait PerformanceTracker: Send + Sync + std::fmt::Debug {
48    /// Track performance metrics
49    fn track_performance(&self, iteration: usize, metrics: &TensorNetworkMetrics);
50    /// Get tracker name
51    fn get_tracker_name(&self) -> &str;
52}
53/// Compression algorithm trait
54pub trait CompressionAlgorithm: Send + Sync + std::fmt::Debug {
55    /// Compress tensor
56    fn compress(
57        &self,
58        tensor: &Tensor,
59        target_dimension: usize,
60    ) -> Result<Tensor, TensorNetworkError>;
61    /// Get compression method name
62    fn get_method_name(&self) -> &str;
63    /// Estimate compression quality
64    fn estimate_quality(&self, original: &Tensor, compressed: &Tensor) -> f64;
65}
66/// Compression quality assessor trait
67pub trait CompressionQualityAssessor: Send + Sync + std::fmt::Debug {
68    /// Assess compression quality
69    fn assess_quality(&self, original: &Tensor, compressed: &Tensor) -> QualityAssessment;
70    /// Get assessor name
71    fn get_assessor_name(&self) -> &str;
72}
73/// Create default tensor network configuration
74pub const fn create_default_tensor_config() -> TensorNetworkConfig {
75    TensorNetworkConfig {
76        network_type: TensorNetworkType::MPS { bond_dimension: 64 },
77        max_bond_dimension: 128,
78        compression_tolerance: 1e-10,
79        num_sweeps: 100,
80        convergence_tolerance: 1e-8,
81        use_gpu: false,
82        parallel_config: ParallelConfig {
83            num_threads: 4,
84            distributed: false,
85            chunk_size: 1000,
86            load_balancing: LoadBalancingStrategy::Dynamic,
87        },
88        memory_config: MemoryConfig {
89            max_memory_gb: 8.0,
90            memory_mapping: false,
91            gc_frequency: 100,
92            cache_optimization: CacheOptimization::Combined,
93        },
94    }
95}
96/// Create MPS-based tensor network sampler
97pub fn create_mps_sampler(bond_dimension: usize) -> TensorNetworkSampler {
98    let mut config = create_default_tensor_config();
99    config.network_type = TensorNetworkType::MPS { bond_dimension };
100    config.max_bond_dimension = bond_dimension * 2;
101    TensorNetworkSampler::new(config)
102}
103/// Create PEPS-based tensor network sampler
104pub fn create_peps_sampler(
105    bond_dimension: usize,
106    lattice_shape: (usize, usize),
107) -> TensorNetworkSampler {
108    let mut config = create_default_tensor_config();
109    config.network_type = TensorNetworkType::PEPS {
110        bond_dimension,
111        lattice_shape,
112    };
113    config.max_bond_dimension = bond_dimension * 2;
114    TensorNetworkSampler::new(config)
115}
116/// Create MERA-based tensor network sampler
117pub fn create_mera_sampler(layers: usize) -> TensorNetworkSampler {
118    let mut config = create_default_tensor_config();
119    config.network_type = TensorNetworkType::MERA {
120        layers,
121        branching_factor: 2,
122    };
123    TensorNetworkSampler::new(config)
124}
125#[cfg(test)]
126mod tests {
127    use super::*;
128    #[test]
129    fn test_tensor_network_sampler_creation() {
130        let sampler = create_mps_sampler(32);
131        assert_eq!(sampler.config.max_bond_dimension, 64);
132        if let TensorNetworkType::MPS { bond_dimension } = sampler.config.network_type {
133            assert_eq!(bond_dimension, 32);
134        } else {
135            panic!("Expected MPS network type ");
136        }
137    }
138    #[test]
139    fn test_peps_sampler_creation() {
140        let sampler = create_peps_sampler(16, (4, 4));
141        if let TensorNetworkType::PEPS {
142            bond_dimension,
143            lattice_shape,
144        } = sampler.config.network_type
145        {
146            assert_eq!(bond_dimension, 16);
147            assert_eq!(lattice_shape, (4, 4));
148        } else {
149            panic!("Expected PEPS network type ");
150        }
151    }
152    #[test]
153    fn test_mera_sampler_creation() {
154        let sampler = create_mera_sampler(3);
155        if let TensorNetworkType::MERA {
156            layers,
157            branching_factor,
158        } = sampler.config.network_type
159        {
160            assert_eq!(layers, 3);
161            assert_eq!(branching_factor, 2);
162        } else {
163            panic!("Expected MERA network type ");
164        }
165    }
166    #[test]
167    fn test_tensor_network_topology() {
168        let mut config = create_default_tensor_config();
169        let topology = NetworkTopology::new(&config.network_type);
170        assert_eq!(topology.topology_type, TopologyType::Chain);
171    }
172    #[test]
173    fn test_compression_config() {
174        let mut config = CompressionConfig::default();
175        assert_eq!(config.target_compression_ratio, 0.5);
176        assert_eq!(config.method, CompressionMethod::SVD);
177    }
178}