aprender/ensemble/
gating.rs1use serde::{Deserialize, Serialize};
4
5pub trait GatingNetwork: Send + Sync {
7 fn forward(&self, x: &[f32]) -> Vec<f32>;
9
10 fn n_features(&self) -> usize;
12
13 fn n_experts(&self) -> usize;
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct SoftmaxGating {
20 n_features: usize,
21 n_experts: usize,
22 temperature: f32,
23 weights: Vec<f32>,
24}
25
26impl SoftmaxGating {
27 #[must_use]
29 pub fn new(n_features: usize, n_experts: usize) -> Self {
30 let scale = (2.0 / (n_features + n_experts) as f32).sqrt();
31 let weights: Vec<f32> = (0..n_features * n_experts)
32 .map(|i| {
33 let row = i / n_experts;
34 let col = i % n_experts;
35 scale * ((row + col) as f32 * 0.1 - 0.5)
36 })
37 .collect();
38
39 Self {
40 n_features,
41 n_experts,
42 temperature: 1.0,
43 weights,
44 }
45 }
46
47 #[must_use]
49 pub fn with_temperature(mut self, temp: f32) -> Self {
50 self.temperature = temp;
51 self
52 }
53
54 #[must_use]
56 pub fn temperature(&self) -> f32 {
57 self.temperature
58 }
59
60 fn softmax(&self, logits: &[f32]) -> Vec<f32> {
61 let scaled: Vec<f32> = logits.iter().map(|&x| x / self.temperature).collect();
62 let max = scaled.iter().copied().fold(f32::NEG_INFINITY, f32::max);
63 let exp_vals: Vec<f32> = scaled.iter().map(|&x| (x - max).exp()).collect();
64 let sum: f32 = exp_vals.iter().sum();
65 exp_vals.iter().map(|&x| x / sum).collect()
66 }
67}
68
69impl GatingNetwork for SoftmaxGating {
70 fn forward(&self, x: &[f32]) -> Vec<f32> {
71 let mut logits = vec![0.0f32; self.n_experts];
72 for (j, logit) in logits.iter_mut().enumerate() {
73 for (i, &xi) in x.iter().take(self.n_features).enumerate() {
74 *logit += xi * self.weights[i * self.n_experts + j];
75 }
76 }
77 self.softmax(&logits)
78 }
79
80 fn n_features(&self) -> usize {
81 self.n_features
82 }
83
84 fn n_experts(&self) -> usize {
85 self.n_experts
86 }
87}
88
89#[cfg(test)]
90mod tests {
91 use super::*;
92
93 #[test]
94 fn test_softmax_gating_new() {
95 let gating = SoftmaxGating::new(10, 4);
96 assert_eq!(gating.n_features(), 10);
97 assert_eq!(gating.n_experts(), 4);
98 assert_eq!(gating.temperature(), 1.0);
99 }
100
101 #[test]
102 fn test_softmax_gating_with_temperature() {
103 let gating = SoftmaxGating::new(5, 3).with_temperature(0.5);
104 assert!((gating.temperature() - 0.5).abs() < 1e-6);
105 }
106
107 #[test]
108 fn test_softmax_gating_forward() {
109 let gating = SoftmaxGating::new(3, 2);
110 let input = vec![1.0, 0.5, 0.2];
111 let weights = gating.forward(&input);
112
113 assert_eq!(weights.len(), 2);
115
116 let sum: f32 = weights.iter().sum();
118 assert!((sum - 1.0).abs() < 1e-5);
119
120 for w in &weights {
122 assert!(*w >= 0.0);
123 }
124 }
125
126 #[test]
127 fn test_softmax_gating_forward_longer_input() {
128 let gating = SoftmaxGating::new(3, 2);
129 let input = vec![1.0, 0.5, 0.2, 0.8, 0.9];
131 let weights = gating.forward(&input);
132 assert_eq!(weights.len(), 2);
133 }
134
135 #[test]
136 fn test_softmax_gating_temperature_effect() {
137 let gating_high_temp = SoftmaxGating::new(3, 4).with_temperature(10.0);
138 let gating_low_temp = SoftmaxGating::new(3, 4).with_temperature(0.1);
139
140 let input = vec![1.0, 2.0, 3.0];
141 let weights_high = gating_high_temp.forward(&input);
142 let weights_low = gating_low_temp.forward(&input);
143
144 let high_max = weights_high.iter().cloned().fold(0.0f32, f32::max);
146 let low_max = weights_low.iter().cloned().fold(0.0f32, f32::max);
147
148 assert!(low_max > high_max);
150 }
151
152 #[test]
153 fn test_softmax_gating_clone() {
154 let gating = SoftmaxGating::new(5, 3).with_temperature(2.0);
155 let cloned = gating.clone();
156 assert_eq!(cloned.n_features(), gating.n_features());
157 assert_eq!(cloned.n_experts(), gating.n_experts());
158 assert!((cloned.temperature() - gating.temperature()).abs() < 1e-6);
159 }
160
161 #[test]
162 fn test_softmax_gating_debug() {
163 let gating = SoftmaxGating::new(3, 2);
164 let debug_str = format!("{:?}", gating);
165 assert!(debug_str.contains("SoftmaxGating"));
166 }
167
168 #[test]
169 fn test_softmax_gating_weights_initialized() {
170 let gating = SoftmaxGating::new(4, 3);
171 assert_eq!(gating.weights.len(), 4 * 3); }
173}