Skip to main content

aprender/ensemble/
gating.rs

1//! Gating networks for expert routing
2
3use serde::{Deserialize, Serialize};
4
5/// Trait for gating networks that route inputs to experts
6pub trait GatingNetwork: Send + Sync {
7    /// Compute expert weights for input
8    fn forward(&self, x: &[f32]) -> Vec<f32>;
9
10    /// Number of input features
11    fn n_features(&self) -> usize;
12
13    /// Number of experts
14    fn n_experts(&self) -> usize;
15}
16
17/// Softmax gating with learnable weights
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct SoftmaxGating {
20    n_features: usize,
21    n_experts: usize,
22    temperature: f32,
23    weights: Vec<f32>,
24}
25
26impl SoftmaxGating {
27    /// Create new softmax gating
28    #[must_use]
29    pub fn new(n_features: usize, n_experts: usize) -> Self {
30        let scale = (2.0 / (n_features + n_experts) as f32).sqrt();
31        let weights: Vec<f32> = (0..n_features * n_experts)
32            .map(|i| {
33                let row = i / n_experts;
34                let col = i % n_experts;
35                scale * ((row + col) as f32 * 0.1 - 0.5)
36            })
37            .collect();
38
39        Self {
40            n_features,
41            n_experts,
42            temperature: 1.0,
43            weights,
44        }
45    }
46
47    /// Set temperature
48    #[must_use]
49    pub fn with_temperature(mut self, temp: f32) -> Self {
50        self.temperature = temp;
51        self
52    }
53
54    /// Get temperature
55    #[must_use]
56    pub fn temperature(&self) -> f32 {
57        self.temperature
58    }
59
60    fn softmax(&self, logits: &[f32]) -> Vec<f32> {
61        let scaled: Vec<f32> = logits.iter().map(|&x| x / self.temperature).collect();
62        let max = scaled.iter().copied().fold(f32::NEG_INFINITY, f32::max);
63        let exp_vals: Vec<f32> = scaled.iter().map(|&x| (x - max).exp()).collect();
64        let sum: f32 = exp_vals.iter().sum();
65        exp_vals.iter().map(|&x| x / sum).collect()
66    }
67}
68
69impl GatingNetwork for SoftmaxGating {
70    fn forward(&self, x: &[f32]) -> Vec<f32> {
71        let mut logits = vec![0.0f32; self.n_experts];
72        for (j, logit) in logits.iter_mut().enumerate() {
73            for (i, &xi) in x.iter().take(self.n_features).enumerate() {
74                *logit += xi * self.weights[i * self.n_experts + j];
75            }
76        }
77        self.softmax(&logits)
78    }
79
80    fn n_features(&self) -> usize {
81        self.n_features
82    }
83
84    fn n_experts(&self) -> usize {
85        self.n_experts
86    }
87}
88
89#[cfg(test)]
90mod tests {
91    use super::*;
92
93    #[test]
94    fn test_softmax_gating_new() {
95        let gating = SoftmaxGating::new(10, 4);
96        assert_eq!(gating.n_features(), 10);
97        assert_eq!(gating.n_experts(), 4);
98        assert_eq!(gating.temperature(), 1.0);
99    }
100
101    #[test]
102    fn test_softmax_gating_with_temperature() {
103        let gating = SoftmaxGating::new(5, 3).with_temperature(0.5);
104        assert!((gating.temperature() - 0.5).abs() < 1e-6);
105    }
106
107    #[test]
108    fn test_softmax_gating_forward() {
109        let gating = SoftmaxGating::new(3, 2);
110        let input = vec![1.0, 0.5, 0.2];
111        let weights = gating.forward(&input);
112
113        // Should return weights for each expert
114        assert_eq!(weights.len(), 2);
115
116        // Weights should sum to ~1.0 (softmax property)
117        let sum: f32 = weights.iter().sum();
118        assert!((sum - 1.0).abs() < 1e-5);
119
120        // All weights should be positive
121        for w in &weights {
122            assert!(*w >= 0.0);
123        }
124    }
125
126    #[test]
127    fn test_softmax_gating_forward_longer_input() {
128        let gating = SoftmaxGating::new(3, 2);
129        // Input longer than n_features - should only use first n_features
130        let input = vec![1.0, 0.5, 0.2, 0.8, 0.9];
131        let weights = gating.forward(&input);
132        assert_eq!(weights.len(), 2);
133    }
134
135    #[test]
136    fn test_softmax_gating_temperature_effect() {
137        let gating_high_temp = SoftmaxGating::new(3, 4).with_temperature(10.0);
138        let gating_low_temp = SoftmaxGating::new(3, 4).with_temperature(0.1);
139
140        let input = vec![1.0, 2.0, 3.0];
141        let weights_high = gating_high_temp.forward(&input);
142        let weights_low = gating_low_temp.forward(&input);
143
144        // High temperature should give more uniform distribution
145        let high_max = weights_high.iter().cloned().fold(0.0f32, f32::max);
146        let low_max = weights_low.iter().cloned().fold(0.0f32, f32::max);
147
148        // Low temperature should have a more peaked distribution
149        assert!(low_max > high_max);
150    }
151
152    #[test]
153    fn test_softmax_gating_clone() {
154        let gating = SoftmaxGating::new(5, 3).with_temperature(2.0);
155        let cloned = gating.clone();
156        assert_eq!(cloned.n_features(), gating.n_features());
157        assert_eq!(cloned.n_experts(), gating.n_experts());
158        assert!((cloned.temperature() - gating.temperature()).abs() < 1e-6);
159    }
160
161    #[test]
162    fn test_softmax_gating_debug() {
163        let gating = SoftmaxGating::new(3, 2);
164        let debug_str = format!("{:?}", gating);
165        assert!(debug_str.contains("SoftmaxGating"));
166    }
167
168    #[test]
169    fn test_softmax_gating_weights_initialized() {
170        let gating = SoftmaxGating::new(4, 3);
171        assert_eq!(gating.weights.len(), 4 * 3); // n_features * n_experts
172    }
173}