1use anyhow::{anyhow, Result};
33use serde::{Deserialize, Serialize};
34
35#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
41pub enum EnsembleStrategy {
42 Voting,
44 WeightedAverage,
47 Stacking,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct EnsembleConfig {
55 pub strategy: EnsembleStrategy,
57 pub output_dim: usize,
59 pub stacking_hidden_dim: usize,
61 pub stacking_lr: f64,
63 pub stacking_epochs: usize,
65 pub normalize: bool,
67}
68
69impl Default for EnsembleConfig {
70 fn default() -> Self {
71 Self {
72 strategy: EnsembleStrategy::Voting,
73 output_dim: 64,
74 stacking_hidden_dim: 128,
75 stacking_lr: 0.01,
76 stacking_epochs: 50,
77 normalize: true,
78 }
79 }
80}
81
82struct StackingMLP {
90 w1: Vec<Vec<f64>>,
92 b1: Vec<f64>,
94 w2: Vec<Vec<f64>>,
96 b2: Vec<f64>,
98 input_dim: usize,
99 hidden_dim: usize,
100 output_dim: usize,
101}
102
103impl StackingMLP {
104 fn new(input_dim: usize, hidden_dim: usize, output_dim: usize, seed: u64) -> Self {
106 let mut state = seed.wrapping_add(1);
107 let mut lcg = || -> f64 {
108 state = state
109 .wrapping_mul(6364136223846793005)
110 .wrapping_add(1442695040888963407);
111 (((state >> 11) as f64) / ((1u64 << 53) as f64)) * 2.0 - 1.0
113 };
114
115 let xavier1 = (6.0_f64 / (input_dim + hidden_dim) as f64).sqrt();
116 let xavier2 = (6.0_f64 / (hidden_dim + output_dim) as f64).sqrt();
117
118 let w1 = (0..hidden_dim)
119 .map(|_| (0..input_dim).map(|_| lcg() * xavier1).collect())
120 .collect();
121 let b1 = vec![0.0; hidden_dim];
122 let w2 = (0..output_dim)
123 .map(|_| (0..hidden_dim).map(|_| lcg() * xavier2).collect())
124 .collect();
125 let b2 = vec![0.0; output_dim];
126
127 Self {
128 w1,
129 b1,
130 w2,
131 b2,
132 input_dim,
133 hidden_dim,
134 output_dim,
135 }
136 }
137
138 fn forward(&self, x: &[f64]) -> (Vec<f64>, Vec<f64>) {
140 let mut h = vec![0.0; self.hidden_dim];
142 for (i, hi) in h.iter_mut().enumerate() {
143 let dot: f64 = self.w1[i].iter().zip(x.iter()).map(|(w, xi)| w * xi).sum();
144 *hi = (dot + self.b1[i]).max(0.0); }
146 let mut out = vec![0.0; self.output_dim];
148 for (i, oi) in out.iter_mut().enumerate() {
149 let dot: f64 = self.w2[i].iter().zip(h.iter()).map(|(w, hi)| w * hi).sum();
150 *oi = dot + self.b2[i];
151 }
152 (h, out)
153 }
154
155 fn backward_step(&mut self, x: &[f64], y: &[f64], lr: f64) {
157 let (h, out) = self.forward(x);
158
159 let d_out: Vec<f64> = out
161 .iter()
162 .zip(y.iter())
163 .map(|(o, t)| 2.0 * (o - t))
164 .collect();
165
166 for (i, di) in d_out.iter().enumerate() {
168 for (j, hj) in h.iter().enumerate() {
169 self.w2[i][j] -= lr * di * hj;
170 }
171 self.b2[i] -= lr * di;
172 }
173
174 let mut d_h = vec![0.0; self.hidden_dim];
176 for (j, dj) in d_h.iter_mut().enumerate() {
177 let back: f64 = (0..self.output_dim).map(|i| d_out[i] * self.w2[i][j]).sum();
178 *dj = if h[j] > 0.0 { back } else { 0.0 };
179 }
180
181 for (i, di) in d_h.iter().enumerate() {
183 for (j, xj) in x.iter().enumerate() {
184 self.w1[i][j] -= lr * di * xj;
185 }
186 self.b1[i] -= lr * di;
187 }
188 }
189
190 fn predict(&self, x: &[f64]) -> Vec<f64> {
191 self.forward(x).1
192 }
193}
194
195type EnsembleModel = Box<dyn Fn(&str) -> Vec<f64> + Send + Sync>;
206
207pub struct EnsembleEmbedder {
208 models: Vec<EnsembleModel>,
209 config: EnsembleConfig,
210 weights: Vec<f64>,
212 stacking_mlp: Option<StackingMLP>,
214}
215
216impl std::fmt::Debug for EnsembleEmbedder {
217 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
218 f.debug_struct("EnsembleEmbedder")
219 .field("num_models", &self.models.len())
220 .field("config", &self.config)
221 .field("strategy", &self.config.strategy)
222 .finish()
223 }
224}
225
226impl EnsembleEmbedder {
227 pub fn new(models: Vec<EnsembleModel>, config: EnsembleConfig) -> Result<Self> {
232 if models.is_empty() {
233 return Err(anyhow!("EnsembleEmbedder requires at least one model"));
234 }
235 let n = models.len();
236 let weights = vec![1.0 / n as f64; n]; Ok(Self {
238 models,
239 config,
240 weights,
241 stacking_mlp: None,
242 })
243 }
244
245 pub fn num_models(&self) -> usize {
247 self.models.len()
248 }
249
250 pub fn set_weights(&mut self, weights: Vec<f64>) -> Result<()> {
258 if weights.len() != self.models.len() {
259 return Err(anyhow!(
260 "weight vector length {} != model count {}",
261 weights.len(),
262 self.models.len()
263 ));
264 }
265 for &w in &weights {
266 if !w.is_finite() || w < 0.0 {
267 return Err(anyhow!("all weights must be non-negative and finite"));
268 }
269 }
270 let sum: f64 = weights.iter().sum();
271 if sum == 0.0 {
272 return Err(anyhow!("weight sum must be > 0"));
273 }
274 self.weights = weights.iter().map(|w| w / sum).collect();
275 Ok(())
276 }
277
278 fn collect_embeddings(&self, key: &str) -> Result<Vec<Vec<f64>>> {
280 let embeddings: Vec<Vec<f64>> = self.models.iter().map(|m| m(key)).collect();
281 for (i, emb) in embeddings.iter().enumerate() {
283 if emb.len() != self.config.output_dim {
284 return Err(anyhow!(
285 "model {} returned embedding of dimension {} but config.output_dim is {}",
286 i,
287 emb.len(),
288 self.config.output_dim
289 ));
290 }
291 }
292 Ok(embeddings)
293 }
294
295 fn l2_normalize(v: &mut [f64]) {
297 let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
298 if norm > 1e-10 {
299 for x in v.iter_mut() {
300 *x /= norm;
301 }
302 }
303 }
304
305 pub fn embed(&self, key: &str) -> Result<Vec<f64>> {
311 let embeddings = self.collect_embeddings(key)?;
312 let mut result = match self.config.strategy {
313 EnsembleStrategy::Voting => {
314 let dim = self.config.output_dim;
315 let mut agg = vec![0.0; dim];
316 for emb in &embeddings {
317 for (a, e) in agg.iter_mut().zip(emb.iter()) {
318 *a += e;
319 }
320 }
321 let n = embeddings.len() as f64;
322 agg.iter_mut().for_each(|a| *a /= n);
323 agg
324 }
325 EnsembleStrategy::WeightedAverage => {
326 let dim = self.config.output_dim;
327 let mut agg = vec![0.0; dim];
328 for (emb, &w) in embeddings.iter().zip(self.weights.iter()) {
329 for (a, e) in agg.iter_mut().zip(emb.iter()) {
330 *a += w * e;
331 }
332 }
333 agg
334 }
335 EnsembleStrategy::Stacking => {
336 let mlp = self.stacking_mlp.as_ref().ok_or_else(|| {
337 anyhow!("Stacking strategy requires calling fit_stacking() first")
338 })?;
339 let concat: Vec<f64> = embeddings.into_iter().flatten().collect();
341 mlp.predict(&concat)
342 }
343 };
344 if self.config.normalize {
345 Self::l2_normalize(&mut result);
346 }
347 Ok(result)
348 }
349
350 pub fn fit_stacking(&mut self, validation_pairs: &[(&str, Vec<f64>)]) -> Result<()> {
360 if validation_pairs.is_empty() {
361 return Err(anyhow!("validation set must not be empty for stacking"));
362 }
363 let concat_dim = self.models.len() * self.config.output_dim;
364 if self.stacking_mlp.is_none() {
365 self.stacking_mlp = Some(StackingMLP::new(
366 concat_dim,
367 self.config.stacking_hidden_dim,
368 self.config.output_dim,
369 42,
370 ));
371 }
372 for _epoch in 0..self.config.stacking_epochs {
373 for (key, target) in validation_pairs {
374 if target.len() != self.config.output_dim {
375 return Err(anyhow!(
376 "target embedding dimension {} != config.output_dim {}",
377 target.len(),
378 self.config.output_dim
379 ));
380 }
381 let embeddings = self.collect_embeddings(key)?;
382 let concat: Vec<f64> = embeddings.into_iter().flatten().collect();
383 if let Some(mlp) = &mut self.stacking_mlp {
384 mlp.backward_step(&concat, target, self.config.stacking_lr);
385 }
386 }
387 }
388 Ok(())
389 }
390
391 pub fn eval_cosine(
394 &self,
395 reference: &impl Fn(&str) -> Vec<f64>,
396 validation_keys: &[&str],
397 ) -> Result<f64> {
398 if validation_keys.is_empty() {
399 return Ok(0.0);
400 }
401 let mut total = 0.0;
402 for &key in validation_keys {
403 let pred = self.embed(key)?;
404 let ref_emb = reference(key);
405 let dot: f64 = pred.iter().zip(ref_emb.iter()).map(|(a, b)| a * b).sum();
406 let norm_pred: f64 = pred.iter().map(|x| x * x).sum::<f64>().sqrt();
407 let norm_ref: f64 = ref_emb.iter().map(|x| x * x).sum::<f64>().sqrt();
408 let cos = if norm_pred > 1e-10 && norm_ref > 1e-10 {
409 (dot / (norm_pred * norm_ref)).clamp(-1.0, 1.0)
410 } else {
411 0.0
412 };
413 total += cos;
414 }
415 Ok(total / validation_keys.len() as f64)
416 }
417
418 pub fn derive_weights(
424 &mut self,
425 reference: &impl Fn(&str) -> Vec<f64>,
426 validation_keys: &[&str],
427 ) -> Result<()> {
428 let mut scores = vec![0.0_f64; self.models.len()];
429 for &key in validation_keys {
430 let ref_emb = reference(key);
431 for (i, model) in self.models.iter().enumerate() {
432 let emb = model(key);
433 let dot: f64 = emb.iter().zip(ref_emb.iter()).map(|(a, b)| a * b).sum();
434 let na: f64 = emb.iter().map(|x| x * x).sum::<f64>().sqrt();
435 let nb: f64 = ref_emb.iter().map(|x| x * x).sum::<f64>().sqrt();
436 let cos = if na > 1e-10 && nb > 1e-10 {
437 (dot / (na * nb)).clamp(-1.0, 1.0)
438 } else {
439 0.0
440 };
441 scores[i] += cos;
442 }
443 }
444 let n = validation_keys.len().max(1) as f64;
446 let weights: Vec<f64> = scores.iter().map(|s| (s / n).max(1e-6)).collect();
447 self.set_weights(weights)
448 }
449}
450
451#[cfg(test)]
456mod tests {
457 use super::*;
458
459 fn make_model(value: f64, dim: usize) -> EnsembleModel {
460 Box::new(move |_key: &str| vec![value; dim])
461 }
462
463 #[test]
464 fn test_voting_mean() {
465 let models: Vec<EnsembleModel> = vec![make_model(1.0, 4), make_model(3.0, 4)];
466 let config = EnsembleConfig {
467 strategy: EnsembleStrategy::Voting,
468 output_dim: 4,
469 normalize: false,
470 ..Default::default()
471 };
472 let embedder = EnsembleEmbedder::new(models, config).unwrap();
473 let emb = embedder.embed("e1").unwrap();
474 for v in &emb {
475 assert!((v - 2.0).abs() < 1e-9, "expected 2.0 got {v}");
476 }
477 }
478
479 #[test]
480 fn test_weighted_average() {
481 let models: Vec<EnsembleModel> = vec![make_model(0.0, 4), make_model(4.0, 4)];
482 let config = EnsembleConfig {
483 strategy: EnsembleStrategy::WeightedAverage,
484 output_dim: 4,
485 normalize: false,
486 ..Default::default()
487 };
488 let mut embedder = EnsembleEmbedder::new(models, config).unwrap();
489 embedder.set_weights(vec![1.0, 3.0]).unwrap();
491 let emb = embedder.embed("e1").unwrap();
492 for v in &emb {
493 assert!((v - 3.0).abs() < 1e-9, "expected 3.0 got {v}");
494 }
495 }
496
497 #[test]
498 fn test_zero_weight_model_excluded() {
499 let models: Vec<EnsembleModel> = vec![
500 make_model(100.0, 4), make_model(1.0, 4),
502 ];
503 let config = EnsembleConfig {
504 strategy: EnsembleStrategy::WeightedAverage,
505 output_dim: 4,
506 normalize: false,
507 ..Default::default()
508 };
509 let mut embedder = EnsembleEmbedder::new(models, config).unwrap();
510 embedder.set_weights(vec![1e-10, 1.0]).unwrap();
512 let emb = embedder.embed("e1").unwrap();
513 for v in &emb {
515 assert!((v - 1.0).abs() < 0.01, "expected ≈1.0 got {v}");
516 }
517 }
518
519 #[test]
520 fn test_stacking_convergence() {
521 let dim = 8;
522 let models: Vec<EnsembleModel> = vec![make_model(1.0, dim), make_model(0.0, dim)];
524 let config = EnsembleConfig {
525 strategy: EnsembleStrategy::Stacking,
526 output_dim: dim,
527 stacking_hidden_dim: 32,
528 stacking_lr: 0.01,
529 stacking_epochs: 200,
530 normalize: false,
531 };
532 let mut embedder = EnsembleEmbedder::new(models, config).unwrap();
533 let targets: Vec<(&str, Vec<f64>)> = (0..20)
535 .map(|i| {
536 let key = Box::leak(format!("e{i}").into_boxed_str()) as &str;
537 (key, vec![0.5; dim])
538 })
539 .collect();
540 embedder.fit_stacking(&targets).unwrap();
541 let emb = embedder.embed("e0").unwrap();
542 for v in &emb {
544 assert!(
545 (v - 0.5).abs() < 0.2,
546 "expected ≈0.5 after stacking, got {v}"
547 );
548 }
549 }
550
551 #[test]
552 fn test_derive_weights() {
553 let dim = 4;
554 let models: Vec<EnsembleModel> = vec![
556 Box::new(move |_| vec![1.0; dim]),
557 Box::new(move |_| vec![0.0; dim]),
558 ];
559 let reference = |_key: &str| vec![1.0f64; dim];
560 let config = EnsembleConfig {
561 strategy: EnsembleStrategy::WeightedAverage,
562 output_dim: dim,
563 normalize: false,
564 ..Default::default()
565 };
566 let mut embedder = EnsembleEmbedder::new(models, config).unwrap();
567 let keys = vec!["e0", "e1", "e2"];
568 embedder.derive_weights(&reference, &keys).unwrap();
569 assert!(embedder.weights[0] > embedder.weights[1] * 100.0);
571 }
572
573 #[test]
574 fn test_empty_models_rejected() {
575 let models: Vec<EnsembleModel> = vec![];
576 let config = EnsembleConfig::default();
577 assert!(EnsembleEmbedder::new(models, config).is_err());
578 }
579
580 #[test]
581 fn test_stacking_requires_fit() {
582 let models: Vec<EnsembleModel> = vec![make_model(1.0, 4)];
583 let config = EnsembleConfig {
584 strategy: EnsembleStrategy::Stacking,
585 output_dim: 4,
586 ..Default::default()
587 };
588 let embedder = EnsembleEmbedder::new(models, config).unwrap();
589 assert!(embedder.embed("e1").is_err());
590 }
591
592 #[test]
593 fn test_normalize_output() {
594 let models: Vec<EnsembleModel> = vec![make_model(3.0, 4)];
595 let config = EnsembleConfig {
596 strategy: EnsembleStrategy::Voting,
597 output_dim: 4,
598 normalize: true,
599 ..Default::default()
600 };
601 let embedder = EnsembleEmbedder::new(models, config).unwrap();
602 let emb = embedder.embed("e1").unwrap();
603 let norm: f64 = emb.iter().map(|x| x * x).sum::<f64>().sqrt();
604 assert!(
605 (norm - 1.0).abs() < 1e-9,
606 "norm should be 1.0 after normalize, got {norm}"
607 );
608 }
609}