1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
7use sklears_core::{
8 error::Result as SklResult,
9 prelude::{Predict, SklearsError},
10 traits::{Estimator, Fit, Untrained},
11 types::{Float, FloatBounds},
12};
13use std::collections::HashMap;
14
15use crate::{PipelinePredictor, PipelineStep};
16
17#[derive(Debug, Clone)]
19pub struct SupportSet {
20 pub features: Array2<f64>,
22 pub labels: Array1<f64>,
24 pub n_shot: usize,
26 pub n_way: usize,
28}
29
30impl SupportSet {
31 #[must_use]
33 pub fn new(features: Array2<f64>, labels: Array1<f64>, n_shot: usize, n_way: usize) -> Self {
34 Self {
35 features,
36 labels,
37 n_shot,
38 n_way,
39 }
40 }
41
42 #[must_use]
44 pub fn get_class_examples(&self, class_label: f64) -> (Array2<f64>, Array1<f64>) {
45 let mut class_features = Vec::new();
46 let mut class_labels = Vec::new();
47
48 for (i, &label) in self.labels.iter().enumerate() {
49 if (label - class_label).abs() < 1e-6 {
50 class_features.push(self.features.row(i).to_owned());
51 class_labels.push(label);
52 }
53 }
54
55 if class_features.is_empty() {
56 return (Array2::zeros((0, self.features.ncols())), Array1::zeros(0));
57 }
58
59 let n_examples = class_features.len();
60 let n_features = class_features[0].len();
61 let mut features_array = Array2::zeros((n_examples, n_features));
62
63 for (i, features) in class_features.iter().enumerate() {
64 features_array.row_mut(i).assign(features);
65 }
66
67 (features_array, Array1::from_vec(class_labels))
68 }
69
70 #[must_use]
72 pub fn get_classes(&self) -> Vec<f64> {
73 let mut classes: Vec<f64> = self.labels.iter().copied().collect();
74 classes.sort_by(|a, b| a.partial_cmp(b).unwrap());
75 classes.dedup_by(|a, b| (*a - *b).abs() < 1e-6);
76 classes
77 }
78}
79
80#[derive(Debug)]
82pub struct PrototypicalNetwork<S = Untrained> {
83 state: S,
84 distance_metric: DistanceMetric,
85 embedding_dim: Option<usize>,
86 prototypes: HashMap<String, Array1<f64>>,
87}
88
89#[derive(Debug)]
91pub struct PrototypicalNetworkTrained {
92 prototypes: HashMap<String, Array1<f64>>,
93 distance_metric: DistanceMetric,
94 n_features_in: usize,
95 feature_names_in: Option<Vec<String>>,
96}
97
98#[derive(Debug, Clone)]
100pub enum DistanceMetric {
101 Euclidean,
103 Cosine,
105 Manhattan,
107 Mahalanobis { covariance: Array2<f64> },
109}
110
111impl DistanceMetric {
112 #[must_use]
114 pub fn distance(&self, a: &ArrayView1<'_, f64>, b: &ArrayView1<'_, f64>) -> f64 {
115 match self {
116 DistanceMetric::Euclidean => ((a - b).mapv(|x| x * x).sum()).sqrt(),
117 DistanceMetric::Cosine => {
118 let dot_product = a.dot(b);
119 let norm_a = (a.mapv(|x| x * x).sum()).sqrt();
120 let norm_b = (b.mapv(|x| x * x).sum()).sqrt();
121 if norm_a == 0.0 || norm_b == 0.0 {
122 1.0
123 } else {
124 1.0 - dot_product / (norm_a * norm_b)
125 }
126 }
127 DistanceMetric::Manhattan => (a - b).mapv(f64::abs).sum(),
128 DistanceMetric::Mahalanobis { covariance } => {
129 let diff = a - b;
130 let weighted_diff = &diff * &covariance.diag();
132 (weighted_diff.mapv(|x| x * x).sum()).sqrt()
133 }
134 }
135 }
136}
137
138impl PrototypicalNetwork<Untrained> {
139 #[must_use]
141 pub fn new() -> Self {
142 Self {
143 state: Untrained,
144 distance_metric: DistanceMetric::Euclidean,
145 embedding_dim: None,
146 prototypes: HashMap::new(),
147 }
148 }
149
150 #[must_use]
152 pub fn distance_metric(mut self, metric: DistanceMetric) -> Self {
153 self.distance_metric = metric;
154 self
155 }
156
157 #[must_use]
159 pub fn embedding_dim(mut self, dim: usize) -> Self {
160 self.embedding_dim = Some(dim);
161 self
162 }
163}
164
165impl Default for PrototypicalNetwork<Untrained> {
166 fn default() -> Self {
167 Self::new()
168 }
169}
170
171impl Estimator for PrototypicalNetwork<Untrained> {
172 type Config = ();
173 type Error = SklearsError;
174 type Float = Float;
175
176 fn config(&self) -> &Self::Config {
177 &()
178 }
179}
180
181impl PrototypicalNetwork<Untrained> {
182 pub fn fit_support_set(
184 self,
185 support_set: &SupportSet,
186 ) -> SklResult<PrototypicalNetwork<PrototypicalNetworkTrained>> {
187 let mut prototypes = HashMap::new();
188 let classes = support_set.get_classes();
189
190 for class_label in classes {
192 let (class_features, _) = support_set.get_class_examples(class_label);
193
194 if class_features.nrows() > 0 {
195 let prototype = class_features.mean_axis(Axis(0)).unwrap();
197 prototypes.insert(class_label.to_string(), prototype);
198 }
199 }
200
201 Ok(PrototypicalNetwork {
202 state: PrototypicalNetworkTrained {
203 prototypes,
204 distance_metric: self.distance_metric,
205 n_features_in: support_set.features.ncols(),
206 feature_names_in: None,
207 },
208 distance_metric: DistanceMetric::Euclidean, embedding_dim: None,
210 prototypes: HashMap::new(),
211 })
212 }
213}
214
215impl Fit<ArrayView2<'_, Float>, Option<&ArrayView1<'_, Float>>> for PrototypicalNetwork<Untrained> {
216 type Fitted = PrototypicalNetwork<PrototypicalNetworkTrained>;
217
218 fn fit(
219 self,
220 x: &ArrayView2<'_, Float>,
221 y: &Option<&ArrayView1<'_, Float>>,
222 ) -> SklResult<Self::Fitted> {
223 if let Some(y_values) = y.as_ref() {
224 let x_f64 = x.mapv(|v| v);
225 let y_f64 = y_values.mapv(|v| v);
226
227 let unique_labels: Vec<f64> = {
229 let mut labels: Vec<f64> = y_f64.iter().copied().collect();
230 labels.sort_by(|a, b| a.partial_cmp(b).unwrap());
231 labels.dedup_by(|a, b| (*a - *b).abs() < 1e-6);
232 labels
233 };
234
235 let n_way = unique_labels.len();
236 let n_shot = y_f64.len() / n_way; let support_set = SupportSet::new(x_f64, y_f64, n_shot, n_way);
239 self.fit_support_set(&support_set)
240 } else {
241 Err(SklearsError::InvalidInput(
242 "Labels required for few-shot learning".to_string(),
243 ))
244 }
245 }
246}
247
248impl PrototypicalNetwork<PrototypicalNetworkTrained> {
249 pub fn predict(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array1<f64>> {
251 let x_f64 = x.mapv(|v| v);
252 let mut predictions = Array1::zeros(x_f64.nrows());
253
254 for (i, sample) in x_f64.axis_iter(Axis(0)).enumerate() {
255 let mut min_distance = f64::INFINITY;
256 let mut predicted_class = 0.0;
257
258 for (class_str, prototype) in &self.state.prototypes {
259 let distance = self
260 .state
261 .distance_metric
262 .distance(&sample, &prototype.view());
263 if distance < min_distance {
264 min_distance = distance;
265 predicted_class = class_str.parse().unwrap_or(0.0);
266 }
267 }
268
269 predictions[i] = predicted_class;
270 }
271
272 Ok(predictions)
273 }
274
275 #[must_use]
277 pub fn prototypes(&self) -> &HashMap<String, Array1<f64>> {
278 &self.state.prototypes
279 }
280}
281
282#[derive(Debug)]
284pub struct MAMLLearner<S = Untrained> {
285 state: S,
286 base_learner: Option<Box<dyn PipelinePredictor>>,
287 inner_lr: f64,
288 outer_lr: f64,
289 inner_steps: usize,
290 meta_parameters: HashMap<String, f64>,
291}
292
293#[derive(Debug)]
295pub struct MAMLLearnerTrained {
296 fitted_learner: Box<dyn PipelinePredictor>,
297 inner_lr: f64,
298 outer_lr: f64,
299 inner_steps: usize,
300 meta_parameters: HashMap<String, f64>,
301 n_features_in: usize,
302 feature_names_in: Option<Vec<String>>,
303}
304
305impl MAMLLearner<Untrained> {
306 #[must_use]
308 pub fn new(base_learner: Box<dyn PipelinePredictor>) -> Self {
309 Self {
310 state: Untrained,
311 base_learner: Some(base_learner),
312 inner_lr: 0.01,
313 outer_lr: 0.001,
314 inner_steps: 5,
315 meta_parameters: HashMap::new(),
316 }
317 }
318
319 #[must_use]
321 pub fn inner_lr(mut self, lr: f64) -> Self {
322 self.inner_lr = lr;
323 self
324 }
325
326 #[must_use]
328 pub fn outer_lr(mut self, lr: f64) -> Self {
329 self.outer_lr = lr;
330 self
331 }
332
333 #[must_use]
335 pub fn inner_steps(mut self, steps: usize) -> Self {
336 self.inner_steps = steps;
337 self
338 }
339
340 #[must_use]
342 pub fn meta_parameters(mut self, params: HashMap<String, f64>) -> Self {
343 self.meta_parameters = params;
344 self
345 }
346}
347
348impl Estimator for MAMLLearner<Untrained> {
349 type Config = ();
350 type Error = SklearsError;
351 type Float = Float;
352
353 fn config(&self) -> &Self::Config {
354 &()
355 }
356}
357
358impl Fit<ArrayView2<'_, Float>, Option<&ArrayView1<'_, Float>>> for MAMLLearner<Untrained> {
359 type Fitted = MAMLLearner<MAMLLearnerTrained>;
360
361 fn fit(
362 self,
363 x: &ArrayView2<'_, Float>,
364 y: &Option<&ArrayView1<'_, Float>>,
365 ) -> SklResult<Self::Fitted> {
366 let mut base_learner = self
367 .base_learner
368 .ok_or_else(|| SklearsError::InvalidInput("No base learner provided".to_string()))?;
369
370 if let Some(y_values) = y.as_ref() {
371 base_learner.fit(x, y_values)?;
373
374 Ok(MAMLLearner {
375 state: MAMLLearnerTrained {
376 fitted_learner: base_learner,
377 inner_lr: self.inner_lr,
378 outer_lr: self.outer_lr,
379 inner_steps: self.inner_steps,
380 meta_parameters: self.meta_parameters,
381 n_features_in: x.ncols(),
382 feature_names_in: None,
383 },
384 base_learner: None,
385 inner_lr: 0.0,
386 outer_lr: 0.0,
387 inner_steps: 0,
388 meta_parameters: HashMap::new(),
389 })
390 } else {
391 Err(SklearsError::InvalidInput(
392 "Labels required for MAML training".to_string(),
393 ))
394 }
395 }
396}
397
398impl MAMLLearner<MAMLLearnerTrained> {
399 pub fn adapt_to_task(&mut self, support_set: &SupportSet) -> SklResult<()> {
401 for _ in 0..self.state.inner_steps {
403 let mapped_features = support_set.features.view().mapv(|v| v as Float);
406 let mapped_labels = support_set.labels.view().mapv(|v| v as Float);
407 self.state
408 .fitted_learner
409 .fit(&mapped_features.view(), &mapped_labels.view())?;
410 }
411 Ok(())
412 }
413
414 pub fn predict(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array1<f64>> {
416 self.state.fitted_learner.predict(x)
417 }
418
419 #[must_use]
421 pub fn meta_parameters(&self) -> &HashMap<String, f64> {
422 &self.state.meta_parameters
423 }
424}
425
426#[derive(Debug)]
428pub struct FewShotPipeline<S = Untrained> {
429 state: S,
430 learner_type: FewShotLearnerType,
431 meta_learner: Option<MetaLearnerWrapper>,
432}
433
434#[derive(Debug)]
436pub struct FewShotPipelineTrained {
437 fitted_learner: MetaLearnerWrapper,
438 n_features_in: usize,
439 feature_names_in: Option<Vec<String>>,
440}
441
442#[derive(Debug, Clone)]
444pub enum FewShotLearnerType {
445 Prototypical { distance_metric: DistanceMetric },
447 MAML {
449 inner_lr: f64,
450 outer_lr: f64,
451 inner_steps: usize,
452 },
453}
454
455#[derive(Debug)]
457pub enum MetaLearnerWrapper {
458 Prototypical(PrototypicalNetwork<PrototypicalNetworkTrained>),
460 MAML(MAMLLearner<MAMLLearnerTrained>),
462}
463
464impl MetaLearnerWrapper {
465 pub fn predict(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array1<f64>> {
467 match self {
468 MetaLearnerWrapper::Prototypical(learner) => learner.predict(x),
469 MetaLearnerWrapper::MAML(learner) => learner.predict(x),
470 }
471 }
472}
473
474impl FewShotPipeline<Untrained> {
475 #[must_use]
477 pub fn new(learner_type: FewShotLearnerType) -> Self {
478 Self {
479 state: Untrained,
480 learner_type,
481 meta_learner: None,
482 }
483 }
484
485 #[must_use]
487 pub fn prototypical(distance_metric: DistanceMetric) -> Self {
488 Self::new(FewShotLearnerType::Prototypical { distance_metric })
489 }
490
491 #[must_use]
493 pub fn maml(inner_lr: f64, outer_lr: f64, inner_steps: usize) -> Self {
494 Self::new(FewShotLearnerType::MAML {
495 inner_lr,
496 outer_lr,
497 inner_steps,
498 })
499 }
500}
501
502impl Estimator for FewShotPipeline<Untrained> {
503 type Config = ();
504 type Error = SklearsError;
505 type Float = Float;
506
507 fn config(&self) -> &Self::Config {
508 &()
509 }
510}
511
512impl Fit<ArrayView2<'_, Float>, Option<&ArrayView1<'_, Float>>> for FewShotPipeline<Untrained> {
513 type Fitted = FewShotPipeline<FewShotPipelineTrained>;
514
515 fn fit(
516 self,
517 x: &ArrayView2<'_, Float>,
518 y: &Option<&ArrayView1<'_, Float>>,
519 ) -> SklResult<Self::Fitted> {
520 let fitted_learner = match &self.learner_type {
521 FewShotLearnerType::Prototypical { distance_metric } => {
522 let learner = PrototypicalNetwork::new().distance_metric(distance_metric.clone());
523 let fitted = learner.fit(x, y)?;
524 MetaLearnerWrapper::Prototypical(fitted)
525 }
526 FewShotLearnerType::MAML {
527 inner_lr,
528 outer_lr,
529 inner_steps,
530 } => {
531 use crate::MockPredictor;
532 let base_learner = Box::new(MockPredictor::new());
533 let learner = MAMLLearner::new(base_learner)
534 .inner_lr(*inner_lr)
535 .outer_lr(*outer_lr)
536 .inner_steps(*inner_steps);
537 let fitted = learner.fit(x, y)?;
538 MetaLearnerWrapper::MAML(fitted)
539 }
540 };
541
542 Ok(FewShotPipeline {
543 state: FewShotPipelineTrained {
544 fitted_learner,
545 n_features_in: x.ncols(),
546 feature_names_in: None,
547 },
548 learner_type: FewShotLearnerType::Prototypical {
549 distance_metric: DistanceMetric::Euclidean,
550 }, meta_learner: None,
552 })
553 }
554}
555
556impl FewShotPipeline<FewShotPipelineTrained> {
557 pub fn predict(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array1<f64>> {
559 self.state.fitted_learner.predict(x)
560 }
561
562 pub fn adapt_to_task(&mut self, support_set: &SupportSet) -> SklResult<()> {
564 match &mut self.state.fitted_learner {
565 MetaLearnerWrapper::Prototypical(_) => {
566 Ok(())
569 }
570 MetaLearnerWrapper::MAML(learner) => learner.adapt_to_task(support_set),
571 }
572 }
573
574 #[must_use]
576 pub fn learner(&self) -> &MetaLearnerWrapper {
577 &self.state.fitted_learner
578 }
579}
580
581#[allow(non_snake_case)]
582#[cfg(test)]
583mod tests {
584 use super::*;
585 use scirs2_core::ndarray::array;
586
587 #[test]
588 fn test_support_set() {
589 let features = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
590 let labels = array![0.0, 0.0, 1.0, 1.0];
591
592 let support_set = SupportSet::new(features, labels, 2, 2);
593
594 let (class_features, class_labels) = support_set.get_class_examples(0.0);
595 assert_eq!(class_features.nrows(), 2);
596 assert_eq!(class_labels.len(), 2);
597
598 let classes = support_set.get_classes();
599 assert_eq!(classes.len(), 2);
600 }
601
602 #[test]
603 fn test_distance_metrics() {
604 let a = array![1.0, 2.0, 3.0];
605 let b = array![4.0, 5.0, 6.0];
606
607 let euclidean = DistanceMetric::Euclidean;
608 let distance = euclidean.distance(&a.view(), &b.view());
609 assert!(distance > 0.0);
610
611 let cosine = DistanceMetric::Cosine;
612 let distance = cosine.distance(&a.view(), &b.view());
613 assert!(distance >= 0.0 && distance <= 2.0);
614 }
615
616 #[test]
617 fn test_prototypical_network() {
618 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
619 let y = array![0.0, 0.0, 1.0, 1.0];
620
621 let learner = PrototypicalNetwork::new();
622 let fitted = learner.fit(&x.view(), &Some(&y.view())).unwrap();
623
624 let predictions = fitted.predict(&x.view()).unwrap();
625 assert_eq!(predictions.len(), x.nrows());
626 }
627
628 #[test]
629 fn test_few_shot_pipeline() {
630 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
631 let y = array![0.0, 0.0, 1.0, 1.0];
632
633 let pipeline = FewShotPipeline::prototypical(DistanceMetric::Euclidean);
634 let fitted = pipeline.fit(&x.view(), &Some(&y.view())).unwrap();
635
636 let predictions = fitted.predict(&x.view()).unwrap();
637 assert_eq!(predictions.len(), x.nrows());
638 }
639}