1use super::config::{ActivationFunction, MetaLearningConfig};
4use super::feedforward::FeedForwardNetwork;
5use super::layers::LayerNormalization;
6use crate::error::Result;
7use scirs2_core::ndarray::{Array1, Array2, Array3, Axis};
8use scirs2_core::numeric::Float;
9use std::collections::{HashMap, VecDeque};
10use std::fmt::Debug;
11use std::time::Instant;
12
13#[derive(Debug, Clone, Copy, PartialEq)]
15pub enum MetaLearningStrategy {
16 MAML,
18 FOMAML,
20 Reptile,
22 GradientBased,
24 MemoryAugmented,
26}
27
28pub struct TransformerMetaLearning<T: Float + Debug + Send + Sync + 'static> {
30 strategy: MetaLearningStrategy,
32
33 config: MetaLearningConfig<T>,
35
36 meta_optimizer: MetaOptimizer<T>,
38
39 adaptation_network: AdaptationNetwork<T>,
41
42 memory_bank: MemoryBank<T>,
44
45 performance_tracker: PerformanceTracker<T>,
47
48 meta_state: MetaState<T>,
50}
51
52impl<T: Float + Debug + Send + Sync + 'static + scirs2_core::ndarray::ScalarOperand>
53 TransformerMetaLearning<T>
54{
55 pub fn new(config: &super::config::TransformerBasedOptimizerConfig<T>) -> Result<Self> {
57 let meta_config = config.meta_learning_config.clone();
58 let strategy = MetaLearningStrategy::MAML; let meta_optimizer = MetaOptimizer::new(&meta_config)?;
61 let adaptation_network =
62 AdaptationNetwork::new(config.model_dimension, config.feedforward_dimension)?;
63 let memory_bank = MemoryBank::new(1000, config.model_dimension)?;
64 let performance_tracker = PerformanceTracker::new();
65 let meta_state = MetaState::new(config.model_dimension)?;
66
67 Ok(Self {
68 strategy,
69 config: meta_config,
70 meta_optimizer,
71 adaptation_network,
72 memory_bank,
73 performance_tracker,
74 meta_state,
75 })
76 }
77
78 pub fn meta_step(
80 &mut self,
81 tasks: &[TaskBatch<T>],
82 support_data: &[Array2<T>],
83 query_data: &[Array2<T>],
84 ) -> Result<MetaLearningResult<T>> {
85 match self.strategy {
86 MetaLearningStrategy::MAML => self.maml_step(tasks, support_data, query_data),
87 MetaLearningStrategy::FOMAML => self.fomaml_step(tasks, support_data, query_data),
88 MetaLearningStrategy::Reptile => self.reptile_step(tasks, support_data, query_data),
89 MetaLearningStrategy::GradientBased => {
90 self.gradient_based_step(tasks, support_data, query_data)
91 }
92 MetaLearningStrategy::MemoryAugmented => {
93 self.memory_augmented_step(tasks, support_data, query_data)
94 }
95 }
96 }
97
98 fn maml_step(
100 &mut self,
101 tasks: &[TaskBatch<T>],
102 support_data: &[Array2<T>],
103 query_data: &[Array2<T>],
104 ) -> Result<MetaLearningResult<T>> {
105 let start_time = Instant::now();
106 let mut total_loss = T::zero();
107 let mut task_adaptations = Vec::new();
108
109 for (i, task) in tasks.iter().enumerate() {
110 let mut adapted_params = self.meta_state.get_parameters().clone();
112
113 for inner_step in 0..self.config.inner_steps {
114 let support_loss =
116 self.compute_task_loss(&adapted_params, &support_data[i], task)?;
117 let gradients = self.compute_gradients(&adapted_params, support_loss)?;
118
119 for (param, grad) in adapted_params.iter_mut().zip(gradients.iter()) {
121 *param = *param - self.config.inner_learning_rate * (*grad);
122 }
123 }
124
125 let query_loss = self.compute_task_loss(&adapted_params, &query_data[i], task)?;
127 total_loss = total_loss + query_loss;
128
129 task_adaptations.push(TaskAdaptation {
130 task_id: task.id.clone(),
131 adapted_parameters: adapted_params,
132 support_loss: self.compute_task_loss(
133 self.meta_state.get_parameters(),
134 &support_data[i],
135 task,
136 )?,
137 query_loss,
138 adaptation_steps: self.config.inner_steps,
139 });
140 }
141
142 let meta_loss = total_loss / T::from(tasks.len()).unwrap();
144 let meta_gradients = self.compute_meta_gradients(&task_adaptations)?;
145 self.meta_optimizer
146 .update(&mut self.meta_state, &meta_gradients)?;
147
148 for (i, adaptation) in task_adaptations.iter().enumerate() {
150 self.memory_bank.store_experience(
151 &tasks[i],
152 &adaptation.adapted_parameters,
153 adaptation.query_loss,
154 )?;
155 }
156
157 let result = MetaLearningResult {
158 meta_loss: meta_loss.to_f64().unwrap_or(0.0),
159 task_adaptations,
160 computation_time: start_time.elapsed(),
161 convergence_rate: self.estimate_convergence_rate()?,
162 };
163
164 self.performance_tracker.record_meta_step(result.clone());
165 Ok(result)
166 }
167
168 fn fomaml_step(
170 &mut self,
171 tasks: &[TaskBatch<T>],
172 support_data: &[Array2<T>],
173 query_data: &[Array2<T>],
174 ) -> Result<MetaLearningResult<T>> {
175 self.maml_step(tasks, support_data, query_data)
178 }
179
180 fn reptile_step(
182 &mut self,
183 tasks: &[TaskBatch<T>],
184 support_data: &[Array2<T>],
185 query_data: &[Array2<T>],
186 ) -> Result<MetaLearningResult<T>> {
187 let start_time = Instant::now();
188 let mut parameter_updates = Vec::new();
189 let mut total_loss = T::zero();
190
191 for (i, task) in tasks.iter().enumerate() {
192 let mut adapted_params = self.meta_state.get_parameters().clone();
194
195 for _ in 0..self.config.inner_steps {
196 let loss = self.compute_task_loss(&adapted_params, &support_data[i], task)?;
197 let gradients = self.compute_gradients(&adapted_params, loss)?;
198
199 for (param, grad) in adapted_params.iter_mut().zip(gradients.iter()) {
200 *param = *param - self.config.inner_learning_rate * (*grad);
201 }
202 }
203
204 let original_params = self.meta_state.get_parameters();
206 let param_diff: Vec<T> = adapted_params
207 .iter()
208 .zip(original_params.iter())
209 .map(|(adapted, original)| *adapted - *original)
210 .collect();
211
212 parameter_updates.push(param_diff);
213
214 let query_loss = self.compute_task_loss(&adapted_params, &query_data[i], task)?;
216 total_loss = total_loss + query_loss;
217 }
218
219 let mut meta_update = vec![T::zero(); self.meta_state.get_parameters().len()];
221 for param_update in ¶meter_updates {
222 for (i, &update) in param_update.iter().enumerate() {
223 meta_update[i] = meta_update[i] + update;
224 }
225 }
226
227 let num_tasks = T::from(tasks.len()).unwrap();
228 for update in meta_update.iter_mut() {
229 *update = *update / num_tasks;
230 }
231
232 self.meta_state
233 .update_parameters(&meta_update, self.config.meta_learning_rate)?;
234
235 let result = MetaLearningResult {
236 meta_loss: (total_loss / num_tasks).to_f64().unwrap_or(0.0),
237 task_adaptations: Vec::new(), computation_time: start_time.elapsed(),
239 convergence_rate: self.estimate_convergence_rate()?,
240 };
241
242 self.performance_tracker.record_meta_step(result.clone());
243 Ok(result)
244 }
245
246 fn gradient_based_step(
248 &mut self,
249 tasks: &[TaskBatch<T>],
250 support_data: &[Array2<T>],
251 query_data: &[Array2<T>],
252 ) -> Result<MetaLearningResult<T>> {
253 let start_time = Instant::now();
255 let mut total_loss = T::zero();
256
257 for (i, task) in tasks.iter().enumerate() {
258 let context_embedding = self.adaptation_network.encode_task_context(task)?;
260 let predicted_params = self
261 .adaptation_network
262 .predict_parameters(&context_embedding)?;
263
264 let mut adapted_params = predicted_params;
266 for _ in 0..self.config.inner_steps {
267 let loss = self.compute_task_loss(&adapted_params, &support_data[i], task)?;
268 let gradients = self.compute_gradients(&adapted_params, loss)?;
269
270 for (param, grad) in adapted_params.iter_mut().zip(gradients.iter()) {
271 *param = *param - self.config.inner_learning_rate * (*grad);
272 }
273 }
274
275 let query_loss = self.compute_task_loss(&adapted_params, &query_data[i], task)?;
276 total_loss = total_loss + query_loss;
277 }
278
279 let result = MetaLearningResult {
280 meta_loss: (total_loss / T::from(tasks.len()).unwrap())
281 .to_f64()
282 .unwrap_or(0.0),
283 task_adaptations: Vec::new(),
284 computation_time: start_time.elapsed(),
285 convergence_rate: self.estimate_convergence_rate()?,
286 };
287
288 Ok(result)
289 }
290
291 fn memory_augmented_step(
293 &mut self,
294 tasks: &[TaskBatch<T>],
295 support_data: &[Array2<T>],
296 query_data: &[Array2<T>],
297 ) -> Result<MetaLearningResult<T>> {
298 let start_time = Instant::now();
299 let mut total_loss = T::zero();
300
301 for (i, task) in tasks.iter().enumerate() {
302 let relevant_experiences = self.memory_bank.retrieve_similar_experiences(task, 5)?;
304
305 let memory_guided_params = self.initialize_from_memory(&relevant_experiences)?;
307
308 let mut adapted_params = memory_guided_params;
309 for _ in 0..self.config.inner_steps {
310 let loss = self.compute_task_loss(&adapted_params, &support_data[i], task)?;
311 let gradients = self.compute_gradients(&adapted_params, loss)?;
312
313 for (param, grad) in adapted_params.iter_mut().zip(gradients.iter()) {
314 *param = *param - self.config.inner_learning_rate * (*grad);
315 }
316 }
317
318 let query_loss = self.compute_task_loss(&adapted_params, &query_data[i], task)?;
319 total_loss = total_loss + query_loss;
320
321 self.memory_bank
323 .store_experience(task, &adapted_params, query_loss)?;
324 }
325
326 let result = MetaLearningResult {
327 meta_loss: (total_loss / T::from(tasks.len()).unwrap())
328 .to_f64()
329 .unwrap_or(0.0),
330 task_adaptations: Vec::new(),
331 computation_time: start_time.elapsed(),
332 convergence_rate: self.estimate_convergence_rate()?,
333 };
334
335 Ok(result)
336 }
337
338 pub fn generate_update(
340 &mut self,
341 transformer_output: &Array2<T>,
342 current_parameters: &Array1<T>,
343 ) -> Result<Array1<T>> {
344 let update = self
346 .adaptation_network
347 .generate_parameter_update(transformer_output, current_parameters)?;
348
349 let scaled_update = self.apply_meta_scaling(&update)?;
351
352 Ok(scaled_update)
353 }
354
355 pub fn update_from_loss(&mut self, loss: T) -> Result<()> {
357 self.meta_state.update_loss_history(loss);
358 self.performance_tracker
359 .record_loss(loss.to_f64().unwrap_or(0.0));
360 Ok(())
361 }
362
363 pub fn set_strategy(&mut self, strategy: MetaLearningStrategy) {
365 self.strategy = strategy;
366 }
367
368 pub fn get_strategy(&self) -> MetaLearningStrategy {
370 self.strategy
371 }
372
373 fn compute_task_loss(&self, params: &[T], data: &Array2<T>, task: &TaskBatch<T>) -> Result<T> {
375 let prediction_error = self.compute_prediction_error(params, data, task)?;
377 Ok(prediction_error)
378 }
379
380 fn compute_prediction_error(
381 &self,
382 _params: &[T],
383 data: &Array2<T>,
384 _task: &TaskBatch<T>,
385 ) -> Result<T> {
386 let mean_squared_error = data
388 .iter()
389 .map(|&x| x * x)
390 .fold(T::zero(), |acc, x| acc + x);
391 Ok(mean_squared_error / T::from(data.len()).unwrap())
392 }
393
394 fn compute_gradients(&self, params: &[T], loss: T) -> Result<Vec<T>> {
395 let gradients = params
397 .iter()
398 .map(|_| loss / T::from(params.len()).unwrap())
399 .collect();
400 Ok(gradients)
401 }
402
403 fn compute_meta_gradients(&self, adaptations: &[TaskAdaptation<T>]) -> Result<Vec<T>> {
404 let param_count = adaptations[0].adapted_parameters.len();
405 let mut meta_gradients = vec![T::zero(); param_count];
406
407 for adaptation in adaptations {
408 for (i, ¶m) in adaptation.adapted_parameters.iter().enumerate() {
409 meta_gradients[i] = meta_gradients[i] + param * adaptation.query_loss;
410 }
411 }
412
413 let num_tasks = T::from(adaptations.len()).unwrap();
414 for grad in meta_gradients.iter_mut() {
415 *grad = *grad / num_tasks;
416 }
417
418 Ok(meta_gradients)
419 }
420
421 fn estimate_convergence_rate(&self) -> Result<f64> {
422 let loss_history = self.performance_tracker.get_loss_history();
423 if loss_history.len() < 2 {
424 return Ok(0.0);
425 }
426
427 let recent_losses: Vec<_> = loss_history.iter().rev().take(5).cloned().collect();
428 let improvement = recent_losses.last().unwrap() - recent_losses.first().unwrap();
429 Ok(improvement.clamp(0.0, 1.0))
430 }
431
432 fn apply_meta_scaling(&self, update: &Array1<T>) -> Result<Array1<T>> {
433 let scale_factor = self.meta_state.get_scale_factor();
435 Ok(update * scale_factor)
436 }
437
438 fn initialize_from_memory(&self, experiences: &[MemoryExperience<T>]) -> Result<Vec<T>> {
439 if experiences.is_empty() {
440 return Ok(self.meta_state.get_parameters().clone());
441 }
442
443 let param_count = experiences[0].parameters.len();
445 let mut averaged_params = vec![T::zero(); param_count];
446
447 for experience in experiences {
448 for (i, ¶m) in experience.parameters.iter().enumerate() {
449 averaged_params[i] = averaged_params[i] + param;
450 }
451 }
452
453 let num_experiences = T::from(experiences.len()).unwrap();
454 for param in averaged_params.iter_mut() {
455 *param = *param / num_experiences;
456 }
457
458 Ok(averaged_params)
459 }
460}
461
462pub struct MetaOptimizer<T: Float + Debug + Send + Sync + 'static> {
464 learning_rate: T,
466
467 momentum: Option<T>,
469
470 velocity: Option<Vec<T>>,
472}
473
474impl<T: Float + Debug + Send + Sync + 'static> MetaOptimizer<T> {
475 pub fn new(config: &MetaLearningConfig<T>) -> Result<Self> {
476 Ok(Self {
477 learning_rate: config.meta_learning_rate,
478 momentum: None,
479 velocity: None,
480 })
481 }
482
483 pub fn update(&mut self, state: &mut MetaState<T>, gradients: &[T]) -> Result<()> {
484 let params = state.get_parameters_mut();
485
486 if let Some(momentum) = self.momentum {
487 if self.velocity.is_none() {
489 self.velocity = Some(vec![T::zero(); params.len()]);
490 }
491
492 if let Some(ref mut velocity) = self.velocity {
493 for i in 0..params.len() {
494 velocity[i] = momentum * velocity[i] + self.learning_rate * gradients[i];
495 params[i] = params[i] - velocity[i];
496 }
497 }
498 } else {
499 for (param, &grad) in params.iter_mut().zip(gradients.iter()) {
501 *param = *param - self.learning_rate * grad;
502 }
503 }
504
505 Ok(())
506 }
507}
508
509pub struct AdaptationNetwork<T: Float + Debug + Send + Sync + 'static> {
511 context_encoder: FeedForwardNetwork<T>,
513
514 parameter_predictor: FeedForwardNetwork<T>,
516
517 update_generator: FeedForwardNetwork<T>,
519
520 model_dimension: usize,
522}
523
524impl<T: Float + Debug + Send + Sync + 'static> AdaptationNetwork<T> {
525 pub fn new(model_dimension: usize, hidden_dimension: usize) -> Result<Self> {
526 let context_encoder =
527 FeedForwardNetwork::new(model_dimension, hidden_dimension, ActivationFunction::ReLU)?;
528
529 let parameter_predictor =
530 FeedForwardNetwork::new(hidden_dimension, model_dimension, ActivationFunction::Tanh)?;
531
532 let update_generator = FeedForwardNetwork::new(
533 model_dimension * 2, model_dimension,
535 ActivationFunction::ReLU,
536 )?;
537
538 Ok(Self {
539 context_encoder,
540 parameter_predictor,
541 update_generator,
542 model_dimension,
543 })
544 }
545
546 pub fn encode_task_context(&mut self, _task: &TaskBatch<T>) -> Result<Array2<T>> {
547 let task_features = Array2::ones((1, self.model_dimension));
549 self.context_encoder.forward(&task_features)
550 }
551
552 pub fn predict_parameters(&mut self, context: &Array2<T>) -> Result<Vec<T>> {
553 let predicted = self.parameter_predictor.forward(context)?;
554 Ok(predicted.row(0).to_vec())
555 }
556
557 pub fn generate_parameter_update(
558 &mut self,
559 transformer_output: &Array2<T>,
560 current_parameters: &Array1<T>,
561 ) -> Result<Array1<T>> {
562 let batch_size = transformer_output.shape()[0];
564 let mut input = Array2::zeros((batch_size, self.model_dimension * 2));
565
566 for i in 0..batch_size {
567 for j in 0..self.model_dimension {
568 input[[i, j]] = transformer_output[[i, j]];
569 if j < current_parameters.len() {
570 input[[i, j + self.model_dimension]] = current_parameters[j];
571 }
572 }
573 }
574
575 let update = self.update_generator.forward(&input)?;
576 Ok(update.row(0).to_owned())
577 }
578}
579
580pub struct MemoryBank<T: Float + Debug + Send + Sync + 'static> {
582 experiences: VecDeque<MemoryExperience<T>>,
584
585 max_size: usize,
587
588 parameter_dimension: usize,
590}
591
592impl<T: Float + Debug + Send + Sync + 'static> MemoryBank<T> {
593 pub fn new(max_size: usize, parameter_dimension: usize) -> Result<Self> {
594 Ok(Self {
595 experiences: VecDeque::new(),
596 max_size,
597 parameter_dimension,
598 })
599 }
600
601 pub fn store_experience(
602 &mut self,
603 task: &TaskBatch<T>,
604 parameters: &[T],
605 performance: T,
606 ) -> Result<()> {
607 let experience = MemoryExperience {
608 task_signature: self.compute_task_signature(task),
609 parameters: parameters.to_vec(),
610 performance: performance.to_f64().unwrap_or(0.0),
611 timestamp: Instant::now(),
612 };
613
614 self.experiences.push_back(experience);
615
616 if self.experiences.len() > self.max_size {
617 self.experiences.pop_front();
618 }
619
620 Ok(())
621 }
622
623 pub fn retrieve_similar_experiences(
624 &self,
625 task: &TaskBatch<T>,
626 k: usize,
627 ) -> Result<Vec<MemoryExperience<T>>> {
628 let target_signature = self.compute_task_signature(task);
629
630 let mut scored_experiences: Vec<_> = self
631 .experiences
632 .iter()
633 .map(|exp| {
634 let similarity = self.compute_similarity(&target_signature, &exp.task_signature);
635 (similarity, exp.clone())
636 })
637 .collect();
638
639 scored_experiences.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
640
641 Ok(scored_experiences
642 .into_iter()
643 .take(k)
644 .map(|(_, exp)| exp)
645 .collect())
646 }
647
648 fn compute_task_signature(&self, task: &TaskBatch<T>) -> Vec<f64> {
649 vec![task.difficulty, task.complexity, task.data_characteristics]
651 }
652
653 fn compute_similarity(&self, sig1: &[f64], sig2: &[f64]) -> f64 {
654 if sig1.len() != sig2.len() {
655 return 0.0;
656 }
657
658 let dot_product: f64 = sig1.iter().zip(sig2.iter()).map(|(a, b)| a * b).sum();
659 let norm1: f64 = sig1.iter().map(|x| x * x).sum::<f64>().sqrt();
660 let norm2: f64 = sig2.iter().map(|x| x * x).sum::<f64>().sqrt();
661
662 if norm1 == 0.0 || norm2 == 0.0 {
663 0.0
664 } else {
665 dot_product / (norm1 * norm2)
666 }
667 }
668}
669
670#[derive(Debug, Clone)]
672pub struct TaskBatch<T: Float + Debug + Send + Sync + 'static> {
673 pub id: String,
674 pub difficulty: f64,
675 pub complexity: f64,
676 pub data_characteristics: f64,
677 pub _phantom: std::marker::PhantomData<T>,
678}
679
680#[derive(Debug, Clone)]
681pub struct TaskAdaptation<T: Float + Debug + Send + Sync + 'static> {
682 pub task_id: String,
683 pub adapted_parameters: Vec<T>,
684 pub support_loss: T,
685 pub query_loss: T,
686 pub adaptation_steps: usize,
687}
688
689#[derive(Debug, Clone)]
690pub struct MetaLearningResult<T: Float + Debug + Send + Sync + 'static> {
691 pub meta_loss: f64,
692 pub task_adaptations: Vec<TaskAdaptation<T>>,
693 pub computation_time: std::time::Duration,
694 pub convergence_rate: f64,
695}
696
697#[derive(Debug, Clone)]
698pub struct MemoryExperience<T: Float + Debug + Send + Sync + 'static> {
699 pub task_signature: Vec<f64>,
700 pub parameters: Vec<T>,
701 pub performance: f64,
702 pub timestamp: Instant,
703}
704
705pub struct PerformanceTracker<T: Float + Debug + Send + Sync + 'static> {
706 loss_history: VecDeque<f64>,
707 meta_results: VecDeque<MetaLearningResult<T>>,
708}
709
710impl<T: Float + Debug + Send + Sync + 'static> Default for PerformanceTracker<T> {
711 fn default() -> Self {
712 Self::new()
713 }
714}
715
716impl<T: Float + Debug + Send + Sync + 'static> PerformanceTracker<T> {
717 pub fn new() -> Self {
718 Self {
719 loss_history: VecDeque::new(),
720 meta_results: VecDeque::new(),
721 }
722 }
723
724 pub fn record_loss(&mut self, loss: f64) {
725 self.loss_history.push_back(loss);
726 if self.loss_history.len() > 1000 {
727 self.loss_history.pop_front();
728 }
729 }
730
731 pub fn record_meta_step(&mut self, result: MetaLearningResult<T>) {
732 self.meta_results.push_back(result);
733 if self.meta_results.len() > 100 {
734 self.meta_results.pop_front();
735 }
736 }
737
738 pub fn get_loss_history(&self) -> &VecDeque<f64> {
739 &self.loss_history
740 }
741}
742
743#[derive(Debug, Clone)]
744pub struct MetaState<T: Float + Debug + Send + Sync + 'static> {
745 parameters: Vec<T>,
746 loss_history: VecDeque<T>,
747 scale_factor: T,
748}
749
750impl<T: Float + Debug + Send + Sync + 'static> MetaState<T> {
751 pub fn new(parameter_count: usize) -> Result<Self> {
752 Ok(Self {
753 parameters: vec![T::zero(); parameter_count],
754 loss_history: VecDeque::new(),
755 scale_factor: T::one(),
756 })
757 }
758
759 pub fn get_parameters(&self) -> &Vec<T> {
760 &self.parameters
761 }
762
763 pub fn get_parameters_mut(&mut self) -> &mut Vec<T> {
764 &mut self.parameters
765 }
766
767 pub fn update_parameters(&mut self, updates: &[T], learning_rate: T) -> Result<()> {
768 for (param, &update) in self.parameters.iter_mut().zip(updates.iter()) {
769 *param = *param + learning_rate * update;
770 }
771 Ok(())
772 }
773
774 pub fn update_loss_history(&mut self, loss: T) {
775 self.loss_history.push_back(loss);
776 if self.loss_history.len() > 100 {
777 self.loss_history.pop_front();
778 }
779 }
780
781 pub fn get_scale_factor(&self) -> T {
782 self.scale_factor
783 }
784}
785
786#[cfg(test)]
787mod tests {
788 use super::*;
789
790 #[test]
791 #[ignore]
792 fn test_meta_learning_creation() {
793 let config = super::super::config::TransformerBasedOptimizerConfig::<f32>::default();
794 let meta_learning = TransformerMetaLearning::new(&config);
795 assert!(meta_learning.is_ok());
796 }
797
798 #[test]
799 fn test_memory_bank() {
800 let memory = MemoryBank::<f32>::new(100, 64);
801 assert!(memory.is_ok());
802
803 let mut bank = memory.unwrap();
804 let task = TaskBatch {
805 id: "test".to_string(),
806 difficulty: 0.5,
807 complexity: 0.7,
808 data_characteristics: 0.3,
809 _phantom: std::marker::PhantomData,
810 };
811
812 let params = vec![0.1f32; 64];
813 assert!(bank.store_experience(&task, ¶ms, 0.8).is_ok());
814 }
815
816 #[test]
817 fn test_adaptation_network() {
818 let network = AdaptationNetwork::<f32>::new(128, 256);
819 assert!(network.is_ok());
820
821 let mut net = network.unwrap();
822 let task = TaskBatch {
823 id: "test".to_string(),
824 difficulty: 0.5,
825 complexity: 0.7,
826 data_characteristics: 0.3,
827 _phantom: std::marker::PhantomData,
828 };
829
830 let context = net.encode_task_context(&task);
831 assert!(context.is_ok());
832 }
833}