oxirs_embed/vision_language_graph/
meta_learner.rs1use super::*;
4use anyhow::Result;
5use scirs2_core::ndarray_ext::{Array1, Array2};
6use scirs2_core::random::{Random, Rng};
7use std::collections::HashMap;
8#[derive(Debug)]
9pub struct MetaLearner {
10 pub config: MetaLearningConfig,
11 pub meta_parameters: HashMap<String, Array2<f32>>,
13 pub task_parameters: HashMap<String, Array2<f32>>,
15}
16
17impl MetaLearner {
18 pub fn new(config: MetaLearningConfig) -> Self {
19 let mut meta_parameters = HashMap::new();
20 let mut task_parameters = HashMap::new();
21
22 let mut random = Random::default();
24 meta_parameters.insert(
25 "meta_weights".to_string(),
26 Array2::from_shape_fn((512, 512), |_| (random.random::<f32>() - 0.5) * 0.1),
27 );
28
29 let mut random = Random::default();
30 task_parameters.insert(
31 "adaptation_weights".to_string(),
32 Array2::from_shape_fn((256, 512), |_| (random.random::<f32>() - 0.5) * 0.1),
33 );
34
35 Self {
36 config,
37 meta_parameters,
38 task_parameters,
39 }
40 }
41
42 pub fn adapt_to_task(
44 &mut self,
45 support_set: &[(Array1<f32>, Array1<f32>)],
46 _query_set: &[(Array1<f32>, Array1<f32>)],
47 ) -> Result<HashMap<String, Array2<f32>>> {
48 match self.config.algorithm {
49 MetaLearningAlgorithm::MAML => self.maml_adaptation(support_set),
50 MetaLearningAlgorithm::ProtoNet => self.prototypical_adaptation(support_set),
51 _ => self.maml_adaptation(support_set),
52 }
53 }
54
55 fn maml_adaptation(
57 &mut self,
58 support_set: &[(Array1<f32>, Array1<f32>)],
59 ) -> Result<HashMap<String, Array2<f32>>> {
60 let mut adapted_params = self.meta_parameters.clone();
61
62 for _step in 0..self.config.adaptation_steps {
64 for (input, _target) in support_set {
66 if let Some(weights) = adapted_params.get_mut("meta_weights") {
67 let _output = weights.dot(input);
69
70 *weights = &*weights * 0.99; }
73 }
74 }
75
76 Ok(adapted_params)
77 }
78
79 fn prototypical_adaptation(
81 &self,
82 support_set: &[(Array1<f32>, Array1<f32>)],
83 ) -> Result<HashMap<String, Array2<f32>>> {
84 let mut prototypes = HashMap::new();
86 let mut class_counts = HashMap::new();
87
88 for (input, target) in support_set {
89 let class_id = target[0] as i32;
91
92 let class_key = class_id.to_string();
93 let prototype = prototypes
94 .entry(class_key.clone())
95 .or_insert(Array1::zeros(input.len()));
96 let count = class_counts.entry(class_key).or_insert(0);
97
98 *prototype = &*prototype + input;
99 *count += 1;
100 }
101
102 for (class_key, count) in class_counts {
104 if let Some(prototype) = prototypes.get_mut(&class_key) {
105 *prototype /= count as f32;
106 }
107 }
108
109 Ok(self.meta_parameters.clone())
111 }
112}