Skip to main content

quantrs2_ml/quantum_self_supervised_learning/
functions.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use scirs2_core::ndarray::{Array1, Array2, Array3, ArrayView1, Axis};
6use scirs2_core::random::prelude::*;
7use scirs2_core::Complex64;
8
9use super::types::{
10    DecoherenceModel, EvolutionType, NegativeSamplingStrategy, NoiseType, PreparationMethod,
11    QuantumActivation, QuantumAugmentationStrategy, QuantumAugmenter, QuantumDecoder,
12    QuantumEncoder, QuantumEncoderDecoder, QuantumMaskingStrategy, QuantumProjectionHead,
13    QuantumSSLMethod, QuantumSelfSupervisedConfig, QuantumSelfSupervisedLearner,
14    QuantumSimilarityMetric, QuantumState, QuantumStateEvolution, QuantumStatePreparation,
15    ReconstructionObjective, ReconstructionStrategy, SSLTrainingConfig,
16};
17
18#[cfg(test)]
19mod tests {
20    use super::*;
21    #[test]
22    fn test_quantum_ssl_creation() {
23        let config = QuantumSelfSupervisedConfig::default();
24        let ssl = QuantumSelfSupervisedLearner::new(config);
25        assert!(ssl.is_ok());
26    }
27    #[test]
28    fn test_quantum_augmentations() {
29        let config = QuantumSelfSupervisedConfig::default();
30        let augmenter = QuantumAugmenter {
31            augmentation_strategies: vec![QuantumAugmentationStrategy::QuantumNoise {
32                noise_type: NoiseType::Gaussian,
33                strength: 0.1,
34            }],
35            augmentation_strength: 0.5,
36            quantum_coherence_preservation: 0.9,
37        };
38        let sample = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
39        let views = augmenter.generate_augmented_views(&sample, 2);
40        assert!(views.is_ok());
41        assert_eq!(views.expect("views should be ok").len(), 2);
42    }
43    #[test]
44    fn test_ssl_training_config() {
45        let config = SSLTrainingConfig::default();
46        assert_eq!(config.batch_size, 256);
47        assert_eq!(config.epochs, 100);
48    }
49    #[test]
50    fn test_quantum_contrastive_method() {
51        let config = QuantumSelfSupervisedConfig {
52            ssl_method: QuantumSSLMethod::QuantumContrastive {
53                similarity_metric: QuantumSimilarityMetric::QuantumCosine,
54                negative_sampling_strategy: NegativeSamplingStrategy::Random,
55                quantum_projection_head: QuantumProjectionHead {
56                    hidden_dims: vec![128, 64],
57                    output_dim: 32,
58                    use_batch_norm: true,
59                    quantum_layers: Vec::new(),
60                    activation: QuantumActivation::QuantumReLU,
61                },
62            },
63            ..Default::default()
64        };
65        let ssl = QuantumSelfSupervisedLearner::new(config);
66        assert!(ssl.is_ok());
67    }
68    #[test]
69    fn test_quantum_masked_method() {
70        let config = QuantumSelfSupervisedConfig {
71            ssl_method: QuantumSSLMethod::QuantumMasked {
72                masking_strategy: QuantumMaskingStrategy::Random {
73                    mask_probability: 0.15,
74                },
75                reconstruction_objective: ReconstructionObjective::MSE,
76                quantum_encoder_decoder: QuantumEncoderDecoder {
77                    encoder: QuantumEncoder {
78                        layers: Vec::new(),
79                        quantum_state_evolution: QuantumStateEvolution {
80                            evolution_type: EvolutionType::Unitary,
81                            time_steps: Array1::linspace(0.0, 1.0, 10),
82                            hamiltonian: Array2::<f64>::eye(8).mapv(|x| Complex64::new(x, 0.0)),
83                            decoherence_model: DecoherenceModel::default(),
84                        },
85                        measurement_points: vec![0, 1],
86                    },
87                    decoder: QuantumDecoder {
88                        layers: Vec::new(),
89                        quantum_state_preparation: QuantumStatePreparation {
90                            preparation_method: PreparationMethod::DirectPreparation,
91                            target_state: QuantumState::default(),
92                            fidelity_threshold: 0.95,
93                        },
94                        reconstruction_strategy: ReconstructionStrategy::FullReconstruction,
95                    },
96                    shared_quantum_state: true,
97                    entanglement_coupling: 0.5,
98                },
99            },
100            ..Default::default()
101        };
102        let ssl = QuantumSelfSupervisedLearner::new(config);
103        assert!(ssl.is_ok());
104    }
105}