quantrs2_ml/quantum_self_supervised_learning/
functions.rs1use scirs2_core::ndarray::{Array1, Array2, Array3, ArrayView1, Axis};
6use scirs2_core::random::prelude::*;
7use scirs2_core::Complex64;
8
9use super::types::{
10 DecoherenceModel, EvolutionType, NegativeSamplingStrategy, NoiseType, PreparationMethod,
11 QuantumActivation, QuantumAugmentationStrategy, QuantumAugmenter, QuantumDecoder,
12 QuantumEncoder, QuantumEncoderDecoder, QuantumMaskingStrategy, QuantumProjectionHead,
13 QuantumSSLMethod, QuantumSelfSupervisedConfig, QuantumSelfSupervisedLearner,
14 QuantumSimilarityMetric, QuantumState, QuantumStateEvolution, QuantumStatePreparation,
15 ReconstructionObjective, ReconstructionStrategy, SSLTrainingConfig,
16};
17
18#[cfg(test)]
19mod tests {
20 use super::*;
21 #[test]
22 fn test_quantum_ssl_creation() {
23 let config = QuantumSelfSupervisedConfig::default();
24 let ssl = QuantumSelfSupervisedLearner::new(config);
25 assert!(ssl.is_ok());
26 }
27 #[test]
28 fn test_quantum_augmentations() {
29 let config = QuantumSelfSupervisedConfig::default();
30 let augmenter = QuantumAugmenter {
31 augmentation_strategies: vec![QuantumAugmentationStrategy::QuantumNoise {
32 noise_type: NoiseType::Gaussian,
33 strength: 0.1,
34 }],
35 augmentation_strength: 0.5,
36 quantum_coherence_preservation: 0.9,
37 };
38 let sample = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
39 let views = augmenter.generate_augmented_views(&sample, 2);
40 assert!(views.is_ok());
41 assert_eq!(views.expect("views should be ok").len(), 2);
42 }
43 #[test]
44 fn test_ssl_training_config() {
45 let config = SSLTrainingConfig::default();
46 assert_eq!(config.batch_size, 256);
47 assert_eq!(config.epochs, 100);
48 }
49 #[test]
50 fn test_quantum_contrastive_method() {
51 let config = QuantumSelfSupervisedConfig {
52 ssl_method: QuantumSSLMethod::QuantumContrastive {
53 similarity_metric: QuantumSimilarityMetric::QuantumCosine,
54 negative_sampling_strategy: NegativeSamplingStrategy::Random,
55 quantum_projection_head: QuantumProjectionHead {
56 hidden_dims: vec![128, 64],
57 output_dim: 32,
58 use_batch_norm: true,
59 quantum_layers: Vec::new(),
60 activation: QuantumActivation::QuantumReLU,
61 },
62 },
63 ..Default::default()
64 };
65 let ssl = QuantumSelfSupervisedLearner::new(config);
66 assert!(ssl.is_ok());
67 }
68 #[test]
69 fn test_quantum_masked_method() {
70 let config = QuantumSelfSupervisedConfig {
71 ssl_method: QuantumSSLMethod::QuantumMasked {
72 masking_strategy: QuantumMaskingStrategy::Random {
73 mask_probability: 0.15,
74 },
75 reconstruction_objective: ReconstructionObjective::MSE,
76 quantum_encoder_decoder: QuantumEncoderDecoder {
77 encoder: QuantumEncoder {
78 layers: Vec::new(),
79 quantum_state_evolution: QuantumStateEvolution {
80 evolution_type: EvolutionType::Unitary,
81 time_steps: Array1::linspace(0.0, 1.0, 10),
82 hamiltonian: Array2::<f64>::eye(8).mapv(|x| Complex64::new(x, 0.0)),
83 decoherence_model: DecoherenceModel::default(),
84 },
85 measurement_points: vec![0, 1],
86 },
87 decoder: QuantumDecoder {
88 layers: Vec::new(),
89 quantum_state_preparation: QuantumStatePreparation {
90 preparation_method: PreparationMethod::DirectPreparation,
91 target_state: QuantumState::default(),
92 fidelity_threshold: 0.95,
93 },
94 reconstruction_strategy: ReconstructionStrategy::FullReconstruction,
95 },
96 shared_quantum_state: true,
97 entanglement_coupling: 0.5,
98 },
99 },
100 ..Default::default()
101 };
102 let ssl = QuantumSelfSupervisedLearner::new(config);
103 assert!(ssl.is_ok());
104 }
105}