Skip to main content

optirs_learned/
few_shot_impl.rs

1// Few-Shot Learning Implementation
2//
3// Implements core methods for PrototypicalNetwork, FastAdaptationEngine,
4// and TaskSimilarityCalculator types defined in crate::few_shot.
5
6use scirs2_core::ndarray::Array1;
7use scirs2_core::numeric::Float;
8use std::fmt::Debug;
9
10use crate::error::{OptimError, Result};
11use crate::few_shot::{
12    AdaptationStrategyType, FastAdaptationEngine, PrototypicalNetwork, TaskSimilarityCalculator,
13};
14
15// ---------------------------------------------------------------------------
16// PrototypicalNetwork additional impl
17// ---------------------------------------------------------------------------
18
19impl<T: Float + Debug + Send + Sync + 'static> PrototypicalNetwork<T> {
20    /// Encode a single feature vector through the encoder network.
21    ///
22    /// Performs a simple linear projection: out = features * W + b (truncating
23    /// or zero-padding the input to match the weight matrix dimensions).
24    /// A ReLU activation is applied element-wise.
25    pub fn encode(&self, features: &Array1<T>) -> Result<Array1<T>> {
26        let emb_dim = self.embedding_dim();
27        let layers = self.encoder_layers();
28        if layers.is_empty() {
29            return Err(OptimError::InvalidState(
30                "Encoder has no layers".to_string(),
31            ));
32        }
33        let layer = &layers[0];
34        let (input_rows, output_cols) = (layer.weights.nrows(), layer.weights.ncols());
35        let actual_out = output_cols.min(emb_dim);
36
37        // Build padded/truncated input
38        let mut input = vec![T::zero(); input_rows];
39        let copy_len = features.len().min(input_rows);
40        for i in 0..copy_len {
41            input[i] = features[i];
42        }
43
44        // Manual matmul: output[j] = sum_i input[i] * weights[i][j] + bias[j]
45        let mut output = Array1::<T>::zeros(emb_dim);
46        for j in 0..actual_out {
47            let mut acc = layer.bias[j];
48            for (i, &inp_val) in input.iter().enumerate().take(input_rows) {
49                acc = acc + inp_val * layer.weights[[i, j]];
50            }
51            // ReLU activation
52            output[j] = if acc > T::zero() { acc } else { T::zero() };
53        }
54        Ok(output)
55    }
56
57    /// Compute a prototype (class centroid) from a set of example embeddings.
58    ///
59    /// The prototype is the element-wise mean of the encoded examples.
60    pub fn compute_prototype(&self, examples: &[Array1<T>]) -> Result<Array1<T>> {
61        if examples.is_empty() {
62            return Err(OptimError::InsufficientData(
63                "Cannot compute prototype from empty example set".to_string(),
64            ));
65        }
66        let dim = examples[0].len();
67        let mut sum = Array1::<T>::zeros(dim);
68        for ex in examples {
69            let len = ex.len().min(dim);
70            for i in 0..len {
71                sum[i] = sum[i] + ex[i];
72            }
73        }
74        let count_t: T =
75            scirs2_core::numeric::NumCast::from(examples.len()).unwrap_or_else(|| T::one());
76        for i in 0..dim {
77            sum[i] = sum[i] / count_t;
78        }
79        Ok(sum)
80    }
81
82    /// Classify a query by finding the nearest prototype.
83    ///
84    /// Returns the index of the closest prototype according to squared
85    /// Euclidean distance.
86    pub fn classify(&self, query: &Array1<T>, prototypes: &[Array1<T>]) -> Result<usize> {
87        if prototypes.is_empty() {
88            return Err(OptimError::InsufficientData(
89                "No prototypes to classify against".to_string(),
90            ));
91        }
92        let (idx, _dist) = self.find_nearest_in_list(query, prototypes)?;
93        Ok(idx)
94    }
95
96    /// Find the nearest prototype and return its index plus the distance.
97    pub fn find_nearest_prototype(&self, query: &Array1<T>) -> Result<(usize, T)> {
98        let stored = self.prototypes();
99        if stored.is_empty() {
100            return Err(OptimError::InsufficientData(
101                "No stored prototypes".to_string(),
102            ));
103        }
104        let vecs: Vec<Array1<T>> = stored.values().map(|p| p.vector.clone()).collect();
105        self.find_nearest_in_list(query, &vecs)
106    }
107
108    // ---- private helpers ----
109
110    fn find_nearest_in_list(
111        &self,
112        query: &Array1<T>,
113        candidates: &[Array1<T>],
114    ) -> Result<(usize, T)> {
115        if candidates.is_empty() {
116            return Err(OptimError::InsufficientData(
117                "No candidates for nearest search".to_string(),
118            ));
119        }
120        let mut best_idx = 0;
121        let mut best_dist = T::infinity();
122        for (i, proto) in candidates.iter().enumerate() {
123            let dist = squared_euclidean(query, proto);
124            if dist < best_dist {
125                best_dist = dist;
126                best_idx = i;
127            }
128        }
129        Ok((best_idx, best_dist))
130    }
131}
132
133// ---------------------------------------------------------------------------
134// FastAdaptationEngine additional impl
135// ---------------------------------------------------------------------------
136
137impl<T: Float + Debug + Send + Sync + 'static> FastAdaptationEngine<T> {
138    /// Perform multi-step gradient adaptation on parameters.
139    ///
140    /// For each gradient in `gradients`, applies one step:
141    ///   params = params - inner_lr * gradient
142    /// returning the final adapted parameters.
143    pub fn adapt(
144        &self,
145        params: &Array1<T>,
146        gradients: &[Array1<T>],
147        inner_lr: T,
148    ) -> Result<Array1<T>> {
149        if params.is_empty() {
150            return Err(OptimError::InvalidState(
151                "Parameters must not be empty".to_string(),
152            ));
153        }
154        let mut current = params.clone();
155        for grad in gradients {
156            let len = current.len().min(grad.len());
157            for i in 0..len {
158                current[i] = current[i] - inner_lr * grad[i];
159            }
160        }
161        Ok(current)
162    }
163
164    /// Select an adaptation algorithm based on task complexity.
165    ///
166    /// Low complexity  -> FOMAML (fast, first-order)
167    /// Medium          -> MAML (second-order)
168    /// High            -> MemoryAugmented (richer capacity)
169    pub fn select_algorithm(&self, task_complexity: T) -> AdaptationStrategyType {
170        let low: T = scirs2_core::numeric::NumCast::from(0.3).unwrap_or_else(|| T::zero());
171        let high: T = scirs2_core::numeric::NumCast::from(0.7).unwrap_or_else(|| T::one());
172
173        if task_complexity < low {
174            AdaptationStrategyType::FOMAML
175        } else if task_complexity < high {
176            AdaptationStrategyType::MAML
177        } else {
178            AdaptationStrategyType::MemoryAugmented
179        }
180    }
181
182    /// Evaluate the quality of an adaptation step.
183    ///
184    /// Returns the relative improvement: ||before|| - ||after|| normalised by
185    /// ||before|| (higher is better; negative means regression).
186    pub fn evaluate_adaptation(&self, before: &Array1<T>, after: &Array1<T>) -> Result<T> {
187        let norm_before = vec_norm(before);
188        let norm_after = vec_norm(after);
189
190        if norm_before == T::zero() {
191            return Ok(T::zero());
192        }
193        Ok((norm_before - norm_after) / norm_before)
194    }
195}
196
197// ---------------------------------------------------------------------------
198// TaskSimilarityCalculator additional impl
199// ---------------------------------------------------------------------------
200
201impl<T: Float + Debug + Send + Sync + 'static> TaskSimilarityCalculator<T> {
202    /// Compute cosine similarity between two task representation vectors.
203    ///
204    /// Returns a value in [-1, 1], with 1 meaning identical direction.
205    pub fn compute_similarity(&self, task1_repr: &Array1<T>, task2_repr: &Array1<T>) -> Result<T> {
206        if task1_repr.is_empty() || task2_repr.is_empty() {
207            return Err(OptimError::InsufficientData(
208                "Task representations must not be empty".to_string(),
209            ));
210        }
211        let n1 = vec_norm(task1_repr);
212        let n2 = vec_norm(task2_repr);
213        if n1 == T::zero() || n2 == T::zero() {
214            return Ok(T::zero());
215        }
216        let len = task1_repr.len().min(task2_repr.len());
217        let mut dot = T::zero();
218        for i in 0..len {
219            dot = dot + task1_repr[i] * task2_repr[i];
220        }
221        Ok(dot / (n1 * n2))
222    }
223
224    /// Find the most similar candidate to the query.
225    ///
226    /// Returns `(index, similarity)` of the best match.
227    pub fn find_most_similar(
228        &self,
229        query: &Array1<T>,
230        candidates: &[Array1<T>],
231    ) -> Result<(usize, T)> {
232        if candidates.is_empty() {
233            return Err(OptimError::InsufficientData(
234                "No candidates for similarity search".to_string(),
235            ));
236        }
237        let mut best_idx = 0;
238        let mut best_sim = T::neg_infinity();
239        for (i, cand) in candidates.iter().enumerate() {
240            let sim = self.compute_similarity(query, cand)?;
241            if sim > best_sim {
242                best_sim = sim;
243                best_idx = i;
244            }
245        }
246        Ok((best_idx, best_sim))
247    }
248}
249
250// ---------------------------------------------------------------------------
251// Utility functions
252// ---------------------------------------------------------------------------
253
254/// Squared Euclidean distance between two vectors (truncated to shorter length).
255fn squared_euclidean<T: Float>(a: &Array1<T>, b: &Array1<T>) -> T {
256    let len = a.len().min(b.len());
257    let mut sum = T::zero();
258    for i in 0..len {
259        let d = a[i] - b[i];
260        sum = sum + d * d;
261    }
262    sum
263}
264
265/// L2 norm of a vector.
266fn vec_norm<T: Float>(v: &Array1<T>) -> T {
267    let mut sum = T::zero();
268    for &x in v.iter() {
269        sum = sum + x * x;
270    }
271    sum.sqrt()
272}
273
274// ---------------------------------------------------------------------------
275// Tests
276// ---------------------------------------------------------------------------
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281    use scirs2_core::ndarray::Array1;
282
283    #[test]
284    fn test_prototypical_network_encode() {
285        let net = PrototypicalNetwork::<f64>::from_dims(4, 3)
286            .expect("failed to create PrototypicalNetwork");
287        // With zero-initialised weights the encode should return all zeros (ReLU of 0 = 0)
288        let features = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
289        let encoded = net.encode(&features).expect("encode failed");
290        assert_eq!(encoded.len(), 4);
291        // All outputs should be zero because weights are initialised to zero
292        for &v in encoded.iter() {
293            assert!((v - 0.0).abs() < 1e-12);
294        }
295    }
296
297    #[test]
298    fn test_compute_prototype() {
299        let net = PrototypicalNetwork::<f64>::from_dims(3, 2)
300            .expect("failed to create PrototypicalNetwork");
301        let examples = vec![
302            Array1::from_vec(vec![1.0, 2.0, 3.0]),
303            Array1::from_vec(vec![3.0, 4.0, 5.0]),
304            Array1::from_vec(vec![5.0, 6.0, 7.0]),
305        ];
306        let proto = net
307            .compute_prototype(&examples)
308            .expect("compute_prototype failed");
309        assert_eq!(proto.len(), 3);
310        assert!((proto[0] - 3.0).abs() < 1e-12);
311        assert!((proto[1] - 4.0).abs() < 1e-12);
312        assert!((proto[2] - 5.0).abs() < 1e-12);
313    }
314
315    #[test]
316    fn test_classify_nearest() {
317        let net = PrototypicalNetwork::<f64>::from_dims(3, 3)
318            .expect("failed to create PrototypicalNetwork");
319        let prototypes = vec![
320            Array1::from_vec(vec![0.0, 0.0, 0.0]),
321            Array1::from_vec(vec![10.0, 10.0, 10.0]),
322            Array1::from_vec(vec![20.0, 20.0, 20.0]),
323        ];
324        // Query close to second prototype
325        let query = Array1::from_vec(vec![9.0, 11.0, 10.0]);
326        let class = net.classify(&query, &prototypes).expect("classify failed");
327        assert_eq!(class, 1);
328
329        // Query close to first prototype
330        let query2 = Array1::from_vec(vec![0.1, -0.1, 0.2]);
331        let class2 = net.classify(&query2, &prototypes).expect("classify failed");
332        assert_eq!(class2, 0);
333    }
334
335    #[test]
336    fn test_fast_adaptation() {
337        let engine = FastAdaptationEngine::<f64>::from_params(0.1, 5)
338            .expect("failed to create FastAdaptationEngine");
339        let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
340        let gradients = vec![
341            Array1::from_vec(vec![0.5, 0.5, 0.5]),
342            Array1::from_vec(vec![0.3, 0.3, 0.3]),
343        ];
344        let adapted = engine
345            .adapt(&params, &gradients, 0.1)
346            .expect("adapt failed");
347        // After step 1: [1.0 - 0.05, 2.0 - 0.05, 3.0 - 0.05] = [0.95, 1.95, 2.95]
348        // After step 2: [0.95 - 0.03, 1.95 - 0.03, 2.95 - 0.03] = [0.92, 1.92, 2.92]
349        assert!((adapted[0] - 0.92).abs() < 1e-12);
350        assert!((adapted[1] - 1.92).abs() < 1e-12);
351        assert!((adapted[2] - 2.92).abs() < 1e-12);
352
353        // Test algorithm selection
354        let strat_low = engine.select_algorithm(0.1);
355        assert!(matches!(strat_low, AdaptationStrategyType::FOMAML));
356        let strat_mid = engine.select_algorithm(0.5);
357        assert!(matches!(strat_mid, AdaptationStrategyType::MAML));
358        let strat_high = engine.select_algorithm(0.9);
359        assert!(matches!(
360            strat_high,
361            AdaptationStrategyType::MemoryAugmented
362        ));
363    }
364
365    #[test]
366    fn test_task_similarity() {
367        let calc = TaskSimilarityCalculator::<f64>::default_new()
368            .expect("failed to create TaskSimilarityCalculator");
369        let a = Array1::from_vec(vec![1.0, 0.0, 0.0]);
370        let b = Array1::from_vec(vec![1.0, 0.0, 0.0]);
371        let sim = calc
372            .compute_similarity(&a, &b)
373            .expect("compute_similarity failed");
374        assert!(
375            (sim - 1.0).abs() < 1e-12,
376            "identical vectors should have similarity 1.0"
377        );
378
379        // Orthogonal
380        let c = Array1::from_vec(vec![0.0, 1.0, 0.0]);
381        let sim2 = calc
382            .compute_similarity(&a, &c)
383            .expect("compute_similarity failed");
384        assert!(
385            sim2.abs() < 1e-12,
386            "orthogonal vectors should have similarity ~0"
387        );
388
389        // Opposite
390        let d = Array1::from_vec(vec![-1.0, 0.0, 0.0]);
391        let sim3 = calc
392            .compute_similarity(&a, &d)
393            .expect("compute_similarity failed");
394        assert!(
395            (sim3 - (-1.0)).abs() < 1e-12,
396            "opposite vectors should have similarity -1.0"
397        );
398    }
399
400    #[test]
401    fn test_find_most_similar() {
402        let calc = TaskSimilarityCalculator::<f64>::default_new()
403            .expect("failed to create TaskSimilarityCalculator");
404        let query = Array1::from_vec(vec![1.0, 1.0, 0.0]);
405        let candidates = vec![
406            Array1::from_vec(vec![0.0, 0.0, 1.0]),   // orthogonal to query
407            Array1::from_vec(vec![1.0, 1.0, 0.01]),  // very similar to query
408            Array1::from_vec(vec![-1.0, -1.0, 0.0]), // opposite to query
409        ];
410        let (idx, sim) = calc
411            .find_most_similar(&query, &candidates)
412            .expect("find_most_similar failed");
413        assert_eq!(idx, 1);
414        assert!(
415            sim > 0.99,
416            "best match should have high similarity, got {sim}"
417        );
418    }
419}