use crate::error::{MLError, Result};
use crate::quantum_in_context_learning::types::*;
use scirs2_core::ndarray::{Array1, Array2, Array3, ArrayView1, Axis};
use scirs2_core::random::ChaCha20Rng;
use scirs2_core::random::{Rng, SeedableRng};
use scirs2_core::Complex64;
use std::f64::consts::PI;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_quantum_in_context_learner_creation() {
let config = QuantumInContextLearningConfig::default();
let learner = QuantumInContextLearner::new(config);
assert!(learner.is_ok());
}
#[test]
fn test_context_encoding() {
let config = QuantumInContextLearningConfig::default();
let encoder = QuantumContextEncoder::new(&config).expect("should create context encoder");
let example = ContextExample {
input: Array1::from_vec(vec![0.1, 0.2, 0.3]),
output: Array1::from_vec(vec![0.8]),
metadata: ContextMetadata {
task_type: "classification".to_string(),
difficulty_level: 0.5,
modality: ContextModality::Tabular,
timestamp: 0,
importance_weight: 1.0,
},
quantum_encoding: QuantumContextState {
quantum_amplitudes: Array1::zeros(16).mapv(|_: f64| Complex64::new(1.0, 0.0)),
classical_features: Array1::from_vec(vec![0.1, 0.2, 0.3]),
entanglement_measure: 0.5,
coherence_time: 1.0,
fidelity: 1.0,
phase_information: Complex64::new(1.0, 0.0),
context_metadata: ContextMetadata {
task_type: "classification".to_string(),
difficulty_level: 0.5,
modality: ContextModality::Tabular,
timestamp: 0,
importance_weight: 1.0,
},
},
};
let encoded = encoder.encode_example(&example);
assert!(encoded.is_ok());
}
#[test]
fn test_zero_shot_learning() {
let config = QuantumInContextLearningConfig::default();
let learner =
QuantumInContextLearner::new(config.clone()).expect("Failed to create learner");
let query = Array1::from_vec(vec![0.5, -0.3, 0.8]);
let result = learner.zero_shot_learning(&query);
assert!(result.is_ok());
let prediction = result.expect("Failed to perform zero-shot learning");
assert_eq!(prediction.len(), config.model_dim);
}
#[test]
fn test_few_shot_learning() {
let config = QuantumInContextLearningConfig {
model_dim: 3,
max_context_examples: 5,
..Default::default()
};
let mut learner = QuantumInContextLearner::new(config).expect("Failed to create learner");
let examples = vec![ContextExample {
input: Array1::from_vec(vec![0.1, 0.2, 0.3]),
output: Array1::from_vec(vec![0.8]),
metadata: ContextMetadata {
task_type: "test".to_string(),
difficulty_level: 0.5,
modality: ContextModality::Tabular,
timestamp: 0,
importance_weight: 1.0,
},
quantum_encoding: QuantumContextState {
quantum_amplitudes: Array1::zeros(256).mapv(|_: f64| Complex64::new(1.0, 0.0)),
classical_features: Array1::from_vec(vec![0.1, 0.2, 0.3]),
entanglement_measure: 0.5,
coherence_time: 1.0,
fidelity: 1.0,
phase_information: Complex64::new(1.0, 0.0),
context_metadata: ContextMetadata {
task_type: "test".to_string(),
difficulty_level: 0.5,
modality: ContextModality::Tabular,
timestamp: 0,
importance_weight: 1.0,
},
},
}];
let query = Array1::from_vec(vec![0.5, -0.3, 0.8]);
let result = learner.few_shot_learning(&examples, &query, 3);
assert!(result.is_ok());
}
#[test]
fn test_quantum_memory_operations() {
let config = QuantumInContextLearningConfig::default();
let mut memory = QuantumEpisodicMemory::new(&config).expect("Failed to create memory");
let test_state = QuantumContextState {
quantum_amplitudes: Array1::zeros(256).mapv(|_: f64| Complex64::new(1.0, 0.0)),
classical_features: Array1::from_vec(vec![0.1, 0.2, 0.3]),
entanglement_measure: 0.7,
coherence_time: 0.9,
fidelity: 0.95,
phase_information: Complex64::new(1.0, 0.0),
context_metadata: ContextMetadata {
task_type: "memory_test".to_string(),
difficulty_level: 0.6,
modality: ContextModality::Tabular,
timestamp: 0,
importance_weight: 1.0,
},
};
let result = memory.add_experience(test_state.clone());
assert!(result.is_ok());
let retrieved = memory.retrieve_similar_contexts(&test_state, 1);
assert!(retrieved.is_ok());
assert_eq!(retrieved.expect("Failed to retrieve contexts").len(), 1);
}
#[test]
fn test_adaptation_strategies() {
let config = QuantumInContextLearningConfig {
adaptation_strategy: AdaptationStrategy::QuantumInterference {
interference_strength: 0.8,
},
..Default::default()
};
let learner = QuantumInContextLearner::new(config);
assert!(learner.is_ok());
}
#[test]
fn test_prototype_bank_operations() {
let config = QuantumInContextLearningConfig::default();
let mut bank = PrototypeBank::new(&config).expect("Failed to create prototype bank");
let test_state = QuantumContextState {
quantum_amplitudes: Array1::zeros(256).mapv(|_: f64| Complex64::new(1.0, 0.0)),
classical_features: Array1::from_vec(vec![0.1, 0.2, 0.3]),
entanglement_measure: 0.5,
coherence_time: 1.0,
fidelity: 1.0,
phase_information: Complex64::new(1.0, 0.0),
context_metadata: ContextMetadata {
task_type: "prototype_test".to_string(),
difficulty_level: 0.5,
modality: ContextModality::Tabular,
timestamp: 0,
importance_weight: 1.0,
},
};
let result = bank.add_prototype(test_state.clone());
assert!(result.is_ok());
assert_eq!(bank.get_prototype_count(), 1);
let found = bank.find_nearest_prototypes(&test_state, 1);
assert!(found.is_ok());
assert_eq!(found.expect("Failed to find nearest prototypes").len(), 1);
}
#[test]
fn test_quantum_attention_mechanism() {
let config = QuantumInContextLearningConfig {
num_attention_heads: 2,
..Default::default()
};
let attention =
QuantumContextAttention::new(&config).expect("Failed to create attention mechanism");
let query_state = QuantumContextState {
quantum_amplitudes: Array1::zeros(256).mapv(|_: f64| Complex64::new(1.0, 0.0)),
classical_features: Array1::from_vec(vec![0.1, 0.2, 0.3]),
entanglement_measure: 0.5,
coherence_time: 1.0,
fidelity: 1.0,
phase_information: Complex64::new(1.0, 0.0),
context_metadata: ContextMetadata {
task_type: "attention_test".to_string(),
difficulty_level: 0.5,
modality: ContextModality::Tabular,
timestamp: 0,
importance_weight: 1.0,
},
};
let contexts = vec![query_state.clone(), query_state.clone()];
let weights = attention.compute_attention_weights(&query_state, &contexts);
assert!(weights.is_ok());
assert_eq!(
weights.expect("Failed to compute attention weights").len(),
2
);
}
}