1use super::model::MultiModalEmbedding;
4use anyhow::{anyhow, Result};
5use scirs2_core::ndarray_ext::{Array1, Array2};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct FewShotLearning {
12 pub support_size: usize,
14 pub query_size: usize,
16 pub num_ways: usize,
18 pub meta_algorithm: MetaAlgorithm,
20 pub adaptation_config: AdaptationConfig,
22 pub prototypical_network: PrototypicalNetwork,
24 pub maml_components: MAMLComponents,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub enum MetaAlgorithm {
31 PrototypicalNetworks,
33 MAML,
35 Reptile,
37 MatchingNetworks,
39 RelationNetworks,
41 MANN,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct AdaptationConfig {
48 pub adaptation_lr: f32,
50 pub adaptation_steps: usize,
52 pub gradient_clip: f32,
54 pub second_order: bool,
56 pub temperature: f32,
58}
59
60impl Default for AdaptationConfig {
61 fn default() -> Self {
62 Self {
63 adaptation_lr: 0.01,
64 adaptation_steps: 5,
65 gradient_clip: 1.0,
66 second_order: true,
67 temperature: 1.0,
68 }
69 }
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct PrototypicalNetwork {
75 pub feature_extractor: HashMap<String, Array2<f32>>,
77 pub prototype_method: PrototypeMethod,
79 pub distance_metric: DistanceMetric,
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
85pub enum PrototypeMethod {
86 Mean,
88 AttentionWeighted,
90 LearnableAggregation,
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
96pub enum DistanceMetric {
97 Euclidean,
99 Cosine,
101 Learned,
103 Mahalanobis,
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct MAMLComponents {
110 pub inner_loop_params: HashMap<String, Array2<f32>>,
112 pub outer_loop_params: HashMap<String, Array2<f32>>,
114 pub meta_gradients: HashMap<String, Array2<f32>>,
116 pub task_adaptations: HashMap<String, HashMap<String, Array2<f32>>>,
118}
119
120impl Default for FewShotLearning {
121 fn default() -> Self {
122 Self {
123 support_size: 5,
124 query_size: 15,
125 num_ways: 3,
126 meta_algorithm: MetaAlgorithm::PrototypicalNetworks,
127 adaptation_config: AdaptationConfig::default(),
128 prototypical_network: PrototypicalNetwork::default(),
129 maml_components: MAMLComponents::default(),
130 }
131 }
132}
133
134impl Default for PrototypicalNetwork {
135 fn default() -> Self {
136 let mut feature_extractor = HashMap::new();
137 feature_extractor.insert(
138 "conv1".to_string(),
139 Array2::from_shape_fn((64, 32), |(_, _)| {
140 use scirs2_core::random::{Random, Rng};
141 let mut random = Random::default();
142 (random.random::<f32>() - 0.5) * 0.1
143 }),
144 );
145 feature_extractor.insert(
146 "conv2".to_string(),
147 Array2::from_shape_fn((128, 64), |(_, _)| {
148 use scirs2_core::random::{Random, Rng};
149 let mut random = Random::default();
150 (random.random::<f32>() - 0.5) * 0.1
151 }),
152 );
153 feature_extractor.insert(
154 "fc".to_string(),
155 Array2::from_shape_fn((256, 128), |(_, _)| {
156 use scirs2_core::random::{Random, Rng};
157 let mut random = Random::default();
158 (random.random::<f32>() - 0.5) * 0.1
159 }),
160 );
161
162 Self {
163 feature_extractor,
164 prototype_method: PrototypeMethod::Mean,
165 distance_metric: DistanceMetric::Euclidean,
166 }
167 }
168}
169
170impl Default for MAMLComponents {
171 fn default() -> Self {
172 let mut inner_params = HashMap::new();
173 let mut outer_params = HashMap::new();
174 let mut meta_grads = HashMap::new();
175
176 for layer in ["layer1", "layer2", "output"] {
177 inner_params.insert(
178 layer.to_string(),
179 Array2::from_shape_fn((128, 128), |(_, _)| {
180 use scirs2_core::random::{Random, Rng};
181 let mut random = Random::default();
182 (random.random::<f32>() - 0.5) * 0.1
183 }),
184 );
185 outer_params.insert(
186 layer.to_string(),
187 Array2::from_shape_fn((128, 128), |(_, _)| {
188 use scirs2_core::random::{Random, Rng};
189 let mut random = Random::default();
190 (random.random::<f32>() - 0.5) * 0.1
191 }),
192 );
193 meta_grads.insert(layer.to_string(), Array2::zeros((128, 128)));
194 }
195
196 Self {
197 inner_loop_params: inner_params,
198 outer_loop_params: outer_params,
199 meta_gradients: meta_grads,
200 task_adaptations: HashMap::new(),
201 }
202 }
203}
204
205impl FewShotLearning {
206 pub fn new(
208 support_size: usize,
209 query_size: usize,
210 num_ways: usize,
211 meta_algorithm: MetaAlgorithm,
212 ) -> Self {
213 Self {
214 support_size,
215 query_size,
216 num_ways,
217 meta_algorithm,
218 adaptation_config: AdaptationConfig::default(),
219 prototypical_network: PrototypicalNetwork::default(),
220 maml_components: MAMLComponents::default(),
221 }
222 }
223
224 pub async fn few_shot_adapt(
226 &mut self,
227 support_examples: &[(String, String, String)], query_examples: &[(String, String)], model: &MultiModalEmbedding,
230 ) -> Result<Vec<(String, f32)>> {
231 match self.meta_algorithm {
232 MetaAlgorithm::PrototypicalNetworks => {
233 self.prototypical_adapt(support_examples, query_examples, model)
234 .await
235 }
236 MetaAlgorithm::MAML => {
237 self.maml_adapt(support_examples, query_examples, model)
238 .await
239 }
240 MetaAlgorithm::Reptile => {
241 self.reptile_adapt(support_examples, query_examples, model)
242 .await
243 }
244 _ => {
245 self.prototypical_adapt(support_examples, query_examples, model)
247 .await
248 }
249 }
250 }
251
252 async fn prototypical_adapt(
254 &mut self,
255 support_examples: &[(String, String, String)],
256 query_examples: &[(String, String)],
257 model: &MultiModalEmbedding,
258 ) -> Result<Vec<(String, f32)>> {
259 let mut prototypes = HashMap::new();
261 let mut label_embeddings: HashMap<String, Vec<Array1<f32>>> = HashMap::new();
262
263 for (text, entity, label) in support_examples {
264 let text_emb = model.text_encoder.encode(text)?;
265 let kg_emb_raw = model.get_or_create_kg_embedding(entity)?;
266 let kg_emb = model.kg_encoder.encode_entity(&kg_emb_raw)?;
267
268 let combined_emb = &text_emb + &kg_emb;
270
271 label_embeddings
272 .entry(label.clone())
273 .or_default()
274 .push(combined_emb);
275 }
276
277 for (label, embeddings) in &label_embeddings {
279 let prototype = self.compute_prototype(embeddings)?;
280 prototypes.insert(label.clone(), prototype);
281 }
282
283 let mut predictions = Vec::new();
285 for (text, entity) in query_examples {
286 let text_emb = model.text_encoder.encode(text)?;
287 let kg_emb_raw = model.get_or_create_kg_embedding(entity)?;
288 let kg_emb = model.kg_encoder.encode_entity(&kg_emb_raw)?;
289
290 let query_emb = &text_emb + &kg_emb;
291
292 let mut best_score = f32::NEG_INFINITY;
293 let mut best_label = String::new();
294
295 for (label, prototype) in &prototypes {
296 let distance = self.compute_distance(&query_emb, prototype);
297 let score = (-distance / self.adaptation_config.temperature).exp();
298
299 if score > best_score {
300 best_score = score;
301 best_label = label.clone();
302 }
303 }
304
305 predictions.push((best_label, best_score));
306 }
307
308 Ok(predictions)
309 }
310
311 async fn maml_adapt(
313 &mut self,
314 support_examples: &[(String, String, String)],
315 query_examples: &[(String, String)],
316 model: &MultiModalEmbedding,
317 ) -> Result<Vec<(String, f32)>> {
318 let task_id = {
319 use scirs2_core::random::{Random, Rng};
320 let mut random = Random::default();
321 format!("task_{}", random.random::<u32>())
322 };
323
324 let mut task_params = HashMap::new();
326 for (layer_name, params) in &self.maml_components.inner_loop_params {
327 task_params.insert(layer_name.clone(), params.clone());
328 }
329
330 for _ in 0..self.adaptation_config.adaptation_steps {
332 let mut gradients = HashMap::new();
333
334 for (text, entity, label) in support_examples {
336 let text_emb = model.text_encoder.encode(text)?;
337 let kg_emb_raw = model.get_or_create_kg_embedding(entity)?;
338 let kg_emb = model.kg_encoder.encode_entity(&kg_emb_raw)?;
339
340 let input_emb = &text_emb + &kg_emb;
341 let predicted = self.forward_pass(&input_emb, &task_params)?;
342
343 let target = self.label_to_target(label)?;
345 let loss_grad = &predicted - ⌖
346
347 for layer_name in task_params.keys() {
349 let grad = self.compute_layer_gradient(&input_emb, &loss_grad, layer_name)?;
350 *gradients
351 .entry(layer_name.clone())
352 .or_insert_with(|| Array2::zeros(grad.dim())) += &grad;
353 }
354 }
355
356 for (layer_name, params) in &mut task_params {
358 if let Some(grad) = gradients.get(layer_name) {
359 *params = &*params - &(grad * self.adaptation_config.adaptation_lr);
360 }
361 }
362 }
363
364 self.maml_components
366 .task_adaptations
367 .insert(task_id.clone(), task_params.clone());
368
369 let mut predictions = Vec::new();
371 for (text, entity) in query_examples {
372 let text_emb = model.text_encoder.encode(text)?;
373 let kg_emb_raw = model.get_or_create_kg_embedding(entity)?;
374 let kg_emb = model.kg_encoder.encode_entity(&kg_emb_raw)?;
375
376 let query_emb = &text_emb + &kg_emb;
377 let output = self.forward_pass(&query_emb, &task_params)?;
378
379 let (predicted_label, confidence) = self.output_to_prediction(&output)?;
381 predictions.push((predicted_label, confidence));
382 }
383
384 Ok(predictions)
385 }
386
387 async fn reptile_adapt(
389 &mut self,
390 support_examples: &[(String, String, String)],
391 query_examples: &[(String, String)],
392 model: &MultiModalEmbedding,
393 ) -> Result<Vec<(String, f32)>> {
394 let mut adapted_params = HashMap::new();
396
397 for (layer_name, params) in &self.maml_components.outer_loop_params {
399 adapted_params.insert(layer_name.clone(), params.clone());
400 }
401
402 for _ in 0..self.adaptation_config.adaptation_steps {
404 let mut param_updates = HashMap::new();
405
406 for (text, entity, label) in support_examples {
407 let text_emb = model.text_encoder.encode(text)?;
408 let kg_emb_raw = model.get_or_create_kg_embedding(entity)?;
409 let kg_emb = model.kg_encoder.encode_entity(&kg_emb_raw)?;
410
411 let input_emb = &text_emb + &kg_emb;
412 let predicted = self.forward_pass(&input_emb, &adapted_params)?;
413
414 let target = self.label_to_target(label)?;
416 let error = &predicted - ⌖
417
418 for (layer_name, params) in &adapted_params {
420 let update = &error * self.adaptation_config.adaptation_lr;
421 let param_change = Array2::from_shape_fn(params.dim(), |(i, j)| {
422 if i < update.len() && j < params.dim().1 {
423 update[i] * params[(i, j)]
424 } else {
425 0.0
426 }
427 });
428
429 *param_updates
430 .entry(layer_name.clone())
431 .or_insert_with(|| Array2::zeros(params.dim())) += ¶m_change;
432 }
433 }
434
435 for (layer_name, params) in &mut adapted_params {
437 if let Some(update) = param_updates.get(layer_name) {
438 *params = &*params - update;
439 }
440 }
441 }
442
443 let mut predictions = Vec::new();
445 for (text, entity) in query_examples {
446 let text_emb = model.text_encoder.encode(text)?;
447 let kg_emb_raw = model.get_or_create_kg_embedding(entity)?;
448 let kg_emb = model.kg_encoder.encode_entity(&kg_emb_raw)?;
449
450 let query_emb = &text_emb + &kg_emb;
451 let output = self.forward_pass(&query_emb, &adapted_params)?;
452
453 let (predicted_label, confidence) = self.output_to_prediction(&output)?;
454 predictions.push((predicted_label, confidence));
455 }
456
457 Ok(predictions)
458 }
459
460 pub fn compute_prototype(&self, embeddings: &[Array1<f32>]) -> Result<Array1<f32>> {
462 if embeddings.is_empty() {
463 return Err(anyhow!("Cannot compute prototype from empty embeddings"));
464 }
465
466 match self.prototypical_network.prototype_method {
467 PrototypeMethod::Mean => {
468 let mut prototype = Array1::zeros(embeddings[0].len());
469 for emb in embeddings {
470 prototype = &prototype + emb;
471 }
472 prototype /= embeddings.len() as f32;
473 Ok(prototype)
474 }
475 PrototypeMethod::AttentionWeighted => {
476 let mut weights = Vec::new();
478 let mut weight_sum = 0.0;
479
480 for emb in embeddings {
481 let weight = emb.dot(emb).sqrt(); weights.push(weight);
483 weight_sum += weight;
484 }
485
486 let mut prototype = Array1::zeros(embeddings[0].len());
487 for (emb, &weight) in embeddings.iter().zip(weights.iter()) {
488 prototype = &prototype + &(emb * (weight / weight_sum));
489 }
490 Ok(prototype)
491 }
492 PrototypeMethod::LearnableAggregation => {
493 let mut prototype = Array1::zeros(embeddings[0].len());
495 for (i, emb) in embeddings.iter().enumerate() {
496 let weight = 1.0 / (1.0 + i as f32); prototype = &prototype + &(emb * weight);
498 }
499 let total_weight: f32 = (0..embeddings.len()).map(|i| 1.0 / (1.0 + i as f32)).sum();
500 prototype /= total_weight;
501 Ok(prototype)
502 }
503 }
504 }
505
506 pub fn compute_distance(&self, emb1: &Array1<f32>, emb2: &Array1<f32>) -> f32 {
508 match self.prototypical_network.distance_metric {
509 DistanceMetric::Euclidean => {
510 let diff = emb1 - emb2;
511 diff.dot(&diff).sqrt()
512 }
513 DistanceMetric::Cosine => {
514 let dot_product = emb1.dot(emb2);
515 let norm1 = emb1.dot(emb1).sqrt();
516 let norm2 = emb2.dot(emb2).sqrt();
517 if norm1 > 0.0 && norm2 > 0.0 {
518 1.0 - (dot_product / (norm1 * norm2))
519 } else {
520 1.0
521 }
522 }
523 DistanceMetric::Learned => {
524 let diff = emb1 - emb2;
526 diff.mapv(|x| x.abs()).sum()
527 }
528 DistanceMetric::Mahalanobis => {
529 let diff = emb1 - emb2;
531 diff.dot(&diff).sqrt()
532 }
533 }
534 }
535
536 fn forward_pass(
538 &self,
539 input: &Array1<f32>,
540 params: &HashMap<String, Array2<f32>>,
541 ) -> Result<Array1<f32>> {
542 let mut output = input.clone();
543
544 for layer_name in ["layer1", "layer2", "output"] {
546 if let Some(weights) = params.get(layer_name) {
547 output = weights.dot(&output);
548 if layer_name != "output" {
549 output = output.mapv(|x| x.max(0.0)); }
551 }
552 }
553
554 Ok(output)
555 }
556
557 fn label_to_target(&self, label: &str) -> Result<Array1<f32>> {
559 let label_hash = label.chars().map(|c| c as u8).sum::<u8>() as usize;
561 let target_dim = 128; let mut target = Array1::zeros(target_dim);
563 target[label_hash % target_dim] = 1.0;
564 Ok(target)
565 }
566
567 fn compute_layer_gradient(
569 &self,
570 input: &Array1<f32>,
571 loss_grad: &Array1<f32>,
572 _layer_name: &str,
573 ) -> Result<Array2<f32>> {
574 let input_len = input.len();
576 let grad_len = loss_grad.len();
577 let mut gradient = Array2::zeros((grad_len.min(128), input_len.min(128)));
578
579 for i in 0..gradient.nrows() {
580 for j in 0..gradient.ncols() {
581 if i < loss_grad.len() && j < input.len() {
582 gradient[(i, j)] = loss_grad[i] * input[j];
583 }
584 }
585 }
586
587 Ok(gradient)
588 }
589
590 fn output_to_prediction(&self, output: &Array1<f32>) -> Result<(String, f32)> {
592 let (max_idx, &max_val) = output
594 .iter()
595 .enumerate()
596 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
597 .unwrap_or((0, &0.0));
598
599 let label = format!("class_{max_idx}");
601 let confidence = 1.0 / (1.0 + (-max_val).exp()); Ok((label, confidence))
604 }
605
606 pub fn meta_update(&mut self, tasks: &[Vec<(String, String, String)>]) -> Result<()> {
608 match self.meta_algorithm {
609 MetaAlgorithm::MAML => {
610 let mut meta_gradients = HashMap::new();
612
613 for _task in tasks {
614 for layer_name in self.maml_components.outer_loop_params.keys() {
616 let grad = Array2::from_shape_fn((128, 128), |(_, _)| {
617 use scirs2_core::random::{Random, Rng};
618 let mut random = Random::default();
619 (random.random::<f32>() - 0.5) * 0.01
620 });
621 *meta_gradients
622 .entry(layer_name.clone())
623 .or_insert_with(|| Array2::zeros((128, 128))) += &grad;
624 }
625 }
626
627 for (layer_name, params) in &mut self.maml_components.outer_loop_params {
629 if let Some(meta_grad) = meta_gradients.get(layer_name) {
630 *params = &*params - &(meta_grad * self.adaptation_config.adaptation_lr);
631 }
632 }
633 }
634 MetaAlgorithm::Reptile => {
635 for _task in tasks {
637 for params in self.maml_components.outer_loop_params.values_mut() {
639 let update = Array2::from_shape_fn(params.dim(), |(_, _)| {
640 use scirs2_core::random::{Random, Rng};
641 let mut random = Random::default();
642 (random.random::<f32>() - 0.5) * 0.001
643 });
644 *params = &*params + &update;
645 }
646 }
647 }
648 _ => {
649 for params in self.prototypical_network.feature_extractor.values_mut() {
651 let update = Array2::from_shape_fn(params.dim(), |(_, _)| {
652 use scirs2_core::random::{Random, Rng};
653 let mut random = Random::default();
654 (random.random::<f32>() - 0.5) * 0.001
655 });
656 *params = &*params + &update;
657 }
658 }
659 }
660
661 Ok(())
662 }
663}