1use crate::error::{DistributedError, DistributedResult};
7use crate::expert::{ExpertId, ExpertRegistry};
8use async_trait::async_trait;
9use candle_core::Tensor;
10use serde::{Deserialize, Serialize};
11use tracing::{debug, info};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct Routing {
16 pub expert_indices: Vec<Vec<ExpertId>>,
18 pub expert_weights: Vec<Vec<f32>>,
20 pub aux_loss: f32,
22}
23
24#[async_trait]
28pub trait ExpertRouter: Send + Sync {
29 fn route(&self, hidden_states: &Tensor) -> DistributedResult<Routing>;
34
35 fn top_k(&self) -> usize;
37
38 fn num_experts(&self) -> usize;
40}
41
42#[async_trait]
44pub trait MixtureOfExperts: Send + Sync {
45 async fn forward(&mut self, input: &Tensor) -> DistributedResult<Tensor>;
49
50 fn registry(&self) -> &ExpertRegistry;
52
53 fn router(&self) -> &dyn ExpertRouter;
55}
56
57pub struct TopKRouter {
59 gate_weights: Tensor,
61 top_k: usize,
63 num_experts: usize,
65 #[allow(dead_code)]
67 aux_loss_coef: f32,
68}
69
70impl TopKRouter {
71 pub fn new(gate_weights: Tensor, top_k: usize, num_experts: usize, aux_loss_coef: f32) -> Self {
73 Self {
74 gate_weights,
75 top_k,
76 num_experts,
77 aux_loss_coef,
78 }
79 }
80}
81
82#[async_trait]
83impl ExpertRouter for TopKRouter {
84 fn route(&self, hidden_states: &Tensor) -> DistributedResult<Routing> {
85 let scores = hidden_states
87 .matmul(&self.gate_weights)
88 .map_err(|e| DistributedError::RoutingFailed(e.to_string()))?;
89
90 let dims = scores.dims();
92 let _batch_size = if dims.len() > 2 { dims[0] } else { 1 };
93 let seq_len = if dims.len() > 2 { dims[1] } else { dims[0] };
94
95 let expert_indices: Vec<Vec<ExpertId>> = (0..seq_len)
98 .map(|_| (0..self.top_k).map(|i| ExpertId::new(i as u64)).collect())
99 .collect();
100
101 let expert_weights: Vec<Vec<f32>> = (0..seq_len)
102 .map(|_| vec![1.0 / self.top_k as f32; self.top_k])
103 .collect();
104
105 Ok(Routing {
106 expert_indices,
107 expert_weights,
108 aux_loss: 0.0,
109 })
110 }
111
112 fn top_k(&self) -> usize {
113 self.top_k
114 }
115
116 fn num_experts(&self) -> usize {
117 self.num_experts
118 }
119}
120
121pub struct DistributedMoE {
123 router: Box<dyn ExpertRouter>,
125 registry: ExpertRegistry,
127 #[allow(dead_code)]
129 config: MoEConfig,
130}
131
132#[derive(Debug, Clone)]
134pub struct MoEConfig {
135 pub hidden_dim: usize,
137 pub num_experts: usize,
139 pub top_k: usize,
141 pub timeout_ms: u64,
143}
144
145impl Default for MoEConfig {
146 fn default() -> Self {
147 Self {
148 hidden_dim: 4096,
149 num_experts: 8,
150 top_k: 2,
151 timeout_ms: 5000,
152 }
153 }
154}
155
156impl DistributedMoE {
157 pub fn new(router: Box<dyn ExpertRouter>, config: MoEConfig) -> Self {
159 info!(
160 num_experts = config.num_experts,
161 top_k = config.top_k,
162 hidden_dim = config.hidden_dim,
163 "Creating DistributedMoE layer"
164 );
165 Self {
166 router,
167 registry: ExpertRegistry::new(),
168 config,
169 }
170 }
171
172 pub fn register_expert(&mut self, expert: Box<dyn crate::expert::Expert>) {
174 debug!("Registering local expert in MoE layer");
175 self.registry.register_local(expert);
176 }
177
178 pub fn register_remote_expert(&mut self, expert_id: ExpertId, peer_id: String) {
180 debug!(
181 "Registering remote expert {} in MoE layer at peer {}",
182 expert_id, peer_id
183 );
184 self.registry.register_remote(expert_id, peer_id);
185 }
186}
187
188#[async_trait]
189impl MixtureOfExperts for DistributedMoE {
190 async fn forward(&mut self, input: &Tensor) -> DistributedResult<Tensor> {
191 debug!("MoE forward pass, input shape: {:?}", input.dims());
192 let routing = self.router.route(input)?;
194 debug!("Routing computed: aux_loss={:.4}", routing.aux_loss);
195
196 Ok(input.clone())
204 }
205
206 fn registry(&self) -> &ExpertRegistry {
207 &self.registry
208 }
209
210 fn router(&self) -> &dyn ExpertRouter {
211 self.router.as_ref()
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218 use crate::expert::LocalExpert;
219 use candle_core::{DType, Device, Tensor};
220
221 fn make_router(hidden: usize, num_experts: usize, top_k: usize) -> TopKRouter {
222 let gate = Tensor::zeros((hidden, num_experts), DType::F32, &Device::Cpu).unwrap();
223 TopKRouter::new(gate, top_k, num_experts, 0.01)
224 }
225
226 #[test]
227 fn test_router_accessors() {
228 let router = make_router(64, 8, 2);
229 assert_eq!(router.top_k(), 2);
230 assert_eq!(router.num_experts(), 8);
231 }
232
233 #[test]
234 fn test_routing_output_matches_seq_len() {
235 let router = make_router(16, 4, 2);
236 let seq_len = 5usize;
237 let input = Tensor::zeros((seq_len, 16usize), DType::F32, &Device::Cpu).unwrap();
238 let routing = router.route(&input).unwrap();
239 assert_eq!(routing.expert_indices.len(), seq_len);
240 assert_eq!(routing.expert_weights.len(), seq_len);
241 for row in &routing.expert_indices {
242 assert_eq!(row.len(), 2, "top_k=2 → each token routed to 2 experts");
243 }
244 }
245
246 #[test]
247 fn test_routing_weights_sum_to_one() {
248 let router = make_router(8, 4, 2);
249 let input = Tensor::zeros((3usize, 8usize), DType::F32, &Device::Cpu).unwrap();
250 let routing = router.route(&input).unwrap();
251 for row in &routing.expert_weights {
252 let sum: f32 = row.iter().sum();
253 assert!((sum - 1.0).abs() < 1e-5, "weights sum={sum}");
254 }
255 }
256
257 #[test]
258 fn test_moe_register_local_expert() {
259 let router = make_router(16, 4, 2);
260 let cfg = MoEConfig {
261 hidden_dim: 16,
262 num_experts: 4,
263 top_k: 2,
264 timeout_ms: 1000,
265 };
266 let mut moe = DistributedMoE::new(Box::new(router), cfg);
267 moe.register_expert(Box::new(LocalExpert::new(0, 16)));
268 assert!(moe.registry().is_local(crate::expert::ExpertId::new(0)));
269 assert!(!moe.registry().is_local(crate::expert::ExpertId::new(1)));
270 }
271
272 #[test]
273 fn test_moe_register_remote_expert() {
274 let router = make_router(16, 4, 2);
275 let cfg = MoEConfig::default();
276 let mut moe = DistributedMoE::new(Box::new(router), cfg);
277 moe.register_remote_expert(crate::expert::ExpertId::new(5), "peer-abc".to_string());
278 assert_eq!(
279 moe.registry()
280 .get_remote_peer(crate::expert::ExpertId::new(5)),
281 Some(&"peer-abc".to_string())
282 );
283 }
284
285 #[tokio::test]
286 async fn test_moe_forward_returns_same_shape() {
287 let router = make_router(8, 2, 1);
288 let cfg = MoEConfig {
289 hidden_dim: 8,
290 num_experts: 2,
291 top_k: 1,
292 timeout_ms: 1000,
293 };
294 let mut moe = DistributedMoE::new(Box::new(router), cfg);
295 let input = Tensor::zeros((2usize, 8usize), DType::F32, &Device::Cpu).unwrap();
296 let output = moe.forward(&input).await.unwrap();
297 assert_eq!(output.dims(), input.dims());
298 }
299}