use crate::{TrainError, TrainResult};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1, Axis};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ShotType {
OneShot,
FewShot(usize),
Custom(usize),
}
impl ShotType {
pub fn k(&self) -> usize {
match self {
ShotType::OneShot => 1,
ShotType::FewShot(k) => *k,
ShotType::Custom(k) => *k,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DistanceMetric {
Euclidean,
Cosine,
Manhattan,
SquaredEuclidean,
}
impl DistanceMetric {
pub fn compute(&self, a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> f64 {
match self {
DistanceMetric::Euclidean => {
let diff = a.to_owned() - b.to_owned();
diff.dot(&diff).sqrt()
}
DistanceMetric::Cosine => {
let dot = a.dot(b);
let norm_a = a.dot(a).sqrt();
let norm_b = b.dot(b).sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
0.0
} else {
1.0 - (dot / (norm_a * norm_b))
}
}
DistanceMetric::Manhattan => {
let diff = a.to_owned() - b.to_owned();
diff.iter().map(|x| x.abs()).sum()
}
DistanceMetric::SquaredEuclidean => {
let diff = a.to_owned() - b.to_owned();
diff.dot(&diff)
}
}
}
}
#[derive(Debug, Clone)]
pub struct SupportSet {
pub features: Array2<f64>,
pub labels: Array1<usize>,
pub num_classes: usize,
}
impl SupportSet {
pub fn new(features: Array2<f64>, labels: Array1<usize>) -> TrainResult<Self> {
if features.nrows() != labels.len() {
return Err(TrainError::InvalidParameter(format!(
"Feature rows ({}) must match label count ({})",
features.nrows(),
labels.len()
)));
}
let num_classes = labels.iter().max().copied().unwrap_or(0) + 1;
Ok(Self {
features,
labels,
num_classes,
})
}
pub fn get_class_examples(&self, class_id: usize) -> Array2<f64> {
let indices: Vec<usize> = self
.labels
.iter()
.enumerate()
.filter(|(_, &label)| label == class_id)
.map(|(idx, _)| idx)
.collect();
if indices.is_empty() {
return Array2::zeros((0, self.features.ncols()));
}
let mut result = Array2::zeros((indices.len(), self.features.ncols()));
for (i, &idx) in indices.iter().enumerate() {
result.row_mut(i).assign(&self.features.row(idx));
}
result
}
pub fn size(&self) -> usize {
self.features.nrows()
}
}
#[derive(Debug, Clone)]
pub struct PrototypicalDistance {
metric: DistanceMetric,
prototypes: Option<Array2<f64>>,
}
impl PrototypicalDistance {
pub fn euclidean() -> Self {
Self {
metric: DistanceMetric::Euclidean,
prototypes: None,
}
}
pub fn cosine() -> Self {
Self {
metric: DistanceMetric::Cosine,
prototypes: None,
}
}
pub fn new(metric: DistanceMetric) -> Self {
Self {
metric,
prototypes: None,
}
}
pub fn compute_prototypes(&mut self, support: &SupportSet) {
let mut prototypes = Array2::zeros((support.num_classes, support.features.ncols()));
for class_id in 0..support.num_classes {
let class_examples = support.get_class_examples(class_id);
if class_examples.nrows() > 0 {
let prototype = class_examples
.mean_axis(Axis(0))
.expect("mean_axis on non-empty class examples");
prototypes.row_mut(class_id).assign(&prototype);
}
}
self.prototypes = Some(prototypes);
}
pub fn compute_distances(&self, query: &ArrayView1<f64>) -> TrainResult<Array1<f64>> {
let prototypes = self
.prototypes
.as_ref()
.ok_or_else(|| TrainError::Other("Prototypes not computed".to_string()))?;
let mut distances = Array1::zeros(prototypes.nrows());
for (i, prototype) in prototypes.axis_iter(Axis(0)).enumerate() {
distances[i] = self.metric.compute(query, &prototype);
}
Ok(distances)
}
pub fn predict(&self, query: &ArrayView1<f64>) -> TrainResult<usize> {
let distances = self.compute_distances(query)?;
let mut min_idx = 0;
let mut min_dist = distances[0];
for (i, &dist) in distances.iter().enumerate() {
if dist < min_dist {
min_dist = dist;
min_idx = i;
}
}
Ok(min_idx)
}
pub fn predict_proba(
&self,
query: &ArrayView1<f64>,
temperature: f64,
) -> TrainResult<Array1<f64>> {
let distances = self.compute_distances(query)?;
let logits = distances.mapv(|d| -d / temperature);
let max_logit = logits.iter().copied().fold(f64::NEG_INFINITY, f64::max);
let exp_logits = logits.mapv(|x| (x - max_logit).exp());
let sum_exp = exp_logits.sum();
let probs = exp_logits.mapv(|x| x / sum_exp);
Ok(probs)
}
}
#[derive(Debug, Clone)]
pub struct EpisodeSampler {
n_way: usize,
shot_type: ShotType,
n_query: usize,
}
impl EpisodeSampler {
pub fn new(n_way: usize, shot_type: ShotType, n_query: usize) -> Self {
Self {
n_way,
shot_type,
n_query,
}
}
pub fn support_size(&self) -> usize {
self.n_way * self.shot_type.k()
}
pub fn query_size(&self) -> usize {
self.n_way * self.n_query
}
pub fn description(&self) -> String {
format!(
"{}-way {}-shot (query: {} per class)",
self.n_way,
self.shot_type.k(),
self.n_query
)
}
}
#[derive(Debug, Clone)]
pub struct MatchingNetwork {
metric: DistanceMetric,
support: Option<SupportSet>,
}
impl MatchingNetwork {
pub fn new(metric: DistanceMetric) -> Self {
Self {
metric,
support: None,
}
}
pub fn set_support(&mut self, support: SupportSet) {
self.support = Some(support);
}
pub fn compute_attention(&self, query: &ArrayView1<f64>) -> TrainResult<Array1<f64>> {
let support = self
.support
.as_ref()
.ok_or_else(|| TrainError::Other("Support set not set".to_string()))?;
let n_support = support.size();
let mut similarities = Array1::zeros(n_support);
for i in 0..n_support {
let support_example = support.features.row(i);
similarities[i] = -self.metric.compute(query, &support_example);
}
let max_sim = similarities
.iter()
.copied()
.fold(f64::NEG_INFINITY, f64::max);
let exp_sims = similarities.mapv(|x| (x - max_sim).exp());
let sum_exp = exp_sims.sum();
let weights = exp_sims.mapv(|x| x / sum_exp);
Ok(weights)
}
pub fn predict_proba(&self, query: &ArrayView1<f64>) -> TrainResult<Array1<f64>> {
let support = self
.support
.as_ref()
.ok_or_else(|| TrainError::Other("Support set not set".to_string()))?;
let attention = self.compute_attention(query)?;
let mut class_probs = Array1::zeros(support.num_classes);
for (i, &weight) in attention.iter().enumerate() {
let label = support.labels[i];
class_probs[label] += weight;
}
Ok(class_probs)
}
pub fn predict(&self, query: &ArrayView1<f64>) -> TrainResult<usize> {
let probs = self.predict_proba(query)?;
let mut max_idx = 0;
let mut max_prob = probs[0];
for (i, &prob) in probs.iter().enumerate() {
if prob > max_prob {
max_prob = prob;
max_idx = i;
}
}
Ok(max_idx)
}
}
#[derive(Debug, Clone, Default)]
pub struct FewShotAccuracy {
correct: usize,
total: usize,
}
impl FewShotAccuracy {
pub fn new() -> Self {
Self {
correct: 0,
total: 0,
}
}
pub fn update(&mut self, predicted: usize, actual: usize) {
self.total += 1;
if predicted == actual {
self.correct += 1;
}
}
pub fn accuracy(&self) -> f64 {
if self.total == 0 {
0.0
} else {
self.correct as f64 / self.total as f64
}
}
pub fn reset(&mut self) {
self.correct = 0;
self.total = 0;
}
pub fn counts(&self) -> (usize, usize) {
(self.correct, self.total)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_shot_type() {
assert_eq!(ShotType::OneShot.k(), 1);
assert_eq!(ShotType::FewShot(5).k(), 5);
assert_eq!(ShotType::Custom(10).k(), 10);
}
#[test]
fn test_euclidean_distance() {
let a = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let b = Array1::from_vec(vec![4.0, 5.0, 6.0]);
let dist = DistanceMetric::Euclidean.compute(&a.view(), &b.view());
assert_relative_eq!(dist, 5.196152, epsilon = 1e-5);
}
#[test]
fn test_cosine_distance() {
let a = Array1::from_vec(vec![1.0, 0.0]);
let b = Array1::from_vec(vec![0.0, 1.0]);
let dist = DistanceMetric::Cosine.compute(&a.view(), &b.view());
assert_relative_eq!(dist, 1.0, epsilon = 1e-5);
}
#[test]
fn test_support_set_creation() {
let features = Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
.expect("unwrap");
let labels = Array1::from_vec(vec![0, 0, 1, 1]);
let support = SupportSet::new(features, labels).expect("unwrap");
assert_eq!(support.size(), 4);
assert_eq!(support.num_classes, 2);
}
#[test]
fn test_support_set_get_class_examples() {
let features = Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
.expect("unwrap");
let labels = Array1::from_vec(vec![0, 0, 1, 1]);
let support = SupportSet::new(features, labels).expect("unwrap");
let class_0 = support.get_class_examples(0);
assert_eq!(class_0.nrows(), 2);
assert_eq!(class_0[[0, 0]], 1.0);
assert_eq!(class_0[[1, 0]], 3.0);
}
#[test]
fn test_prototypical_distance() {
let features = Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
.expect("unwrap");
let labels = Array1::from_vec(vec![0, 0, 1, 1]);
let support = SupportSet::new(features, labels).expect("unwrap");
let mut proto = PrototypicalDistance::euclidean();
proto.compute_prototypes(&support);
let query = Array1::from_vec(vec![2.0, 3.0]);
let prediction = proto.predict(&query.view()).expect("unwrap");
assert_eq!(prediction, 0); }
#[test]
fn test_prototypical_predict_proba() {
let features = Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
.expect("unwrap");
let labels = Array1::from_vec(vec![0, 0, 1, 1]);
let support = SupportSet::new(features, labels).expect("unwrap");
let mut proto = PrototypicalDistance::euclidean();
proto.compute_prototypes(&support);
let query = Array1::from_vec(vec![2.0, 3.0]);
let probs = proto.predict_proba(&query.view(), 1.0).expect("unwrap");
assert_eq!(probs.len(), 2);
assert!(probs[0] > probs[1]); assert_relative_eq!(probs.sum(), 1.0, epsilon = 1e-10);
}
#[test]
fn test_episode_sampler() {
let sampler = EpisodeSampler::new(5, ShotType::OneShot, 15);
assert_eq!(sampler.support_size(), 5); assert_eq!(sampler.query_size(), 75); assert!(sampler.description().contains("5-way"));
}
#[test]
fn test_matching_network() {
let features = Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
.expect("unwrap");
let labels = Array1::from_vec(vec![0, 0, 1, 1]);
let support = SupportSet::new(features, labels).expect("unwrap");
let mut matcher = MatchingNetwork::new(DistanceMetric::Euclidean);
matcher.set_support(support);
let query = Array1::from_vec(vec![2.0, 3.0]);
let prediction = matcher.predict(&query.view()).expect("unwrap");
assert_eq!(prediction, 0); }
#[test]
fn test_matching_network_attention() {
let features = Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
.expect("unwrap");
let labels = Array1::from_vec(vec![0, 0, 1, 1]);
let support = SupportSet::new(features, labels).expect("unwrap");
let mut matcher = MatchingNetwork::new(DistanceMetric::Euclidean);
matcher.set_support(support);
let query = Array1::from_vec(vec![2.0, 3.0]);
let attention = matcher.compute_attention(&query.view()).expect("unwrap");
assert_eq!(attention.len(), 4);
assert_relative_eq!(attention.sum(), 1.0, epsilon = 1e-10);
}
#[test]
fn test_few_shot_accuracy() {
let mut acc = FewShotAccuracy::new();
acc.update(0, 0); acc.update(1, 1); acc.update(1, 0);
assert_eq!(acc.accuracy(), 2.0 / 3.0);
assert_eq!(acc.counts(), (2, 3));
acc.reset();
assert_eq!(acc.accuracy(), 0.0);
}
#[test]
fn test_manhattan_distance() {
let a = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let b = Array1::from_vec(vec![4.0, 5.0, 6.0]);
let dist = DistanceMetric::Manhattan.compute(&a.view(), &b.view());
assert_eq!(dist, 9.0);
}
#[test]
fn test_squared_euclidean_distance() {
let a = Array1::from_vec(vec![1.0, 2.0]);
let b = Array1::from_vec(vec![4.0, 6.0]);
let dist = DistanceMetric::SquaredEuclidean.compute(&a.view(), &b.view());
assert_eq!(dist, 25.0); }
}