quantrs2_ml/quantum_in_context_learning/
functions.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use crate::error::{MLError, Result};
6use crate::quantum_in_context_learning::types::*;
7use scirs2_core::ndarray::{Array1, Array2, Array3, ArrayView1, Axis};
8use scirs2_core::random::ChaCha20Rng;
9use scirs2_core::random::{Rng, SeedableRng};
10use scirs2_core::Complex64;
11use std::f64::consts::PI;
12#[cfg(test)]
13mod tests {
14    use super::*;
15    #[test]
16    fn test_quantum_in_context_learner_creation() {
17        let config = QuantumInContextLearningConfig::default();
18        let learner = QuantumInContextLearner::new(config);
19        assert!(learner.is_ok());
20    }
21    #[test]
22    fn test_context_encoding() {
23        let config = QuantumInContextLearningConfig::default();
24        let encoder = QuantumContextEncoder::new(&config).expect("should create context encoder");
25        let example = ContextExample {
26            input: Array1::from_vec(vec![0.1, 0.2, 0.3]),
27            output: Array1::from_vec(vec![0.8]),
28            metadata: ContextMetadata {
29                task_type: "classification".to_string(),
30                difficulty_level: 0.5,
31                modality: ContextModality::Tabular,
32                timestamp: 0,
33                importance_weight: 1.0,
34            },
35            quantum_encoding: QuantumContextState {
36                quantum_amplitudes: Array1::zeros(16).mapv(|_: f64| Complex64::new(1.0, 0.0)),
37                classical_features: Array1::from_vec(vec![0.1, 0.2, 0.3]),
38                entanglement_measure: 0.5,
39                coherence_time: 1.0,
40                fidelity: 1.0,
41                phase_information: Complex64::new(1.0, 0.0),
42                context_metadata: ContextMetadata {
43                    task_type: "classification".to_string(),
44                    difficulty_level: 0.5,
45                    modality: ContextModality::Tabular,
46                    timestamp: 0,
47                    importance_weight: 1.0,
48                },
49            },
50        };
51        let encoded = encoder.encode_example(&example);
52        assert!(encoded.is_ok());
53    }
54    #[test]
55    fn test_zero_shot_learning() {
56        let config = QuantumInContextLearningConfig::default();
57        let learner =
58            QuantumInContextLearner::new(config.clone()).expect("Failed to create learner");
59        let query = Array1::from_vec(vec![0.5, -0.3, 0.8]);
60        let result = learner.zero_shot_learning(&query);
61        assert!(result.is_ok());
62        let prediction = result.expect("Failed to perform zero-shot learning");
63        assert_eq!(prediction.len(), config.model_dim);
64    }
65    #[test]
66    fn test_few_shot_learning() {
67        let config = QuantumInContextLearningConfig {
68            model_dim: 3,
69            max_context_examples: 5,
70            ..Default::default()
71        };
72        let mut learner = QuantumInContextLearner::new(config).expect("Failed to create learner");
73        let examples = vec![ContextExample {
74            input: Array1::from_vec(vec![0.1, 0.2, 0.3]),
75            output: Array1::from_vec(vec![0.8]),
76            metadata: ContextMetadata {
77                task_type: "test".to_string(),
78                difficulty_level: 0.5,
79                modality: ContextModality::Tabular,
80                timestamp: 0,
81                importance_weight: 1.0,
82            },
83            quantum_encoding: QuantumContextState {
84                quantum_amplitudes: Array1::zeros(256).mapv(|_: f64| Complex64::new(1.0, 0.0)),
85                classical_features: Array1::from_vec(vec![0.1, 0.2, 0.3]),
86                entanglement_measure: 0.5,
87                coherence_time: 1.0,
88                fidelity: 1.0,
89                phase_information: Complex64::new(1.0, 0.0),
90                context_metadata: ContextMetadata {
91                    task_type: "test".to_string(),
92                    difficulty_level: 0.5,
93                    modality: ContextModality::Tabular,
94                    timestamp: 0,
95                    importance_weight: 1.0,
96                },
97            },
98        }];
99        let query = Array1::from_vec(vec![0.5, -0.3, 0.8]);
100        let result = learner.few_shot_learning(&examples, &query, 3);
101        assert!(result.is_ok());
102    }
103    #[test]
104    fn test_quantum_memory_operations() {
105        let config = QuantumInContextLearningConfig::default();
106        let mut memory = QuantumEpisodicMemory::new(&config).expect("Failed to create memory");
107        let test_state = QuantumContextState {
108            quantum_amplitudes: Array1::zeros(256).mapv(|_: f64| Complex64::new(1.0, 0.0)),
109            classical_features: Array1::from_vec(vec![0.1, 0.2, 0.3]),
110            entanglement_measure: 0.7,
111            coherence_time: 0.9,
112            fidelity: 0.95,
113            phase_information: Complex64::new(1.0, 0.0),
114            context_metadata: ContextMetadata {
115                task_type: "memory_test".to_string(),
116                difficulty_level: 0.6,
117                modality: ContextModality::Tabular,
118                timestamp: 0,
119                importance_weight: 1.0,
120            },
121        };
122        let result = memory.add_experience(test_state.clone());
123        assert!(result.is_ok());
124        let retrieved = memory.retrieve_similar_contexts(&test_state, 1);
125        assert!(retrieved.is_ok());
126        assert_eq!(retrieved.expect("Failed to retrieve contexts").len(), 1);
127    }
128    #[test]
129    fn test_adaptation_strategies() {
130        let config = QuantumInContextLearningConfig {
131            adaptation_strategy: AdaptationStrategy::QuantumInterference {
132                interference_strength: 0.8,
133            },
134            ..Default::default()
135        };
136        let learner = QuantumInContextLearner::new(config);
137        assert!(learner.is_ok());
138    }
139    #[test]
140    fn test_prototype_bank_operations() {
141        let config = QuantumInContextLearningConfig::default();
142        let mut bank = PrototypeBank::new(&config).expect("Failed to create prototype bank");
143        let test_state = QuantumContextState {
144            quantum_amplitudes: Array1::zeros(256).mapv(|_: f64| Complex64::new(1.0, 0.0)),
145            classical_features: Array1::from_vec(vec![0.1, 0.2, 0.3]),
146            entanglement_measure: 0.5,
147            coherence_time: 1.0,
148            fidelity: 1.0,
149            phase_information: Complex64::new(1.0, 0.0),
150            context_metadata: ContextMetadata {
151                task_type: "prototype_test".to_string(),
152                difficulty_level: 0.5,
153                modality: ContextModality::Tabular,
154                timestamp: 0,
155                importance_weight: 1.0,
156            },
157        };
158        let result = bank.add_prototype(test_state.clone());
159        assert!(result.is_ok());
160        assert_eq!(bank.get_prototype_count(), 1);
161        let found = bank.find_nearest_prototypes(&test_state, 1);
162        assert!(found.is_ok());
163        assert_eq!(found.expect("Failed to find nearest prototypes").len(), 1);
164    }
165    #[test]
166    fn test_quantum_attention_mechanism() {
167        let config = QuantumInContextLearningConfig {
168            num_attention_heads: 2,
169            ..Default::default()
170        };
171        let attention =
172            QuantumContextAttention::new(&config).expect("Failed to create attention mechanism");
173        let query_state = QuantumContextState {
174            quantum_amplitudes: Array1::zeros(256).mapv(|_: f64| Complex64::new(1.0, 0.0)),
175            classical_features: Array1::from_vec(vec![0.1, 0.2, 0.3]),
176            entanglement_measure: 0.5,
177            coherence_time: 1.0,
178            fidelity: 1.0,
179            phase_information: Complex64::new(1.0, 0.0),
180            context_metadata: ContextMetadata {
181                task_type: "attention_test".to_string(),
182                difficulty_level: 0.5,
183                modality: ContextModality::Tabular,
184                timestamp: 0,
185                importance_weight: 1.0,
186            },
187        };
188        let contexts = vec![query_state.clone(), query_state.clone()];
189        let weights = attention.compute_attention_weights(&query_state, &contexts);
190        assert!(weights.is_ok());
191        assert_eq!(
192            weights.expect("Failed to compute attention weights").len(),
193            2
194        );
195    }
196}