Skip to main content

oxirs_embed/
ensemble.rs

1//! Ensemble Embedding Methods (v0.3.0)
2//!
3//! Aggregates multiple embedding models using three strategies:
4//! - **Voting** (mean pooling): average embeddings across all models element-wise.
5//! - **WeightedAverage**: performance-weighted mean (weights from validation cosine similarity).
6//! - **Stacking**: a two-layer meta-learner (linear + ReLU + linear) trained on concatenated
7//!   model outputs; requires a held-out validation set.
8//!
9//! ## Design
10//!
11//! Each "model" in the ensemble is represented as a boxed function `Fn(&str) -> Vec<f64>`
12//! (same contract as `ModelVariant` in the A/B testing module).  This keeps the
13//! ensemble decoupled from specific KGE implementations.
14//!
15//! ```rust,no_run
16//! use oxirs_embed::ensemble::{EnsembleConfig, EnsembleEmbedder, EnsembleStrategy};
17//!
18//! let models: Vec<Box<dyn Fn(&str) -> Vec<f64> + Send + Sync>> = vec![
19//!     Box::new(|_key: &str| vec![1.0f64; 16]),
20//!     Box::new(|_key: &str| vec![2.0f64; 16]),
21//! ];
22//! let config = EnsembleConfig {
23//!     strategy: EnsembleStrategy::Voting,
24//!     output_dim: 16,
25//!     ..Default::default()
26//! };
27//! let embedder = EnsembleEmbedder::new(models, config).expect("valid config");
28//! let embedding = embedder.embed("entity:Alice").expect("embedding");
29//! assert_eq!(embedding.len(), 16);
30//! ```
31
32use anyhow::{anyhow, Result};
33use serde::{Deserialize, Serialize};
34
35// ─────────────────────────────────────────────────────────────────────────────
36// EnsembleStrategy
37// ─────────────────────────────────────────────────────────────────────────────
38
39/// Aggregation strategy for combining model embeddings.
40#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
41pub enum EnsembleStrategy {
42    /// Element-wise mean over all model outputs.
43    Voting,
44    /// Weighted element-wise mean. Weights must be supplied via
45    /// [`EnsembleEmbedder::set_weights`]; default is uniform.
46    WeightedAverage,
47    /// Two-layer meta-learner trained on concatenated model outputs.
48    /// Requires calling [`EnsembleEmbedder::fit_stacking`] before inference.
49    Stacking,
50}
51
52/// Configuration for [`EnsembleEmbedder`].
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct EnsembleConfig {
55    /// Aggregation strategy.
56    pub strategy: EnsembleStrategy,
57    /// Dimensionality of each individual model's output.
58    pub output_dim: usize,
59    /// Hidden dimension of the stacking meta-learner.
60    pub stacking_hidden_dim: usize,
61    /// Learning rate for stacking meta-learner training.
62    pub stacking_lr: f64,
63    /// Number of gradient-descent epochs for stacking.
64    pub stacking_epochs: usize,
65    /// L2-normalize final embedding (all strategies).
66    pub normalize: bool,
67}
68
69impl Default for EnsembleConfig {
70    fn default() -> Self {
71        Self {
72            strategy: EnsembleStrategy::Voting,
73            output_dim: 64,
74            stacking_hidden_dim: 128,
75            stacking_lr: 0.01,
76            stacking_epochs: 50,
77            normalize: true,
78        }
79    }
80}
81
82// ─────────────────────────────────────────────────────────────────────────────
83// Stacking meta-learner (linear → ReLU → linear)
84// ─────────────────────────────────────────────────────────────────────────────
85
86/// A minimal two-layer MLP meta-learner for the stacking strategy.
87///
88/// Architecture: `[concat_dim] → hidden → ReLU → output_dim`
89struct StackingMLP {
90    /// Weight matrix W1: shape [hidden, concat_dim]
91    w1: Vec<Vec<f64>>,
92    /// Bias b1: shape [hidden]
93    b1: Vec<f64>,
94    /// Weight matrix W2: shape [output_dim, hidden]
95    w2: Vec<Vec<f64>>,
96    /// Bias b2: shape [output_dim]
97    b2: Vec<f64>,
98    input_dim: usize,
99    hidden_dim: usize,
100    output_dim: usize,
101}
102
103impl StackingMLP {
104    /// Initialise weights with Xavier uniform (±sqrt(6/(fan_in+fan_out))).
105    fn new(input_dim: usize, hidden_dim: usize, output_dim: usize, seed: u64) -> Self {
106        let mut state = seed.wrapping_add(1);
107        let mut lcg = || -> f64 {
108            state = state
109                .wrapping_mul(6364136223846793005)
110                .wrapping_add(1442695040888963407);
111            // map to (-1, 1)
112            (((state >> 11) as f64) / ((1u64 << 53) as f64)) * 2.0 - 1.0
113        };
114
115        let xavier1 = (6.0_f64 / (input_dim + hidden_dim) as f64).sqrt();
116        let xavier2 = (6.0_f64 / (hidden_dim + output_dim) as f64).sqrt();
117
118        let w1 = (0..hidden_dim)
119            .map(|_| (0..input_dim).map(|_| lcg() * xavier1).collect())
120            .collect();
121        let b1 = vec![0.0; hidden_dim];
122        let w2 = (0..output_dim)
123            .map(|_| (0..hidden_dim).map(|_| lcg() * xavier2).collect())
124            .collect();
125        let b2 = vec![0.0; output_dim];
126
127        Self {
128            w1,
129            b1,
130            w2,
131            b2,
132            input_dim,
133            hidden_dim,
134            output_dim,
135        }
136    }
137
138    /// Forward pass: returns hidden activations and output.
139    fn forward(&self, x: &[f64]) -> (Vec<f64>, Vec<f64>) {
140        // Hidden layer
141        let mut h = vec![0.0; self.hidden_dim];
142        for (i, hi) in h.iter_mut().enumerate() {
143            let dot: f64 = self.w1[i].iter().zip(x.iter()).map(|(w, xi)| w * xi).sum();
144            *hi = (dot + self.b1[i]).max(0.0); // ReLU
145        }
146        // Output layer
147        let mut out = vec![0.0; self.output_dim];
148        for (i, oi) in out.iter_mut().enumerate() {
149            let dot: f64 = self.w2[i].iter().zip(h.iter()).map(|(w, hi)| w * hi).sum();
150            *oi = dot + self.b2[i];
151        }
152        (h, out)
153    }
154
155    /// Single SGD step with MSE loss given target `y`.
156    fn backward_step(&mut self, x: &[f64], y: &[f64], lr: f64) {
157        let (h, out) = self.forward(x);
158
159        // Output layer gradients (MSE derivative = 2*(out - y))
160        let d_out: Vec<f64> = out
161            .iter()
162            .zip(y.iter())
163            .map(|(o, t)| 2.0 * (o - t))
164            .collect();
165
166        // Gradient w2 and b2
167        for (i, di) in d_out.iter().enumerate() {
168            for (j, hj) in h.iter().enumerate() {
169                self.w2[i][j] -= lr * di * hj;
170            }
171            self.b2[i] -= lr * di;
172        }
173
174        // Backprop through ReLU into hidden layer
175        let mut d_h = vec![0.0; self.hidden_dim];
176        for (j, dj) in d_h.iter_mut().enumerate() {
177            let back: f64 = (0..self.output_dim).map(|i| d_out[i] * self.w2[i][j]).sum();
178            *dj = if h[j] > 0.0 { back } else { 0.0 };
179        }
180
181        // Gradient w1 and b1
182        for (i, di) in d_h.iter().enumerate() {
183            for (j, xj) in x.iter().enumerate() {
184                self.w1[i][j] -= lr * di * xj;
185            }
186            self.b1[i] -= lr * di;
187        }
188    }
189
190    fn predict(&self, x: &[f64]) -> Vec<f64> {
191        self.forward(x).1
192    }
193}
194
195// ─────────────────────────────────────────────────────────────────────────────
196// EnsembleEmbedder
197// ─────────────────────────────────────────────────────────────────────────────
198
199/// Aggregates multiple embedding model functions under a single interface.
200///
201/// Each model is a boxed `Fn(&str) -> Vec<f64>` that maps an entity key to
202/// an embedding vector.  All models must produce embeddings of the same
203/// dimensionality (`config.output_dim`).
204/// Type alias for an ensemble model function: maps an entity key to its embedding.
205type EnsembleModel = Box<dyn Fn(&str) -> Vec<f64> + Send + Sync>;
206
207pub struct EnsembleEmbedder {
208    models: Vec<EnsembleModel>,
209    config: EnsembleConfig,
210    /// Per-model weights (only used for `WeightedAverage`).
211    weights: Vec<f64>,
212    /// Trained stacking meta-learner (only for `Stacking` strategy).
213    stacking_mlp: Option<StackingMLP>,
214}
215
216impl std::fmt::Debug for EnsembleEmbedder {
217    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
218        f.debug_struct("EnsembleEmbedder")
219            .field("num_models", &self.models.len())
220            .field("config", &self.config)
221            .field("strategy", &self.config.strategy)
222            .finish()
223    }
224}
225
226impl EnsembleEmbedder {
227    /// Create a new ensemble embedder.
228    ///
229    /// # Errors
230    /// Returns an error if `models` is empty.
231    pub fn new(models: Vec<EnsembleModel>, config: EnsembleConfig) -> Result<Self> {
232        if models.is_empty() {
233            return Err(anyhow!("EnsembleEmbedder requires at least one model"));
234        }
235        let n = models.len();
236        let weights = vec![1.0 / n as f64; n]; // uniform by default
237        Ok(Self {
238            models,
239            config,
240            weights,
241            stacking_mlp: None,
242        })
243    }
244
245    /// Return the number of models in the ensemble.
246    pub fn num_models(&self) -> usize {
247        self.models.len()
248    }
249
250    /// Set per-model weights for `WeightedAverage`.
251    ///
252    /// Weights are automatically normalised to sum to 1.
253    ///
254    /// # Errors
255    /// Returns an error if the weight vector length does not match the model count,
256    /// or if any weight is negative or not finite, or if the sum is zero.
257    pub fn set_weights(&mut self, weights: Vec<f64>) -> Result<()> {
258        if weights.len() != self.models.len() {
259            return Err(anyhow!(
260                "weight vector length {} != model count {}",
261                weights.len(),
262                self.models.len()
263            ));
264        }
265        for &w in &weights {
266            if !w.is_finite() || w < 0.0 {
267                return Err(anyhow!("all weights must be non-negative and finite"));
268            }
269        }
270        let sum: f64 = weights.iter().sum();
271        if sum == 0.0 {
272            return Err(anyhow!("weight sum must be > 0"));
273        }
274        self.weights = weights.iter().map(|w| w / sum).collect();
275        Ok(())
276    }
277
278    /// Collect raw embeddings from every model for `key`.
279    fn collect_embeddings(&self, key: &str) -> Result<Vec<Vec<f64>>> {
280        let embeddings: Vec<Vec<f64>> = self.models.iter().map(|m| m(key)).collect();
281        // Validate dimensionality
282        for (i, emb) in embeddings.iter().enumerate() {
283            if emb.len() != self.config.output_dim {
284                return Err(anyhow!(
285                    "model {} returned embedding of dimension {} but config.output_dim is {}",
286                    i,
287                    emb.len(),
288                    self.config.output_dim
289                ));
290            }
291        }
292        Ok(embeddings)
293    }
294
295    /// L2-normalize a vector in-place. No-op if the norm is zero.
296    fn l2_normalize(v: &mut [f64]) {
297        let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
298        if norm > 1e-10 {
299            for x in v.iter_mut() {
300                *x /= norm;
301            }
302        }
303    }
304
305    /// Produce an ensemble embedding for the given entity key.
306    ///
307    /// # Errors
308    /// - Dimensionality mismatch between model output and `config.output_dim`.
309    /// - `Stacking` strategy called before [`EnsembleEmbedder::fit_stacking`].
310    pub fn embed(&self, key: &str) -> Result<Vec<f64>> {
311        let embeddings = self.collect_embeddings(key)?;
312        let mut result = match self.config.strategy {
313            EnsembleStrategy::Voting => {
314                let dim = self.config.output_dim;
315                let mut agg = vec![0.0; dim];
316                for emb in &embeddings {
317                    for (a, e) in agg.iter_mut().zip(emb.iter()) {
318                        *a += e;
319                    }
320                }
321                let n = embeddings.len() as f64;
322                agg.iter_mut().for_each(|a| *a /= n);
323                agg
324            }
325            EnsembleStrategy::WeightedAverage => {
326                let dim = self.config.output_dim;
327                let mut agg = vec![0.0; dim];
328                for (emb, &w) in embeddings.iter().zip(self.weights.iter()) {
329                    for (a, e) in agg.iter_mut().zip(emb.iter()) {
330                        *a += w * e;
331                    }
332                }
333                agg
334            }
335            EnsembleStrategy::Stacking => {
336                let mlp = self.stacking_mlp.as_ref().ok_or_else(|| {
337                    anyhow!("Stacking strategy requires calling fit_stacking() first")
338                })?;
339                // Concatenate all embeddings into a single input vector
340                let concat: Vec<f64> = embeddings.into_iter().flatten().collect();
341                mlp.predict(&concat)
342            }
343        };
344        if self.config.normalize {
345            Self::l2_normalize(&mut result);
346        }
347        Ok(result)
348    }
349
350    /// Train the stacking meta-learner on a validation set.
351    ///
352    /// `validation_pairs` is a slice of `(key, target_embedding)` pairs where
353    /// `target_embedding` is the ground-truth embedding (e.g. from a stronger
354    /// reference model or link-prediction evaluation).
355    ///
356    /// # Errors
357    /// Returns an error if the validation set is empty, or if any target
358    /// embedding has the wrong dimensionality.
359    pub fn fit_stacking(&mut self, validation_pairs: &[(&str, Vec<f64>)]) -> Result<()> {
360        if validation_pairs.is_empty() {
361            return Err(anyhow!("validation set must not be empty for stacking"));
362        }
363        let concat_dim = self.models.len() * self.config.output_dim;
364        if self.stacking_mlp.is_none() {
365            self.stacking_mlp = Some(StackingMLP::new(
366                concat_dim,
367                self.config.stacking_hidden_dim,
368                self.config.output_dim,
369                42,
370            ));
371        }
372        for _epoch in 0..self.config.stacking_epochs {
373            for (key, target) in validation_pairs {
374                if target.len() != self.config.output_dim {
375                    return Err(anyhow!(
376                        "target embedding dimension {} != config.output_dim {}",
377                        target.len(),
378                        self.config.output_dim
379                    ));
380                }
381                let embeddings = self.collect_embeddings(key)?;
382                let concat: Vec<f64> = embeddings.into_iter().flatten().collect();
383                if let Some(mlp) = &mut self.stacking_mlp {
384                    mlp.backward_step(&concat, target, self.config.stacking_lr);
385                }
386            }
387        }
388        Ok(())
389    }
390
391    /// Compute the average cosine similarity between ensemble output and a reference
392    /// model on the given validation keys. Useful for tuning weights.
393    pub fn eval_cosine(
394        &self,
395        reference: &impl Fn(&str) -> Vec<f64>,
396        validation_keys: &[&str],
397    ) -> Result<f64> {
398        if validation_keys.is_empty() {
399            return Ok(0.0);
400        }
401        let mut total = 0.0;
402        for &key in validation_keys {
403            let pred = self.embed(key)?;
404            let ref_emb = reference(key);
405            let dot: f64 = pred.iter().zip(ref_emb.iter()).map(|(a, b)| a * b).sum();
406            let norm_pred: f64 = pred.iter().map(|x| x * x).sum::<f64>().sqrt();
407            let norm_ref: f64 = ref_emb.iter().map(|x| x * x).sum::<f64>().sqrt();
408            let cos = if norm_pred > 1e-10 && norm_ref > 1e-10 {
409                (dot / (norm_pred * norm_ref)).clamp(-1.0, 1.0)
410            } else {
411                0.0
412            };
413            total += cos;
414        }
415        Ok(total / validation_keys.len() as f64)
416    }
417
418    /// Derive performance-based weights from cosine similarity scores on a validation set.
419    ///
420    /// Each model's weight is proportional to its mean cosine similarity to a reference
421    /// embedding. Models with zero weight (similarity ≤ 0) are assigned a small ε = 1e-6
422    /// to keep them in the ensemble.
423    pub fn derive_weights(
424        &mut self,
425        reference: &impl Fn(&str) -> Vec<f64>,
426        validation_keys: &[&str],
427    ) -> Result<()> {
428        let mut scores = vec![0.0_f64; self.models.len()];
429        for &key in validation_keys {
430            let ref_emb = reference(key);
431            for (i, model) in self.models.iter().enumerate() {
432                let emb = model(key);
433                let dot: f64 = emb.iter().zip(ref_emb.iter()).map(|(a, b)| a * b).sum();
434                let na: f64 = emb.iter().map(|x| x * x).sum::<f64>().sqrt();
435                let nb: f64 = ref_emb.iter().map(|x| x * x).sum::<f64>().sqrt();
436                let cos = if na > 1e-10 && nb > 1e-10 {
437                    (dot / (na * nb)).clamp(-1.0, 1.0)
438                } else {
439                    0.0
440                };
441                scores[i] += cos;
442            }
443        }
444        // Average and clamp to ε minimum
445        let n = validation_keys.len().max(1) as f64;
446        let weights: Vec<f64> = scores.iter().map(|s| (s / n).max(1e-6)).collect();
447        self.set_weights(weights)
448    }
449}
450
451// ─────────────────────────────────────────────────────────────────────────────
452// Tests
453// ─────────────────────────────────────────────────────────────────────────────
454
455#[cfg(test)]
456mod tests {
457    use super::*;
458
459    fn make_model(value: f64, dim: usize) -> EnsembleModel {
460        Box::new(move |_key: &str| vec![value; dim])
461    }
462
463    #[test]
464    fn test_voting_mean() {
465        let models: Vec<EnsembleModel> = vec![make_model(1.0, 4), make_model(3.0, 4)];
466        let config = EnsembleConfig {
467            strategy: EnsembleStrategy::Voting,
468            output_dim: 4,
469            normalize: false,
470            ..Default::default()
471        };
472        let embedder = EnsembleEmbedder::new(models, config).unwrap();
473        let emb = embedder.embed("e1").unwrap();
474        for v in &emb {
475            assert!((v - 2.0).abs() < 1e-9, "expected 2.0 got {v}");
476        }
477    }
478
479    #[test]
480    fn test_weighted_average() {
481        let models: Vec<EnsembleModel> = vec![make_model(0.0, 4), make_model(4.0, 4)];
482        let config = EnsembleConfig {
483            strategy: EnsembleStrategy::WeightedAverage,
484            output_dim: 4,
485            normalize: false,
486            ..Default::default()
487        };
488        let mut embedder = EnsembleEmbedder::new(models, config).unwrap();
489        // 25% model-0 + 75% model-1 → expected = 3.0
490        embedder.set_weights(vec![1.0, 3.0]).unwrap();
491        let emb = embedder.embed("e1").unwrap();
492        for v in &emb {
493            assert!((v - 3.0).abs() < 1e-9, "expected 3.0 got {v}");
494        }
495    }
496
497    #[test]
498    fn test_zero_weight_model_excluded() {
499        let models: Vec<EnsembleModel> = vec![
500            make_model(100.0, 4), // would skew result if weighted
501            make_model(1.0, 4),
502        ];
503        let config = EnsembleConfig {
504            strategy: EnsembleStrategy::WeightedAverage,
505            output_dim: 4,
506            normalize: false,
507            ..Default::default()
508        };
509        let mut embedder = EnsembleEmbedder::new(models, config).unwrap();
510        // Weight first model to near-zero (not exactly 0 — must be non-negative)
511        embedder.set_weights(vec![1e-10, 1.0]).unwrap();
512        let emb = embedder.embed("e1").unwrap();
513        // Result should be ≈1.0 (dominated by second model)
514        for v in &emb {
515            assert!((v - 1.0).abs() < 0.01, "expected ≈1.0 got {v}");
516        }
517    }
518
519    #[test]
520    fn test_stacking_convergence() {
521        let dim = 8;
522        // Both models output constant vectors; target = 0.5*vec
523        let models: Vec<EnsembleModel> = vec![make_model(1.0, dim), make_model(0.0, dim)];
524        let config = EnsembleConfig {
525            strategy: EnsembleStrategy::Stacking,
526            output_dim: dim,
527            stacking_hidden_dim: 32,
528            stacking_lr: 0.01,
529            stacking_epochs: 200,
530            normalize: false,
531        };
532        let mut embedder = EnsembleEmbedder::new(models, config).unwrap();
533        // Validation: target = 0.5 for each dimension
534        let targets: Vec<(&str, Vec<f64>)> = (0..20)
535            .map(|i| {
536                let key = Box::leak(format!("e{i}").into_boxed_str()) as &str;
537                (key, vec![0.5; dim])
538            })
539            .collect();
540        embedder.fit_stacking(&targets).unwrap();
541        let emb = embedder.embed("e0").unwrap();
542        // After training should converge towards 0.5
543        for v in &emb {
544            assert!(
545                (v - 0.5).abs() < 0.2,
546                "expected ≈0.5 after stacking, got {v}"
547            );
548        }
549    }
550
551    #[test]
552    fn test_derive_weights() {
553        let dim = 4;
554        // Model 0 matches reference perfectly; model 1 is zeros (cosine = 0)
555        let models: Vec<EnsembleModel> = vec![
556            Box::new(move |_| vec![1.0; dim]),
557            Box::new(move |_| vec![0.0; dim]),
558        ];
559        let reference = |_key: &str| vec![1.0f64; dim];
560        let config = EnsembleConfig {
561            strategy: EnsembleStrategy::WeightedAverage,
562            output_dim: dim,
563            normalize: false,
564            ..Default::default()
565        };
566        let mut embedder = EnsembleEmbedder::new(models, config).unwrap();
567        let keys = vec!["e0", "e1", "e2"];
568        embedder.derive_weights(&reference, &keys).unwrap();
569        // Model 0 weight should be >> model 1 weight
570        assert!(embedder.weights[0] > embedder.weights[1] * 100.0);
571    }
572
573    #[test]
574    fn test_empty_models_rejected() {
575        let models: Vec<EnsembleModel> = vec![];
576        let config = EnsembleConfig::default();
577        assert!(EnsembleEmbedder::new(models, config).is_err());
578    }
579
580    #[test]
581    fn test_stacking_requires_fit() {
582        let models: Vec<EnsembleModel> = vec![make_model(1.0, 4)];
583        let config = EnsembleConfig {
584            strategy: EnsembleStrategy::Stacking,
585            output_dim: 4,
586            ..Default::default()
587        };
588        let embedder = EnsembleEmbedder::new(models, config).unwrap();
589        assert!(embedder.embed("e1").is_err());
590    }
591
592    #[test]
593    fn test_normalize_output() {
594        let models: Vec<EnsembleModel> = vec![make_model(3.0, 4)];
595        let config = EnsembleConfig {
596            strategy: EnsembleStrategy::Voting,
597            output_dim: 4,
598            normalize: true,
599            ..Default::default()
600        };
601        let embedder = EnsembleEmbedder::new(models, config).unwrap();
602        let emb = embedder.embed("e1").unwrap();
603        let norm: f64 = emb.iter().map(|x| x * x).sum::<f64>().sqrt();
604        assert!(
605            (norm - 1.0).abs() < 1e-9,
606            "norm should be 1.0 after normalize, got {norm}"
607        );
608    }
609}