quantrs2_ml/quantum_in_context_learning/
functions.rs1use crate::error::{MLError, Result};
6use crate::quantum_in_context_learning::types::*;
7use scirs2_core::ndarray::{Array1, Array2, Array3, ArrayView1, Axis};
8use scirs2_core::random::ChaCha20Rng;
9use scirs2_core::random::{Rng, SeedableRng};
10use scirs2_core::Complex64;
11use std::f64::consts::PI;
12#[cfg(test)]
13mod tests {
14 use super::*;
15 #[test]
16 fn test_quantum_in_context_learner_creation() {
17 let config = QuantumInContextLearningConfig::default();
18 let learner = QuantumInContextLearner::new(config);
19 assert!(learner.is_ok());
20 }
21 #[test]
22 fn test_context_encoding() {
23 let config = QuantumInContextLearningConfig::default();
24 let encoder = QuantumContextEncoder::new(&config).expect("should create context encoder");
25 let example = ContextExample {
26 input: Array1::from_vec(vec![0.1, 0.2, 0.3]),
27 output: Array1::from_vec(vec![0.8]),
28 metadata: ContextMetadata {
29 task_type: "classification".to_string(),
30 difficulty_level: 0.5,
31 modality: ContextModality::Tabular,
32 timestamp: 0,
33 importance_weight: 1.0,
34 },
35 quantum_encoding: QuantumContextState {
36 quantum_amplitudes: Array1::zeros(16).mapv(|_: f64| Complex64::new(1.0, 0.0)),
37 classical_features: Array1::from_vec(vec![0.1, 0.2, 0.3]),
38 entanglement_measure: 0.5,
39 coherence_time: 1.0,
40 fidelity: 1.0,
41 phase_information: Complex64::new(1.0, 0.0),
42 context_metadata: ContextMetadata {
43 task_type: "classification".to_string(),
44 difficulty_level: 0.5,
45 modality: ContextModality::Tabular,
46 timestamp: 0,
47 importance_weight: 1.0,
48 },
49 },
50 };
51 let encoded = encoder.encode_example(&example);
52 assert!(encoded.is_ok());
53 }
54 #[test]
55 fn test_zero_shot_learning() {
56 let config = QuantumInContextLearningConfig::default();
57 let learner =
58 QuantumInContextLearner::new(config.clone()).expect("Failed to create learner");
59 let query = Array1::from_vec(vec![0.5, -0.3, 0.8]);
60 let result = learner.zero_shot_learning(&query);
61 assert!(result.is_ok());
62 let prediction = result.expect("Failed to perform zero-shot learning");
63 assert_eq!(prediction.len(), config.model_dim);
64 }
65 #[test]
66 fn test_few_shot_learning() {
67 let config = QuantumInContextLearningConfig {
68 model_dim: 3,
69 max_context_examples: 5,
70 ..Default::default()
71 };
72 let mut learner = QuantumInContextLearner::new(config).expect("Failed to create learner");
73 let examples = vec![ContextExample {
74 input: Array1::from_vec(vec![0.1, 0.2, 0.3]),
75 output: Array1::from_vec(vec![0.8]),
76 metadata: ContextMetadata {
77 task_type: "test".to_string(),
78 difficulty_level: 0.5,
79 modality: ContextModality::Tabular,
80 timestamp: 0,
81 importance_weight: 1.0,
82 },
83 quantum_encoding: QuantumContextState {
84 quantum_amplitudes: Array1::zeros(256).mapv(|_: f64| Complex64::new(1.0, 0.0)),
85 classical_features: Array1::from_vec(vec![0.1, 0.2, 0.3]),
86 entanglement_measure: 0.5,
87 coherence_time: 1.0,
88 fidelity: 1.0,
89 phase_information: Complex64::new(1.0, 0.0),
90 context_metadata: ContextMetadata {
91 task_type: "test".to_string(),
92 difficulty_level: 0.5,
93 modality: ContextModality::Tabular,
94 timestamp: 0,
95 importance_weight: 1.0,
96 },
97 },
98 }];
99 let query = Array1::from_vec(vec![0.5, -0.3, 0.8]);
100 let result = learner.few_shot_learning(&examples, &query, 3);
101 assert!(result.is_ok());
102 }
103 #[test]
104 fn test_quantum_memory_operations() {
105 let config = QuantumInContextLearningConfig::default();
106 let mut memory = QuantumEpisodicMemory::new(&config).expect("Failed to create memory");
107 let test_state = QuantumContextState {
108 quantum_amplitudes: Array1::zeros(256).mapv(|_: f64| Complex64::new(1.0, 0.0)),
109 classical_features: Array1::from_vec(vec![0.1, 0.2, 0.3]),
110 entanglement_measure: 0.7,
111 coherence_time: 0.9,
112 fidelity: 0.95,
113 phase_information: Complex64::new(1.0, 0.0),
114 context_metadata: ContextMetadata {
115 task_type: "memory_test".to_string(),
116 difficulty_level: 0.6,
117 modality: ContextModality::Tabular,
118 timestamp: 0,
119 importance_weight: 1.0,
120 },
121 };
122 let result = memory.add_experience(test_state.clone());
123 assert!(result.is_ok());
124 let retrieved = memory.retrieve_similar_contexts(&test_state, 1);
125 assert!(retrieved.is_ok());
126 assert_eq!(retrieved.expect("Failed to retrieve contexts").len(), 1);
127 }
128 #[test]
129 fn test_adaptation_strategies() {
130 let config = QuantumInContextLearningConfig {
131 adaptation_strategy: AdaptationStrategy::QuantumInterference {
132 interference_strength: 0.8,
133 },
134 ..Default::default()
135 };
136 let learner = QuantumInContextLearner::new(config);
137 assert!(learner.is_ok());
138 }
139 #[test]
140 fn test_prototype_bank_operations() {
141 let config = QuantumInContextLearningConfig::default();
142 let mut bank = PrototypeBank::new(&config).expect("Failed to create prototype bank");
143 let test_state = QuantumContextState {
144 quantum_amplitudes: Array1::zeros(256).mapv(|_: f64| Complex64::new(1.0, 0.0)),
145 classical_features: Array1::from_vec(vec![0.1, 0.2, 0.3]),
146 entanglement_measure: 0.5,
147 coherence_time: 1.0,
148 fidelity: 1.0,
149 phase_information: Complex64::new(1.0, 0.0),
150 context_metadata: ContextMetadata {
151 task_type: "prototype_test".to_string(),
152 difficulty_level: 0.5,
153 modality: ContextModality::Tabular,
154 timestamp: 0,
155 importance_weight: 1.0,
156 },
157 };
158 let result = bank.add_prototype(test_state.clone());
159 assert!(result.is_ok());
160 assert_eq!(bank.get_prototype_count(), 1);
161 let found = bank.find_nearest_prototypes(&test_state, 1);
162 assert!(found.is_ok());
163 assert_eq!(found.expect("Failed to find nearest prototypes").len(), 1);
164 }
165 #[test]
166 fn test_quantum_attention_mechanism() {
167 let config = QuantumInContextLearningConfig {
168 num_attention_heads: 2,
169 ..Default::default()
170 };
171 let attention =
172 QuantumContextAttention::new(&config).expect("Failed to create attention mechanism");
173 let query_state = QuantumContextState {
174 quantum_amplitudes: Array1::zeros(256).mapv(|_: f64| Complex64::new(1.0, 0.0)),
175 classical_features: Array1::from_vec(vec![0.1, 0.2, 0.3]),
176 entanglement_measure: 0.5,
177 coherence_time: 1.0,
178 fidelity: 1.0,
179 phase_information: Complex64::new(1.0, 0.0),
180 context_metadata: ContextMetadata {
181 task_type: "attention_test".to_string(),
182 difficulty_level: 0.5,
183 modality: ContextModality::Tabular,
184 timestamp: 0,
185 importance_weight: 1.0,
186 },
187 };
188 let contexts = vec![query_state.clone(), query_state.clone()];
189 let weights = attention.compute_attention_weights(&query_state, &contexts);
190 assert!(weights.is_ok());
191 assert_eq!(
192 weights.expect("Failed to compute attention weights").len(),
193 2
194 );
195 }
196}