1use scirs2_core::ndarray::Array1;
7use scirs2_core::numeric::Float;
8use std::fmt::Debug;
9
10use crate::error::{OptimError, Result};
11use crate::few_shot::{
12 AdaptationStrategyType, FastAdaptationEngine, PrototypicalNetwork, TaskSimilarityCalculator,
13};
14
15impl<T: Float + Debug + Send + Sync + 'static> PrototypicalNetwork<T> {
20 pub fn encode(&self, features: &Array1<T>) -> Result<Array1<T>> {
26 let emb_dim = self.embedding_dim();
27 let layers = self.encoder_layers();
28 if layers.is_empty() {
29 return Err(OptimError::InvalidState(
30 "Encoder has no layers".to_string(),
31 ));
32 }
33 let layer = &layers[0];
34 let (input_rows, output_cols) = (layer.weights.nrows(), layer.weights.ncols());
35 let actual_out = output_cols.min(emb_dim);
36
37 let mut input = vec![T::zero(); input_rows];
39 let copy_len = features.len().min(input_rows);
40 for i in 0..copy_len {
41 input[i] = features[i];
42 }
43
44 let mut output = Array1::<T>::zeros(emb_dim);
46 for j in 0..actual_out {
47 let mut acc = layer.bias[j];
48 for (i, &inp_val) in input.iter().enumerate().take(input_rows) {
49 acc = acc + inp_val * layer.weights[[i, j]];
50 }
51 output[j] = if acc > T::zero() { acc } else { T::zero() };
53 }
54 Ok(output)
55 }
56
57 pub fn compute_prototype(&self, examples: &[Array1<T>]) -> Result<Array1<T>> {
61 if examples.is_empty() {
62 return Err(OptimError::InsufficientData(
63 "Cannot compute prototype from empty example set".to_string(),
64 ));
65 }
66 let dim = examples[0].len();
67 let mut sum = Array1::<T>::zeros(dim);
68 for ex in examples {
69 let len = ex.len().min(dim);
70 for i in 0..len {
71 sum[i] = sum[i] + ex[i];
72 }
73 }
74 let count_t: T =
75 scirs2_core::numeric::NumCast::from(examples.len()).unwrap_or_else(|| T::one());
76 for i in 0..dim {
77 sum[i] = sum[i] / count_t;
78 }
79 Ok(sum)
80 }
81
82 pub fn classify(&self, query: &Array1<T>, prototypes: &[Array1<T>]) -> Result<usize> {
87 if prototypes.is_empty() {
88 return Err(OptimError::InsufficientData(
89 "No prototypes to classify against".to_string(),
90 ));
91 }
92 let (idx, _dist) = self.find_nearest_in_list(query, prototypes)?;
93 Ok(idx)
94 }
95
96 pub fn find_nearest_prototype(&self, query: &Array1<T>) -> Result<(usize, T)> {
98 let stored = self.prototypes();
99 if stored.is_empty() {
100 return Err(OptimError::InsufficientData(
101 "No stored prototypes".to_string(),
102 ));
103 }
104 let vecs: Vec<Array1<T>> = stored.values().map(|p| p.vector.clone()).collect();
105 self.find_nearest_in_list(query, &vecs)
106 }
107
108 fn find_nearest_in_list(
111 &self,
112 query: &Array1<T>,
113 candidates: &[Array1<T>],
114 ) -> Result<(usize, T)> {
115 if candidates.is_empty() {
116 return Err(OptimError::InsufficientData(
117 "No candidates for nearest search".to_string(),
118 ));
119 }
120 let mut best_idx = 0;
121 let mut best_dist = T::infinity();
122 for (i, proto) in candidates.iter().enumerate() {
123 let dist = squared_euclidean(query, proto);
124 if dist < best_dist {
125 best_dist = dist;
126 best_idx = i;
127 }
128 }
129 Ok((best_idx, best_dist))
130 }
131}
132
133impl<T: Float + Debug + Send + Sync + 'static> FastAdaptationEngine<T> {
138 pub fn adapt(
144 &self,
145 params: &Array1<T>,
146 gradients: &[Array1<T>],
147 inner_lr: T,
148 ) -> Result<Array1<T>> {
149 if params.is_empty() {
150 return Err(OptimError::InvalidState(
151 "Parameters must not be empty".to_string(),
152 ));
153 }
154 let mut current = params.clone();
155 for grad in gradients {
156 let len = current.len().min(grad.len());
157 for i in 0..len {
158 current[i] = current[i] - inner_lr * grad[i];
159 }
160 }
161 Ok(current)
162 }
163
164 pub fn select_algorithm(&self, task_complexity: T) -> AdaptationStrategyType {
170 let low: T = scirs2_core::numeric::NumCast::from(0.3).unwrap_or_else(|| T::zero());
171 let high: T = scirs2_core::numeric::NumCast::from(0.7).unwrap_or_else(|| T::one());
172
173 if task_complexity < low {
174 AdaptationStrategyType::FOMAML
175 } else if task_complexity < high {
176 AdaptationStrategyType::MAML
177 } else {
178 AdaptationStrategyType::MemoryAugmented
179 }
180 }
181
182 pub fn evaluate_adaptation(&self, before: &Array1<T>, after: &Array1<T>) -> Result<T> {
187 let norm_before = vec_norm(before);
188 let norm_after = vec_norm(after);
189
190 if norm_before == T::zero() {
191 return Ok(T::zero());
192 }
193 Ok((norm_before - norm_after) / norm_before)
194 }
195}
196
197impl<T: Float + Debug + Send + Sync + 'static> TaskSimilarityCalculator<T> {
202 pub fn compute_similarity(&self, task1_repr: &Array1<T>, task2_repr: &Array1<T>) -> Result<T> {
206 if task1_repr.is_empty() || task2_repr.is_empty() {
207 return Err(OptimError::InsufficientData(
208 "Task representations must not be empty".to_string(),
209 ));
210 }
211 let n1 = vec_norm(task1_repr);
212 let n2 = vec_norm(task2_repr);
213 if n1 == T::zero() || n2 == T::zero() {
214 return Ok(T::zero());
215 }
216 let len = task1_repr.len().min(task2_repr.len());
217 let mut dot = T::zero();
218 for i in 0..len {
219 dot = dot + task1_repr[i] * task2_repr[i];
220 }
221 Ok(dot / (n1 * n2))
222 }
223
224 pub fn find_most_similar(
228 &self,
229 query: &Array1<T>,
230 candidates: &[Array1<T>],
231 ) -> Result<(usize, T)> {
232 if candidates.is_empty() {
233 return Err(OptimError::InsufficientData(
234 "No candidates for similarity search".to_string(),
235 ));
236 }
237 let mut best_idx = 0;
238 let mut best_sim = T::neg_infinity();
239 for (i, cand) in candidates.iter().enumerate() {
240 let sim = self.compute_similarity(query, cand)?;
241 if sim > best_sim {
242 best_sim = sim;
243 best_idx = i;
244 }
245 }
246 Ok((best_idx, best_sim))
247 }
248}
249
250fn squared_euclidean<T: Float>(a: &Array1<T>, b: &Array1<T>) -> T {
256 let len = a.len().min(b.len());
257 let mut sum = T::zero();
258 for i in 0..len {
259 let d = a[i] - b[i];
260 sum = sum + d * d;
261 }
262 sum
263}
264
265fn vec_norm<T: Float>(v: &Array1<T>) -> T {
267 let mut sum = T::zero();
268 for &x in v.iter() {
269 sum = sum + x * x;
270 }
271 sum.sqrt()
272}
273
274#[cfg(test)]
279mod tests {
280 use super::*;
281 use scirs2_core::ndarray::Array1;
282
283 #[test]
284 fn test_prototypical_network_encode() {
285 let net = PrototypicalNetwork::<f64>::from_dims(4, 3)
286 .expect("failed to create PrototypicalNetwork");
287 let features = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
289 let encoded = net.encode(&features).expect("encode failed");
290 assert_eq!(encoded.len(), 4);
291 for &v in encoded.iter() {
293 assert!((v - 0.0).abs() < 1e-12);
294 }
295 }
296
297 #[test]
298 fn test_compute_prototype() {
299 let net = PrototypicalNetwork::<f64>::from_dims(3, 2)
300 .expect("failed to create PrototypicalNetwork");
301 let examples = vec![
302 Array1::from_vec(vec![1.0, 2.0, 3.0]),
303 Array1::from_vec(vec![3.0, 4.0, 5.0]),
304 Array1::from_vec(vec![5.0, 6.0, 7.0]),
305 ];
306 let proto = net
307 .compute_prototype(&examples)
308 .expect("compute_prototype failed");
309 assert_eq!(proto.len(), 3);
310 assert!((proto[0] - 3.0).abs() < 1e-12);
311 assert!((proto[1] - 4.0).abs() < 1e-12);
312 assert!((proto[2] - 5.0).abs() < 1e-12);
313 }
314
315 #[test]
316 fn test_classify_nearest() {
317 let net = PrototypicalNetwork::<f64>::from_dims(3, 3)
318 .expect("failed to create PrototypicalNetwork");
319 let prototypes = vec![
320 Array1::from_vec(vec![0.0, 0.0, 0.0]),
321 Array1::from_vec(vec![10.0, 10.0, 10.0]),
322 Array1::from_vec(vec![20.0, 20.0, 20.0]),
323 ];
324 let query = Array1::from_vec(vec![9.0, 11.0, 10.0]);
326 let class = net.classify(&query, &prototypes).expect("classify failed");
327 assert_eq!(class, 1);
328
329 let query2 = Array1::from_vec(vec![0.1, -0.1, 0.2]);
331 let class2 = net.classify(&query2, &prototypes).expect("classify failed");
332 assert_eq!(class2, 0);
333 }
334
335 #[test]
336 fn test_fast_adaptation() {
337 let engine = FastAdaptationEngine::<f64>::from_params(0.1, 5)
338 .expect("failed to create FastAdaptationEngine");
339 let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
340 let gradients = vec![
341 Array1::from_vec(vec![0.5, 0.5, 0.5]),
342 Array1::from_vec(vec![0.3, 0.3, 0.3]),
343 ];
344 let adapted = engine
345 .adapt(¶ms, &gradients, 0.1)
346 .expect("adapt failed");
347 assert!((adapted[0] - 0.92).abs() < 1e-12);
350 assert!((adapted[1] - 1.92).abs() < 1e-12);
351 assert!((adapted[2] - 2.92).abs() < 1e-12);
352
353 let strat_low = engine.select_algorithm(0.1);
355 assert!(matches!(strat_low, AdaptationStrategyType::FOMAML));
356 let strat_mid = engine.select_algorithm(0.5);
357 assert!(matches!(strat_mid, AdaptationStrategyType::MAML));
358 let strat_high = engine.select_algorithm(0.9);
359 assert!(matches!(
360 strat_high,
361 AdaptationStrategyType::MemoryAugmented
362 ));
363 }
364
365 #[test]
366 fn test_task_similarity() {
367 let calc = TaskSimilarityCalculator::<f64>::default_new()
368 .expect("failed to create TaskSimilarityCalculator");
369 let a = Array1::from_vec(vec![1.0, 0.0, 0.0]);
370 let b = Array1::from_vec(vec![1.0, 0.0, 0.0]);
371 let sim = calc
372 .compute_similarity(&a, &b)
373 .expect("compute_similarity failed");
374 assert!(
375 (sim - 1.0).abs() < 1e-12,
376 "identical vectors should have similarity 1.0"
377 );
378
379 let c = Array1::from_vec(vec![0.0, 1.0, 0.0]);
381 let sim2 = calc
382 .compute_similarity(&a, &c)
383 .expect("compute_similarity failed");
384 assert!(
385 sim2.abs() < 1e-12,
386 "orthogonal vectors should have similarity ~0"
387 );
388
389 let d = Array1::from_vec(vec![-1.0, 0.0, 0.0]);
391 let sim3 = calc
392 .compute_similarity(&a, &d)
393 .expect("compute_similarity failed");
394 assert!(
395 (sim3 - (-1.0)).abs() < 1e-12,
396 "opposite vectors should have similarity -1.0"
397 );
398 }
399
400 #[test]
401 fn test_find_most_similar() {
402 let calc = TaskSimilarityCalculator::<f64>::default_new()
403 .expect("failed to create TaskSimilarityCalculator");
404 let query = Array1::from_vec(vec![1.0, 1.0, 0.0]);
405 let candidates = vec![
406 Array1::from_vec(vec![0.0, 0.0, 1.0]), Array1::from_vec(vec![1.0, 1.0, 0.01]), Array1::from_vec(vec![-1.0, -1.0, 0.0]), ];
410 let (idx, sim) = calc
411 .find_most_similar(&query, &candidates)
412 .expect("find_most_similar failed");
413 assert_eq!(idx, 1);
414 assert!(
415 sim > 0.99,
416 "best match should have high similarity, got {sim}"
417 );
418 }
419}