oxirs_embed/multimodal/impl/
learning.rs

1//! Few-shot learning and meta-learning components for multi-modal embeddings
2
3use super::model::MultiModalEmbedding;
4use anyhow::{anyhow, Result};
5use scirs2_core::ndarray_ext::{Array1, Array2};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9/// Few-shot learning module for rapid adaptation
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct FewShotLearning {
12    /// Support set size
13    pub support_size: usize,
14    /// Query set size
15    pub query_size: usize,
16    /// Number of ways (classes/entities)
17    pub num_ways: usize,
18    /// Meta-learning algorithm
19    pub meta_algorithm: MetaAlgorithm,
20    /// Adaptation parameters
21    pub adaptation_config: AdaptationConfig,
22    /// Prototypical network
23    pub prototypical_network: PrototypicalNetwork,
24    /// Model-agnostic meta-learning (MAML) components
25    pub maml_components: MAMLComponents,
26}
27
28/// Meta-learning algorithms for few-shot learning
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub enum MetaAlgorithm {
31    /// Prototypical Networks
32    PrototypicalNetworks,
33    /// Model-Agnostic Meta-Learning
34    MAML,
35    /// Reptile algorithm
36    Reptile,
37    /// Matching Networks
38    MatchingNetworks,
39    /// Relation Networks
40    RelationNetworks,
41    /// Memory-Augmented Neural Networks
42    MANN,
43}
44
45/// Adaptation configuration
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct AdaptationConfig {
48    /// Learning rate for few-shot adaptation
49    pub adaptation_lr: f32,
50    /// Number of adaptation steps
51    pub adaptation_steps: usize,
52    /// Gradient clipping threshold
53    pub gradient_clip: f32,
54    /// Use second-order gradients (for MAML)
55    pub second_order: bool,
56    /// Temperature for prototypical networks
57    pub temperature: f32,
58}
59
60impl Default for AdaptationConfig {
61    fn default() -> Self {
62        Self {
63            adaptation_lr: 0.01,
64            adaptation_steps: 5,
65            gradient_clip: 1.0,
66            second_order: true,
67            temperature: 1.0,
68        }
69    }
70}
71
72/// Prototypical network for few-shot learning
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct PrototypicalNetwork {
75    /// Feature extractor parameters
76    pub feature_extractor: HashMap<String, Array2<f32>>,
77    /// Prototype computation method
78    pub prototype_method: PrototypeMethod,
79    /// Distance metric
80    pub distance_metric: DistanceMetric,
81}
82
83/// Prototype computation methods
84#[derive(Debug, Clone, Serialize, Deserialize)]
85pub enum PrototypeMethod {
86    /// Simple mean of support examples
87    Mean,
88    /// Weighted mean with attention
89    AttentionWeighted,
90    /// Learnable prototype aggregation
91    LearnableAggregation,
92}
93
94/// Distance metrics for prototype comparison
95#[derive(Debug, Clone, Serialize, Deserialize)]
96pub enum DistanceMetric {
97    /// Euclidean distance
98    Euclidean,
99    /// Cosine distance
100    Cosine,
101    /// Learned distance metric
102    Learned,
103    /// Mahalanobis distance
104    Mahalanobis,
105}
106
107/// MAML components for meta-learning
108#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct MAMLComponents {
110    /// Inner loop parameters
111    pub inner_loop_params: HashMap<String, Array2<f32>>,
112    /// Outer loop parameters
113    pub outer_loop_params: HashMap<String, Array2<f32>>,
114    /// Meta-gradients
115    pub meta_gradients: HashMap<String, Array2<f32>>,
116    /// Task-specific adaptations
117    pub task_adaptations: HashMap<String, HashMap<String, Array2<f32>>>,
118}
119
120impl Default for FewShotLearning {
121    fn default() -> Self {
122        Self {
123            support_size: 5,
124            query_size: 15,
125            num_ways: 3,
126            meta_algorithm: MetaAlgorithm::PrototypicalNetworks,
127            adaptation_config: AdaptationConfig::default(),
128            prototypical_network: PrototypicalNetwork::default(),
129            maml_components: MAMLComponents::default(),
130        }
131    }
132}
133
134impl Default for PrototypicalNetwork {
135    fn default() -> Self {
136        let mut feature_extractor = HashMap::new();
137        feature_extractor.insert(
138            "conv1".to_string(),
139            Array2::from_shape_fn((64, 32), |(_, _)| {
140                use scirs2_core::random::{Random, Rng};
141                let mut random = Random::default();
142                (random.random::<f32>() - 0.5) * 0.1
143            }),
144        );
145        feature_extractor.insert(
146            "conv2".to_string(),
147            Array2::from_shape_fn((128, 64), |(_, _)| {
148                use scirs2_core::random::{Random, Rng};
149                let mut random = Random::default();
150                (random.random::<f32>() - 0.5) * 0.1
151            }),
152        );
153        feature_extractor.insert(
154            "fc".to_string(),
155            Array2::from_shape_fn((256, 128), |(_, _)| {
156                use scirs2_core::random::{Random, Rng};
157                let mut random = Random::default();
158                (random.random::<f32>() - 0.5) * 0.1
159            }),
160        );
161
162        Self {
163            feature_extractor,
164            prototype_method: PrototypeMethod::Mean,
165            distance_metric: DistanceMetric::Euclidean,
166        }
167    }
168}
169
170impl Default for MAMLComponents {
171    fn default() -> Self {
172        let mut inner_params = HashMap::new();
173        let mut outer_params = HashMap::new();
174        let mut meta_grads = HashMap::new();
175
176        for layer in ["layer1", "layer2", "output"] {
177            inner_params.insert(
178                layer.to_string(),
179                Array2::from_shape_fn((128, 128), |(_, _)| {
180                    use scirs2_core::random::{Random, Rng};
181                    let mut random = Random::default();
182                    (random.random::<f32>() - 0.5) * 0.1
183                }),
184            );
185            outer_params.insert(
186                layer.to_string(),
187                Array2::from_shape_fn((128, 128), |(_, _)| {
188                    use scirs2_core::random::{Random, Rng};
189                    let mut random = Random::default();
190                    (random.random::<f32>() - 0.5) * 0.1
191                }),
192            );
193            meta_grads.insert(layer.to_string(), Array2::zeros((128, 128)));
194        }
195
196        Self {
197            inner_loop_params: inner_params,
198            outer_loop_params: outer_params,
199            meta_gradients: meta_grads,
200            task_adaptations: HashMap::new(),
201        }
202    }
203}
204
205impl FewShotLearning {
206    /// Create new few-shot learning module
207    pub fn new(
208        support_size: usize,
209        query_size: usize,
210        num_ways: usize,
211        meta_algorithm: MetaAlgorithm,
212    ) -> Self {
213        Self {
214            support_size,
215            query_size,
216            num_ways,
217            meta_algorithm,
218            adaptation_config: AdaptationConfig::default(),
219            prototypical_network: PrototypicalNetwork::default(),
220            maml_components: MAMLComponents::default(),
221        }
222    }
223
224    /// Perform few-shot adaptation
225    pub async fn few_shot_adapt(
226        &mut self,
227        support_examples: &[(String, String, String)], // (text, entity, label)
228        query_examples: &[(String, String)],           // (text, entity)
229        model: &MultiModalEmbedding,
230    ) -> Result<Vec<(String, f32)>> {
231        match self.meta_algorithm {
232            MetaAlgorithm::PrototypicalNetworks => {
233                self.prototypical_adapt(support_examples, query_examples, model)
234                    .await
235            }
236            MetaAlgorithm::MAML => {
237                self.maml_adapt(support_examples, query_examples, model)
238                    .await
239            }
240            MetaAlgorithm::Reptile => {
241                self.reptile_adapt(support_examples, query_examples, model)
242                    .await
243            }
244            _ => {
245                // Fallback to prototypical networks
246                self.prototypical_adapt(support_examples, query_examples, model)
247                    .await
248            }
249        }
250    }
251
252    /// Prototypical networks adaptation
253    async fn prototypical_adapt(
254        &mut self,
255        support_examples: &[(String, String, String)],
256        query_examples: &[(String, String)],
257        model: &MultiModalEmbedding,
258    ) -> Result<Vec<(String, f32)>> {
259        // Extract features for support examples
260        let mut prototypes = HashMap::new();
261        let mut label_embeddings: HashMap<String, Vec<Array1<f32>>> = HashMap::new();
262
263        for (text, entity, label) in support_examples {
264            let text_emb = model.text_encoder.encode(text)?;
265            let kg_emb_raw = model.get_or_create_kg_embedding(entity)?;
266            let kg_emb = model.kg_encoder.encode_entity(&kg_emb_raw)?;
267
268            // Combine text and KG embeddings
269            let combined_emb = &text_emb + &kg_emb;
270
271            label_embeddings
272                .entry(label.clone())
273                .or_default()
274                .push(combined_emb);
275        }
276
277        // Compute prototypes
278        for (label, embeddings) in &label_embeddings {
279            let prototype = self.compute_prototype(embeddings)?;
280            prototypes.insert(label.clone(), prototype);
281        }
282
283        // Classify query examples
284        let mut predictions = Vec::new();
285        for (text, entity) in query_examples {
286            let text_emb = model.text_encoder.encode(text)?;
287            let kg_emb_raw = model.get_or_create_kg_embedding(entity)?;
288            let kg_emb = model.kg_encoder.encode_entity(&kg_emb_raw)?;
289
290            let query_emb = &text_emb + &kg_emb;
291
292            let mut best_score = f32::NEG_INFINITY;
293            let mut best_label = String::new();
294
295            for (label, prototype) in &prototypes {
296                let distance = self.compute_distance(&query_emb, prototype);
297                let score = (-distance / self.adaptation_config.temperature).exp();
298
299                if score > best_score {
300                    best_score = score;
301                    best_label = label.clone();
302                }
303            }
304
305            predictions.push((best_label, best_score));
306        }
307
308        Ok(predictions)
309    }
310
311    /// MAML adaptation
312    async fn maml_adapt(
313        &mut self,
314        support_examples: &[(String, String, String)],
315        query_examples: &[(String, String)],
316        model: &MultiModalEmbedding,
317    ) -> Result<Vec<(String, f32)>> {
318        let task_id = {
319            use scirs2_core::random::{Random, Rng};
320            let mut random = Random::default();
321            format!("task_{}", random.random::<u32>())
322        };
323
324        // Initialize task-specific parameters
325        let mut task_params = HashMap::new();
326        for (layer_name, params) in &self.maml_components.inner_loop_params {
327            task_params.insert(layer_name.clone(), params.clone());
328        }
329
330        // Inner loop: adapt on support set
331        for _ in 0..self.adaptation_config.adaptation_steps {
332            let mut gradients = HashMap::new();
333
334            // Compute gradients on support set
335            for (text, entity, label) in support_examples {
336                let text_emb = model.text_encoder.encode(text)?;
337                let kg_emb_raw = model.get_or_create_kg_embedding(entity)?;
338                let kg_emb = model.kg_encoder.encode_entity(&kg_emb_raw)?;
339
340                let input_emb = &text_emb + &kg_emb;
341                let predicted = self.forward_pass(&input_emb, &task_params)?;
342
343                // Compute loss and gradients (simplified)
344                let target = self.label_to_target(label)?;
345                let loss_grad = &predicted - &target;
346
347                // Accumulate gradients
348                for layer_name in task_params.keys() {
349                    let grad = self.compute_layer_gradient(&input_emb, &loss_grad, layer_name)?;
350                    *gradients
351                        .entry(layer_name.clone())
352                        .or_insert_with(|| Array2::zeros(grad.dim())) += &grad;
353                }
354            }
355
356            // Update task parameters
357            for (layer_name, params) in &mut task_params {
358                if let Some(grad) = gradients.get(layer_name) {
359                    *params = &*params - &(grad * self.adaptation_config.adaptation_lr);
360                }
361            }
362        }
363
364        // Store task adaptation
365        self.maml_components
366            .task_adaptations
367            .insert(task_id.clone(), task_params.clone());
368
369        // Evaluate on query set
370        let mut predictions = Vec::new();
371        for (text, entity) in query_examples {
372            let text_emb = model.text_encoder.encode(text)?;
373            let kg_emb_raw = model.get_or_create_kg_embedding(entity)?;
374            let kg_emb = model.kg_encoder.encode_entity(&kg_emb_raw)?;
375
376            let query_emb = &text_emb + &kg_emb;
377            let output = self.forward_pass(&query_emb, &task_params)?;
378
379            // Convert output to prediction
380            let (predicted_label, confidence) = self.output_to_prediction(&output)?;
381            predictions.push((predicted_label, confidence));
382        }
383
384        Ok(predictions)
385    }
386
387    /// Reptile adaptation
388    async fn reptile_adapt(
389        &mut self,
390        support_examples: &[(String, String, String)],
391        query_examples: &[(String, String)],
392        model: &MultiModalEmbedding,
393    ) -> Result<Vec<(String, f32)>> {
394        // Reptile is similar to MAML but uses first-order gradients
395        let mut adapted_params = HashMap::new();
396
397        // Initialize with current parameters
398        for (layer_name, params) in &self.maml_components.outer_loop_params {
399            adapted_params.insert(layer_name.clone(), params.clone());
400        }
401
402        // Adapt on support set with multiple steps
403        for _ in 0..self.adaptation_config.adaptation_steps {
404            let mut param_updates = HashMap::new();
405
406            for (text, entity, label) in support_examples {
407                let text_emb = model.text_encoder.encode(text)?;
408                let kg_emb_raw = model.get_or_create_kg_embedding(entity)?;
409                let kg_emb = model.kg_encoder.encode_entity(&kg_emb_raw)?;
410
411                let input_emb = &text_emb + &kg_emb;
412                let predicted = self.forward_pass(&input_emb, &adapted_params)?;
413
414                // Simple gradient approximation
415                let target = self.label_to_target(label)?;
416                let error = &predicted - &target;
417
418                // Update parameters toward reducing error
419                for (layer_name, params) in &adapted_params {
420                    let update = &error * self.adaptation_config.adaptation_lr;
421                    let param_change = Array2::from_shape_fn(params.dim(), |(i, j)| {
422                        if i < update.len() && j < params.dim().1 {
423                            update[i] * params[(i, j)]
424                        } else {
425                            0.0
426                        }
427                    });
428
429                    *param_updates
430                        .entry(layer_name.clone())
431                        .or_insert_with(|| Array2::zeros(params.dim())) += &param_change;
432                }
433            }
434
435            // Apply updates
436            for (layer_name, params) in &mut adapted_params {
437                if let Some(update) = param_updates.get(layer_name) {
438                    *params = &*params - update;
439                }
440            }
441        }
442
443        // Evaluate on query set
444        let mut predictions = Vec::new();
445        for (text, entity) in query_examples {
446            let text_emb = model.text_encoder.encode(text)?;
447            let kg_emb_raw = model.get_or_create_kg_embedding(entity)?;
448            let kg_emb = model.kg_encoder.encode_entity(&kg_emb_raw)?;
449
450            let query_emb = &text_emb + &kg_emb;
451            let output = self.forward_pass(&query_emb, &adapted_params)?;
452
453            let (predicted_label, confidence) = self.output_to_prediction(&output)?;
454            predictions.push((predicted_label, confidence));
455        }
456
457        Ok(predictions)
458    }
459
460    /// Compute prototype from embeddings
461    pub fn compute_prototype(&self, embeddings: &[Array1<f32>]) -> Result<Array1<f32>> {
462        if embeddings.is_empty() {
463            return Err(anyhow!("Cannot compute prototype from empty embeddings"));
464        }
465
466        match self.prototypical_network.prototype_method {
467            PrototypeMethod::Mean => {
468                let mut prototype = Array1::zeros(embeddings[0].len());
469                for emb in embeddings {
470                    prototype = &prototype + emb;
471                }
472                prototype /= embeddings.len() as f32;
473                Ok(prototype)
474            }
475            PrototypeMethod::AttentionWeighted => {
476                // Compute attention-weighted prototype
477                let mut weights = Vec::new();
478                let mut weight_sum = 0.0;
479
480                for emb in embeddings {
481                    let weight = emb.dot(emb).sqrt(); // Use norm as attention weight
482                    weights.push(weight);
483                    weight_sum += weight;
484                }
485
486                let mut prototype = Array1::zeros(embeddings[0].len());
487                for (emb, &weight) in embeddings.iter().zip(weights.iter()) {
488                    prototype = &prototype + &(emb * (weight / weight_sum));
489                }
490                Ok(prototype)
491            }
492            PrototypeMethod::LearnableAggregation => {
493                // Use learnable aggregation (simplified)
494                let mut prototype = Array1::zeros(embeddings[0].len());
495                for (i, emb) in embeddings.iter().enumerate() {
496                    let weight = 1.0 / (1.0 + i as f32); // Decay weight
497                    prototype = &prototype + &(emb * weight);
498                }
499                let total_weight: f32 = (0..embeddings.len()).map(|i| 1.0 / (1.0 + i as f32)).sum();
500                prototype /= total_weight;
501                Ok(prototype)
502            }
503        }
504    }
505
506    /// Compute distance between embeddings
507    pub fn compute_distance(&self, emb1: &Array1<f32>, emb2: &Array1<f32>) -> f32 {
508        match self.prototypical_network.distance_metric {
509            DistanceMetric::Euclidean => {
510                let diff = emb1 - emb2;
511                diff.dot(&diff).sqrt()
512            }
513            DistanceMetric::Cosine => {
514                let dot_product = emb1.dot(emb2);
515                let norm1 = emb1.dot(emb1).sqrt();
516                let norm2 = emb2.dot(emb2).sqrt();
517                if norm1 > 0.0 && norm2 > 0.0 {
518                    1.0 - (dot_product / (norm1 * norm2))
519                } else {
520                    1.0
521                }
522            }
523            DistanceMetric::Learned => {
524                // Use learned distance metric (simplified)
525                let diff = emb1 - emb2;
526                diff.mapv(|x| x.abs()).sum()
527            }
528            DistanceMetric::Mahalanobis => {
529                // Simplified Mahalanobis distance
530                let diff = emb1 - emb2;
531                diff.dot(&diff).sqrt()
532            }
533        }
534    }
535
536    /// Forward pass through adapted network
537    fn forward_pass(
538        &self,
539        input: &Array1<f32>,
540        params: &HashMap<String, Array2<f32>>,
541    ) -> Result<Array1<f32>> {
542        let mut output = input.clone();
543
544        // Simple feedforward network
545        for layer_name in ["layer1", "layer2", "output"] {
546            if let Some(weights) = params.get(layer_name) {
547                output = weights.dot(&output);
548                if layer_name != "output" {
549                    output = output.mapv(|x| x.max(0.0)); // ReLU
550                }
551            }
552        }
553
554        Ok(output)
555    }
556
557    /// Convert label to target vector
558    fn label_to_target(&self, label: &str) -> Result<Array1<f32>> {
559        // Simple one-hot encoding based on label hash
560        let label_hash = label.chars().map(|c| c as u8).sum::<u8>() as usize;
561        let target_dim = 128; // Fixed target dimension
562        let mut target = Array1::zeros(target_dim);
563        target[label_hash % target_dim] = 1.0;
564        Ok(target)
565    }
566
567    /// Compute layer gradient
568    fn compute_layer_gradient(
569        &self,
570        input: &Array1<f32>,
571        loss_grad: &Array1<f32>,
572        _layer_name: &str,
573    ) -> Result<Array2<f32>> {
574        // Simplified gradient computation
575        let input_len = input.len();
576        let grad_len = loss_grad.len();
577        let mut gradient = Array2::zeros((grad_len.min(128), input_len.min(128)));
578
579        for i in 0..gradient.nrows() {
580            for j in 0..gradient.ncols() {
581                if i < loss_grad.len() && j < input.len() {
582                    gradient[(i, j)] = loss_grad[i] * input[j];
583                }
584            }
585        }
586
587        Ok(gradient)
588    }
589
590    /// Convert output to prediction
591    fn output_to_prediction(&self, output: &Array1<f32>) -> Result<(String, f32)> {
592        // Find the index with maximum value
593        let (max_idx, &max_val) = output
594            .iter()
595            .enumerate()
596            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
597            .unwrap_or((0, &0.0));
598
599        // Convert index to label
600        let label = format!("class_{max_idx}");
601        let confidence = 1.0 / (1.0 + (-max_val).exp()); // Sigmoid
602
603        Ok((label, confidence))
604    }
605
606    /// Meta-update for improving few-shot performance
607    pub fn meta_update(&mut self, tasks: &[Vec<(String, String, String)>]) -> Result<()> {
608        match self.meta_algorithm {
609            MetaAlgorithm::MAML => {
610                // Update outer loop parameters based on task performance
611                let mut meta_gradients = HashMap::new();
612
613                for _task in tasks {
614                    // Simulate task-specific adaptation
615                    for layer_name in self.maml_components.outer_loop_params.keys() {
616                        let grad = Array2::from_shape_fn((128, 128), |(_, _)| {
617                            use scirs2_core::random::{Random, Rng};
618                            let mut random = Random::default();
619                            (random.random::<f32>() - 0.5) * 0.01
620                        });
621                        *meta_gradients
622                            .entry(layer_name.clone())
623                            .or_insert_with(|| Array2::zeros((128, 128))) += &grad;
624                    }
625                }
626
627                // Apply meta-gradients
628                for (layer_name, params) in &mut self.maml_components.outer_loop_params {
629                    if let Some(meta_grad) = meta_gradients.get(layer_name) {
630                        *params = &*params - &(meta_grad * self.adaptation_config.adaptation_lr);
631                    }
632                }
633            }
634            MetaAlgorithm::Reptile => {
635                // Reptile meta-update
636                for _task in tasks {
637                    // Simulate task adaptation and update toward adapted parameters
638                    for params in self.maml_components.outer_loop_params.values_mut() {
639                        let update = Array2::from_shape_fn(params.dim(), |(_, _)| {
640                            use scirs2_core::random::{Random, Rng};
641                            let mut random = Random::default();
642                            (random.random::<f32>() - 0.5) * 0.001
643                        });
644                        *params = &*params + &update;
645                    }
646                }
647            }
648            _ => {
649                // For prototypical networks, update feature extractor
650                for params in self.prototypical_network.feature_extractor.values_mut() {
651                    let update = Array2::from_shape_fn(params.dim(), |(_, _)| {
652                        use scirs2_core::random::{Random, Rng};
653                        let mut random = Random::default();
654                        (random.random::<f32>() - 0.5) * 0.001
655                    });
656                    *params = &*params + &update;
657                }
658            }
659        }
660
661        Ok(())
662    }
663}