1use super::config::{ActivationFunction, MetaLearningConfig};
4use super::feedforward::FeedForwardNetwork;
5use super::layers::LayerNormalization;
6use crate::error::Result;
7use scirs2_core::ndarray::{Array1, Array2, Array3, Axis};
8use scirs2_core::numeric::Float;
9use std::collections::{HashMap, VecDeque};
10use std::fmt::Debug;
11use std::time::Instant;
12
13#[derive(Debug, Clone, Copy, PartialEq)]
15pub enum MetaLearningStrategy {
16 MAML,
18 FOMAML,
20 Reptile,
22 GradientBased,
24 MemoryAugmented,
26}
27
28pub struct TransformerMetaLearning<T: Float + Debug + Send + Sync + 'static> {
30 strategy: MetaLearningStrategy,
32
33 config: MetaLearningConfig<T>,
35
36 meta_optimizer: MetaOptimizer<T>,
38
39 adaptation_network: AdaptationNetwork<T>,
41
42 memory_bank: MemoryBank<T>,
44
45 performance_tracker: PerformanceTracker<T>,
47
48 meta_state: MetaState<T>,
50}
51
52impl<T: Float + Debug + Send + Sync + 'static + scirs2_core::ndarray::ScalarOperand>
53 TransformerMetaLearning<T>
54{
55 pub fn new(config: &super::config::TransformerBasedOptimizerConfig<T>) -> Result<Self> {
57 let meta_config = config.meta_learning_config.clone();
58 let strategy = MetaLearningStrategy::MAML; let meta_optimizer = MetaOptimizer::new(&meta_config)?;
61 let adaptation_network =
62 AdaptationNetwork::new(config.model_dimension, config.feedforward_dimension)?;
63 let memory_bank = MemoryBank::new(1000, config.model_dimension)?;
64 let performance_tracker = PerformanceTracker::new();
65 let meta_state = MetaState::new(config.model_dimension)?;
66
67 Ok(Self {
68 strategy,
69 config: meta_config,
70 meta_optimizer,
71 adaptation_network,
72 memory_bank,
73 performance_tracker,
74 meta_state,
75 })
76 }
77
78 pub fn meta_step(
80 &mut self,
81 tasks: &[TaskBatch<T>],
82 support_data: &[Array2<T>],
83 query_data: &[Array2<T>],
84 ) -> Result<MetaLearningResult<T>> {
85 match self.strategy {
86 MetaLearningStrategy::MAML => self.maml_step(tasks, support_data, query_data),
87 MetaLearningStrategy::FOMAML => self.fomaml_step(tasks, support_data, query_data),
88 MetaLearningStrategy::Reptile => self.reptile_step(tasks, support_data, query_data),
89 MetaLearningStrategy::GradientBased => {
90 self.gradient_based_step(tasks, support_data, query_data)
91 }
92 MetaLearningStrategy::MemoryAugmented => {
93 self.memory_augmented_step(tasks, support_data, query_data)
94 }
95 }
96 }
97
98 fn maml_step(
100 &mut self,
101 tasks: &[TaskBatch<T>],
102 support_data: &[Array2<T>],
103 query_data: &[Array2<T>],
104 ) -> Result<MetaLearningResult<T>> {
105 let start_time = Instant::now();
106 let mut total_loss = T::zero();
107 let mut task_adaptations = Vec::new();
108
109 for (i, task) in tasks.iter().enumerate() {
110 let mut adapted_params = self.meta_state.get_parameters().clone();
112
113 for inner_step in 0..self.config.inner_steps {
114 let support_loss =
116 self.compute_task_loss(&adapted_params, &support_data[i], task)?;
117 let gradients = self.compute_gradients(&adapted_params, support_loss)?;
118
119 for (param, grad) in adapted_params.iter_mut().zip(gradients.iter()) {
121 *param = *param - self.config.inner_learning_rate * (*grad);
122 }
123 }
124
125 let query_loss = self.compute_task_loss(&adapted_params, &query_data[i], task)?;
127 total_loss = total_loss + query_loss;
128
129 task_adaptations.push(TaskAdaptation {
130 task_id: task.id.clone(),
131 adapted_parameters: adapted_params,
132 support_loss: self.compute_task_loss(
133 self.meta_state.get_parameters(),
134 &support_data[i],
135 task,
136 )?,
137 query_loss,
138 adaptation_steps: self.config.inner_steps,
139 });
140 }
141
142 let meta_loss = total_loss / T::from(tasks.len()).expect("unwrap failed");
144 let meta_gradients = self.compute_meta_gradients(&task_adaptations)?;
145 self.meta_optimizer
146 .update(&mut self.meta_state, &meta_gradients)?;
147
148 for (i, adaptation) in task_adaptations.iter().enumerate() {
150 self.memory_bank.store_experience(
151 &tasks[i],
152 &adaptation.adapted_parameters,
153 adaptation.query_loss,
154 )?;
155 }
156
157 let result = MetaLearningResult {
158 meta_loss: meta_loss.to_f64().unwrap_or(0.0),
159 task_adaptations,
160 computation_time: start_time.elapsed(),
161 convergence_rate: self.estimate_convergence_rate()?,
162 };
163
164 self.performance_tracker.record_meta_step(result.clone());
165 Ok(result)
166 }
167
168 fn fomaml_step(
170 &mut self,
171 tasks: &[TaskBatch<T>],
172 support_data: &[Array2<T>],
173 query_data: &[Array2<T>],
174 ) -> Result<MetaLearningResult<T>> {
175 self.maml_step(tasks, support_data, query_data)
178 }
179
180 fn reptile_step(
182 &mut self,
183 tasks: &[TaskBatch<T>],
184 support_data: &[Array2<T>],
185 query_data: &[Array2<T>],
186 ) -> Result<MetaLearningResult<T>> {
187 let start_time = Instant::now();
188 let mut parameter_updates = Vec::new();
189 let mut total_loss = T::zero();
190
191 for (i, task) in tasks.iter().enumerate() {
192 let mut adapted_params = self.meta_state.get_parameters().clone();
194
195 for _ in 0..self.config.inner_steps {
196 let loss = self.compute_task_loss(&adapted_params, &support_data[i], task)?;
197 let gradients = self.compute_gradients(&adapted_params, loss)?;
198
199 for (param, grad) in adapted_params.iter_mut().zip(gradients.iter()) {
200 *param = *param - self.config.inner_learning_rate * (*grad);
201 }
202 }
203
204 let original_params = self.meta_state.get_parameters();
206 let param_diff: Vec<T> = adapted_params
207 .iter()
208 .zip(original_params.iter())
209 .map(|(adapted, original)| *adapted - *original)
210 .collect();
211
212 parameter_updates.push(param_diff);
213
214 let query_loss = self.compute_task_loss(&adapted_params, &query_data[i], task)?;
216 total_loss = total_loss + query_loss;
217 }
218
219 let mut meta_update = vec![T::zero(); self.meta_state.get_parameters().len()];
221 for param_update in ¶meter_updates {
222 for (i, &update) in param_update.iter().enumerate() {
223 meta_update[i] = meta_update[i] + update;
224 }
225 }
226
227 let num_tasks = T::from(tasks.len()).expect("unwrap failed");
228 for update in meta_update.iter_mut() {
229 *update = *update / num_tasks;
230 }
231
232 self.meta_state
233 .update_parameters(&meta_update, self.config.meta_learning_rate)?;
234
235 let result = MetaLearningResult {
236 meta_loss: (total_loss / num_tasks).to_f64().unwrap_or(0.0),
237 task_adaptations: Vec::new(), computation_time: start_time.elapsed(),
239 convergence_rate: self.estimate_convergence_rate()?,
240 };
241
242 self.performance_tracker.record_meta_step(result.clone());
243 Ok(result)
244 }
245
246 fn gradient_based_step(
248 &mut self,
249 tasks: &[TaskBatch<T>],
250 support_data: &[Array2<T>],
251 query_data: &[Array2<T>],
252 ) -> Result<MetaLearningResult<T>> {
253 let start_time = Instant::now();
255 let mut total_loss = T::zero();
256
257 for (i, task) in tasks.iter().enumerate() {
258 let context_embedding = self.adaptation_network.encode_task_context(task)?;
260 let predicted_params = self
261 .adaptation_network
262 .predict_parameters(&context_embedding)?;
263
264 let mut adapted_params = predicted_params;
266 for _ in 0..self.config.inner_steps {
267 let loss = self.compute_task_loss(&adapted_params, &support_data[i], task)?;
268 let gradients = self.compute_gradients(&adapted_params, loss)?;
269
270 for (param, grad) in adapted_params.iter_mut().zip(gradients.iter()) {
271 *param = *param - self.config.inner_learning_rate * (*grad);
272 }
273 }
274
275 let query_loss = self.compute_task_loss(&adapted_params, &query_data[i], task)?;
276 total_loss = total_loss + query_loss;
277 }
278
279 let result = MetaLearningResult {
280 meta_loss: (total_loss / T::from(tasks.len()).expect("unwrap failed"))
281 .to_f64()
282 .unwrap_or(0.0),
283 task_adaptations: Vec::new(),
284 computation_time: start_time.elapsed(),
285 convergence_rate: self.estimate_convergence_rate()?,
286 };
287
288 Ok(result)
289 }
290
291 fn memory_augmented_step(
293 &mut self,
294 tasks: &[TaskBatch<T>],
295 support_data: &[Array2<T>],
296 query_data: &[Array2<T>],
297 ) -> Result<MetaLearningResult<T>> {
298 let start_time = Instant::now();
299 let mut total_loss = T::zero();
300
301 for (i, task) in tasks.iter().enumerate() {
302 let relevant_experiences = self.memory_bank.retrieve_similar_experiences(task, 5)?;
304
305 let memory_guided_params = self.initialize_from_memory(&relevant_experiences)?;
307
308 let mut adapted_params = memory_guided_params;
309 for _ in 0..self.config.inner_steps {
310 let loss = self.compute_task_loss(&adapted_params, &support_data[i], task)?;
311 let gradients = self.compute_gradients(&adapted_params, loss)?;
312
313 for (param, grad) in adapted_params.iter_mut().zip(gradients.iter()) {
314 *param = *param - self.config.inner_learning_rate * (*grad);
315 }
316 }
317
318 let query_loss = self.compute_task_loss(&adapted_params, &query_data[i], task)?;
319 total_loss = total_loss + query_loss;
320
321 self.memory_bank
323 .store_experience(task, &adapted_params, query_loss)?;
324 }
325
326 let result = MetaLearningResult {
327 meta_loss: (total_loss / T::from(tasks.len()).expect("unwrap failed"))
328 .to_f64()
329 .unwrap_or(0.0),
330 task_adaptations: Vec::new(),
331 computation_time: start_time.elapsed(),
332 convergence_rate: self.estimate_convergence_rate()?,
333 };
334
335 Ok(result)
336 }
337
338 pub fn generate_update(
340 &mut self,
341 transformer_output: &Array2<T>,
342 current_parameters: &Array1<T>,
343 ) -> Result<Array1<T>> {
344 let update = self
346 .adaptation_network
347 .generate_parameter_update(transformer_output, current_parameters)?;
348
349 let scaled_update = self.apply_meta_scaling(&update)?;
351
352 Ok(scaled_update)
353 }
354
355 pub fn update_from_loss(&mut self, loss: T) -> Result<()> {
357 self.meta_state.update_loss_history(loss);
358 self.performance_tracker
359 .record_loss(loss.to_f64().unwrap_or(0.0));
360 Ok(())
361 }
362
363 pub fn set_strategy(&mut self, strategy: MetaLearningStrategy) {
365 self.strategy = strategy;
366 }
367
368 pub fn get_strategy(&self) -> MetaLearningStrategy {
370 self.strategy
371 }
372
373 fn compute_task_loss(&self, params: &[T], data: &Array2<T>, task: &TaskBatch<T>) -> Result<T> {
375 let prediction_error = self.compute_prediction_error(params, data, task)?;
377 Ok(prediction_error)
378 }
379
380 fn compute_prediction_error(
381 &self,
382 _params: &[T],
383 data: &Array2<T>,
384 _task: &TaskBatch<T>,
385 ) -> Result<T> {
386 let mean_squared_error = data
388 .iter()
389 .map(|&x| x * x)
390 .fold(T::zero(), |acc, x| acc + x);
391 Ok(mean_squared_error / T::from(data.len()).expect("unwrap failed"))
392 }
393
394 fn compute_gradients(&self, params: &[T], loss: T) -> Result<Vec<T>> {
395 let gradients = params
397 .iter()
398 .map(|_| loss / T::from(params.len()).expect("unwrap failed"))
399 .collect();
400 Ok(gradients)
401 }
402
403 fn compute_meta_gradients(&self, adaptations: &[TaskAdaptation<T>]) -> Result<Vec<T>> {
404 let param_count = adaptations[0].adapted_parameters.len();
405 let mut meta_gradients = vec![T::zero(); param_count];
406
407 for adaptation in adaptations {
408 for (i, ¶m) in adaptation.adapted_parameters.iter().enumerate() {
409 meta_gradients[i] = meta_gradients[i] + param * adaptation.query_loss;
410 }
411 }
412
413 let num_tasks = T::from(adaptations.len()).expect("unwrap failed");
414 for grad in meta_gradients.iter_mut() {
415 *grad = *grad / num_tasks;
416 }
417
418 Ok(meta_gradients)
419 }
420
421 fn estimate_convergence_rate(&self) -> Result<f64> {
422 let loss_history = self.performance_tracker.get_loss_history();
423 if loss_history.len() < 2 {
424 return Ok(0.0);
425 }
426
427 let recent_losses: Vec<_> = loss_history.iter().rev().take(5).cloned().collect();
428 let improvement = recent_losses.last().expect("unwrap failed")
429 - recent_losses.first().expect("unwrap failed");
430 Ok(improvement.clamp(0.0, 1.0))
431 }
432
433 fn apply_meta_scaling(&self, update: &Array1<T>) -> Result<Array1<T>> {
434 let scale_factor = self.meta_state.get_scale_factor();
436 Ok(update * scale_factor)
437 }
438
439 fn initialize_from_memory(&self, experiences: &[MemoryExperience<T>]) -> Result<Vec<T>> {
440 if experiences.is_empty() {
441 return Ok(self.meta_state.get_parameters().clone());
442 }
443
444 let param_count = experiences[0].parameters.len();
446 let mut averaged_params = vec![T::zero(); param_count];
447
448 for experience in experiences {
449 for (i, ¶m) in experience.parameters.iter().enumerate() {
450 averaged_params[i] = averaged_params[i] + param;
451 }
452 }
453
454 let num_experiences = T::from(experiences.len()).expect("unwrap failed");
455 for param in averaged_params.iter_mut() {
456 *param = *param / num_experiences;
457 }
458
459 Ok(averaged_params)
460 }
461}
462
463pub struct MetaOptimizer<T: Float + Debug + Send + Sync + 'static> {
465 learning_rate: T,
467
468 momentum: Option<T>,
470
471 velocity: Option<Vec<T>>,
473}
474
475impl<T: Float + Debug + Send + Sync + 'static> MetaOptimizer<T> {
476 pub fn new(config: &MetaLearningConfig<T>) -> Result<Self> {
477 Ok(Self {
478 learning_rate: config.meta_learning_rate,
479 momentum: None,
480 velocity: None,
481 })
482 }
483
484 pub fn update(&mut self, state: &mut MetaState<T>, gradients: &[T]) -> Result<()> {
485 let params = state.get_parameters_mut();
486
487 if let Some(momentum) = self.momentum {
488 if self.velocity.is_none() {
490 self.velocity = Some(vec![T::zero(); params.len()]);
491 }
492
493 if let Some(ref mut velocity) = self.velocity {
494 for i in 0..params.len() {
495 velocity[i] = momentum * velocity[i] + self.learning_rate * gradients[i];
496 params[i] = params[i] - velocity[i];
497 }
498 }
499 } else {
500 for (param, &grad) in params.iter_mut().zip(gradients.iter()) {
502 *param = *param - self.learning_rate * grad;
503 }
504 }
505
506 Ok(())
507 }
508}
509
510pub struct AdaptationNetwork<T: Float + Debug + Send + Sync + 'static> {
512 context_encoder: FeedForwardNetwork<T>,
514
515 parameter_predictor: FeedForwardNetwork<T>,
517
518 update_generator: FeedForwardNetwork<T>,
520
521 model_dimension: usize,
523}
524
525impl<T: Float + Debug + Send + Sync + 'static> AdaptationNetwork<T> {
526 pub fn new(model_dimension: usize, hidden_dimension: usize) -> Result<Self> {
527 let context_encoder =
528 FeedForwardNetwork::new(model_dimension, hidden_dimension, ActivationFunction::ReLU)?;
529
530 let parameter_predictor =
531 FeedForwardNetwork::new(hidden_dimension, model_dimension, ActivationFunction::Tanh)?;
532
533 let update_generator = FeedForwardNetwork::new(
534 model_dimension * 2, model_dimension,
536 ActivationFunction::ReLU,
537 )?;
538
539 Ok(Self {
540 context_encoder,
541 parameter_predictor,
542 update_generator,
543 model_dimension,
544 })
545 }
546
547 pub fn encode_task_context(&mut self, _task: &TaskBatch<T>) -> Result<Array2<T>> {
548 let task_features = Array2::ones((1, self.model_dimension));
550 self.context_encoder.forward(&task_features)
551 }
552
553 pub fn predict_parameters(&mut self, context: &Array2<T>) -> Result<Vec<T>> {
554 let predicted = self.parameter_predictor.forward(context)?;
555 Ok(predicted.row(0).to_vec())
556 }
557
558 pub fn generate_parameter_update(
559 &mut self,
560 transformer_output: &Array2<T>,
561 current_parameters: &Array1<T>,
562 ) -> Result<Array1<T>> {
563 let batch_size = transformer_output.shape()[0];
565 let mut input = Array2::zeros((batch_size, self.model_dimension * 2));
566
567 for i in 0..batch_size {
568 for j in 0..self.model_dimension {
569 input[[i, j]] = transformer_output[[i, j]];
570 if j < current_parameters.len() {
571 input[[i, j + self.model_dimension]] = current_parameters[j];
572 }
573 }
574 }
575
576 let update = self.update_generator.forward(&input)?;
577 Ok(update.row(0).to_owned())
578 }
579}
580
581pub struct MemoryBank<T: Float + Debug + Send + Sync + 'static> {
583 experiences: VecDeque<MemoryExperience<T>>,
585
586 max_size: usize,
588
589 parameter_dimension: usize,
591}
592
593impl<T: Float + Debug + Send + Sync + 'static> MemoryBank<T> {
594 pub fn new(max_size: usize, parameter_dimension: usize) -> Result<Self> {
595 Ok(Self {
596 experiences: VecDeque::new(),
597 max_size,
598 parameter_dimension,
599 })
600 }
601
602 pub fn store_experience(
603 &mut self,
604 task: &TaskBatch<T>,
605 parameters: &[T],
606 performance: T,
607 ) -> Result<()> {
608 let experience = MemoryExperience {
609 task_signature: self.compute_task_signature(task),
610 parameters: parameters.to_vec(),
611 performance: performance.to_f64().unwrap_or(0.0),
612 timestamp: Instant::now(),
613 };
614
615 self.experiences.push_back(experience);
616
617 if self.experiences.len() > self.max_size {
618 self.experiences.pop_front();
619 }
620
621 Ok(())
622 }
623
624 pub fn retrieve_similar_experiences(
625 &self,
626 task: &TaskBatch<T>,
627 k: usize,
628 ) -> Result<Vec<MemoryExperience<T>>> {
629 let target_signature = self.compute_task_signature(task);
630
631 let mut scored_experiences: Vec<_> = self
632 .experiences
633 .iter()
634 .map(|exp| {
635 let similarity = self.compute_similarity(&target_signature, &exp.task_signature);
636 (similarity, exp.clone())
637 })
638 .collect();
639
640 scored_experiences.sort_by(|a, b| b.0.partial_cmp(&a.0).expect("unwrap failed"));
641
642 Ok(scored_experiences
643 .into_iter()
644 .take(k)
645 .map(|(_, exp)| exp)
646 .collect())
647 }
648
649 fn compute_task_signature(&self, task: &TaskBatch<T>) -> Vec<f64> {
650 vec![task.difficulty, task.complexity, task.data_characteristics]
652 }
653
654 fn compute_similarity(&self, sig1: &[f64], sig2: &[f64]) -> f64 {
655 if sig1.len() != sig2.len() {
656 return 0.0;
657 }
658
659 let dot_product: f64 = sig1.iter().zip(sig2.iter()).map(|(a, b)| a * b).sum();
660 let norm1: f64 = sig1.iter().map(|x| x * x).sum::<f64>().sqrt();
661 let norm2: f64 = sig2.iter().map(|x| x * x).sum::<f64>().sqrt();
662
663 if norm1 == 0.0 || norm2 == 0.0 {
664 0.0
665 } else {
666 dot_product / (norm1 * norm2)
667 }
668 }
669}
670
671#[derive(Debug, Clone)]
673pub struct TaskBatch<T: Float + Debug + Send + Sync + 'static> {
674 pub id: String,
675 pub difficulty: f64,
676 pub complexity: f64,
677 pub data_characteristics: f64,
678 pub _phantom: std::marker::PhantomData<T>,
679}
680
681#[derive(Debug, Clone)]
682pub struct TaskAdaptation<T: Float + Debug + Send + Sync + 'static> {
683 pub task_id: String,
684 pub adapted_parameters: Vec<T>,
685 pub support_loss: T,
686 pub query_loss: T,
687 pub adaptation_steps: usize,
688}
689
690#[derive(Debug, Clone)]
691pub struct MetaLearningResult<T: Float + Debug + Send + Sync + 'static> {
692 pub meta_loss: f64,
693 pub task_adaptations: Vec<TaskAdaptation<T>>,
694 pub computation_time: std::time::Duration,
695 pub convergence_rate: f64,
696}
697
698#[derive(Debug, Clone)]
699pub struct MemoryExperience<T: Float + Debug + Send + Sync + 'static> {
700 pub task_signature: Vec<f64>,
701 pub parameters: Vec<T>,
702 pub performance: f64,
703 pub timestamp: Instant,
704}
705
706pub struct PerformanceTracker<T: Float + Debug + Send + Sync + 'static> {
707 loss_history: VecDeque<f64>,
708 meta_results: VecDeque<MetaLearningResult<T>>,
709}
710
711impl<T: Float + Debug + Send + Sync + 'static> Default for PerformanceTracker<T> {
712 fn default() -> Self {
713 Self::new()
714 }
715}
716
717impl<T: Float + Debug + Send + Sync + 'static> PerformanceTracker<T> {
718 pub fn new() -> Self {
719 Self {
720 loss_history: VecDeque::new(),
721 meta_results: VecDeque::new(),
722 }
723 }
724
725 pub fn record_loss(&mut self, loss: f64) {
726 self.loss_history.push_back(loss);
727 if self.loss_history.len() > 1000 {
728 self.loss_history.pop_front();
729 }
730 }
731
732 pub fn record_meta_step(&mut self, result: MetaLearningResult<T>) {
733 self.meta_results.push_back(result);
734 if self.meta_results.len() > 100 {
735 self.meta_results.pop_front();
736 }
737 }
738
739 pub fn get_loss_history(&self) -> &VecDeque<f64> {
740 &self.loss_history
741 }
742}
743
744#[derive(Debug, Clone)]
745pub struct MetaState<T: Float + Debug + Send + Sync + 'static> {
746 parameters: Vec<T>,
747 loss_history: VecDeque<T>,
748 scale_factor: T,
749}
750
751impl<T: Float + Debug + Send + Sync + 'static> MetaState<T> {
752 pub fn new(parameter_count: usize) -> Result<Self> {
753 Ok(Self {
754 parameters: vec![T::zero(); parameter_count],
755 loss_history: VecDeque::new(),
756 scale_factor: T::one(),
757 })
758 }
759
760 pub fn get_parameters(&self) -> &Vec<T> {
761 &self.parameters
762 }
763
764 pub fn get_parameters_mut(&mut self) -> &mut Vec<T> {
765 &mut self.parameters
766 }
767
768 pub fn update_parameters(&mut self, updates: &[T], learning_rate: T) -> Result<()> {
769 for (param, &update) in self.parameters.iter_mut().zip(updates.iter()) {
770 *param = *param + learning_rate * update;
771 }
772 Ok(())
773 }
774
775 pub fn update_loss_history(&mut self, loss: T) {
776 self.loss_history.push_back(loss);
777 if self.loss_history.len() > 100 {
778 self.loss_history.pop_front();
779 }
780 }
781
782 pub fn get_scale_factor(&self) -> T {
783 self.scale_factor
784 }
785}
786
787#[cfg(test)]
788mod tests {
789 use super::*;
790
791 #[test]
792 #[ignore]
793 fn test_meta_learning_creation() {
794 let config = super::super::config::TransformerBasedOptimizerConfig::<f32>::default();
795 let meta_learning = TransformerMetaLearning::new(&config);
796 assert!(meta_learning.is_ok());
797 }
798
799 #[test]
800 fn test_memory_bank() {
801 let memory = MemoryBank::<f32>::new(100, 64);
802 assert!(memory.is_ok());
803
804 let mut bank = memory.expect("unwrap failed");
805 let task = TaskBatch {
806 id: "test".to_string(),
807 difficulty: 0.5,
808 complexity: 0.7,
809 data_characteristics: 0.3,
810 _phantom: std::marker::PhantomData,
811 };
812
813 let params = vec![0.1f32; 64];
814 assert!(bank.store_experience(&task, ¶ms, 0.8).is_ok());
815 }
816
817 #[test]
818 fn test_adaptation_network() {
819 let network = AdaptationNetwork::<f32>::new(128, 256);
820 assert!(network.is_ok());
821
822 let mut net = network.expect("unwrap failed");
823 let task = TaskBatch {
824 id: "test".to_string(),
825 difficulty: 0.5,
826 complexity: 0.7,
827 data_characteristics: 0.3,
828 _phantom: std::marker::PhantomData,
829 };
830
831 let context = net.encode_task_context(&task);
832 assert!(context.is_ok());
833 }
834}