Skip to main content

kwaai_distributed/
moe.rs

1//! Mixture of Experts (MoE) implementation
2//!
3//! Enables arbitrarily large models by distributing "expert" sublayers
4//! across network participants.
5
6use crate::error::{DistributedError, DistributedResult};
7use crate::expert::{ExpertId, ExpertRegistry};
8use async_trait::async_trait;
9use candle_core::Tensor;
10use serde::{Deserialize, Serialize};
11use tracing::{debug, info};
12
13/// Routing information for MoE layer
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct Routing {
16    /// Expert indices for each token [batch, seq_len, top_k]
17    pub expert_indices: Vec<Vec<ExpertId>>,
18    /// Expert weights for each token [batch, seq_len, top_k]
19    pub expert_weights: Vec<Vec<f32>>,
20    /// Auxiliary load balancing loss
21    pub aux_loss: f32,
22}
23
24/// Trait for expert routing
25///
26/// The router determines which experts handle each token.
27#[async_trait]
28pub trait ExpertRouter: Send + Sync {
29    /// Route tokens to experts
30    ///
31    /// Returns routing information including expert assignments
32    /// and weights for each token.
33    fn route(&self, hidden_states: &Tensor) -> DistributedResult<Routing>;
34
35    /// Number of experts to route to per token
36    fn top_k(&self) -> usize;
37
38    /// Total number of experts
39    fn num_experts(&self) -> usize;
40}
41
42/// Trait for Mixture of Experts layer
43#[async_trait]
44pub trait MixtureOfExperts: Send + Sync {
45    /// Forward pass through MoE layer
46    ///
47    /// Routes tokens to experts and combines results.
48    async fn forward(&mut self, input: &Tensor) -> DistributedResult<Tensor>;
49
50    /// Get the expert registry
51    fn registry(&self) -> &ExpertRegistry;
52
53    /// Get the router
54    fn router(&self) -> &dyn ExpertRouter;
55}
56
57/// Simple top-k router implementation
58pub struct TopKRouter {
59    /// Gating weights [hidden_size, num_experts]
60    gate_weights: Tensor,
61    /// Number of experts to select per token
62    top_k: usize,
63    /// Total number of experts
64    num_experts: usize,
65    /// Auxiliary loss coefficient
66    #[allow(dead_code)]
67    aux_loss_coef: f32,
68}
69
70impl TopKRouter {
71    /// Create a new top-k router
72    pub fn new(gate_weights: Tensor, top_k: usize, num_experts: usize, aux_loss_coef: f32) -> Self {
73        Self {
74            gate_weights,
75            top_k,
76            num_experts,
77            aux_loss_coef,
78        }
79    }
80}
81
82#[async_trait]
83impl ExpertRouter for TopKRouter {
84    fn route(&self, hidden_states: &Tensor) -> DistributedResult<Routing> {
85        // Compute gating scores
86        let scores = hidden_states
87            .matmul(&self.gate_weights)
88            .map_err(|e| DistributedError::RoutingFailed(e.to_string()))?;
89
90        // Get dimensions
91        let dims = scores.dims();
92        let _batch_size = if dims.len() > 2 { dims[0] } else { 1 };
93        let seq_len = if dims.len() > 2 { dims[1] } else { dims[0] };
94
95        // Placeholder routing - in real implementation, would do proper softmax + top-k
96        // For now, return uniform routing to first top_k experts
97        let expert_indices: Vec<Vec<ExpertId>> = (0..seq_len)
98            .map(|_| (0..self.top_k).map(|i| ExpertId::new(i as u64)).collect())
99            .collect();
100
101        let expert_weights: Vec<Vec<f32>> = (0..seq_len)
102            .map(|_| vec![1.0 / self.top_k as f32; self.top_k])
103            .collect();
104
105        Ok(Routing {
106            expert_indices,
107            expert_weights,
108            aux_loss: 0.0,
109        })
110    }
111
112    fn top_k(&self) -> usize {
113        self.top_k
114    }
115
116    fn num_experts(&self) -> usize {
117        self.num_experts
118    }
119}
120
121/// Distributed MoE layer implementation
122pub struct DistributedMoE {
123    /// Expert router
124    router: Box<dyn ExpertRouter>,
125    /// Expert registry
126    registry: ExpertRegistry,
127    /// Configuration
128    #[allow(dead_code)]
129    config: MoEConfig,
130}
131
132/// Configuration for MoE layer
133#[derive(Debug, Clone)]
134pub struct MoEConfig {
135    /// Hidden dimension
136    pub hidden_dim: usize,
137    /// Number of experts
138    pub num_experts: usize,
139    /// Top-k experts per token
140    pub top_k: usize,
141    /// Timeout for remote calls (ms)
142    pub timeout_ms: u64,
143}
144
145impl Default for MoEConfig {
146    fn default() -> Self {
147        Self {
148            hidden_dim: 4096,
149            num_experts: 8,
150            top_k: 2,
151            timeout_ms: 5000,
152        }
153    }
154}
155
156impl DistributedMoE {
157    /// Create a new distributed MoE layer
158    pub fn new(router: Box<dyn ExpertRouter>, config: MoEConfig) -> Self {
159        info!(
160            num_experts = config.num_experts,
161            top_k = config.top_k,
162            hidden_dim = config.hidden_dim,
163            "Creating DistributedMoE layer"
164        );
165        Self {
166            router,
167            registry: ExpertRegistry::new(),
168            config,
169        }
170    }
171
172    /// Register an expert (local or remote)
173    pub fn register_expert(&mut self, expert: Box<dyn crate::expert::Expert>) {
174        debug!("Registering local expert in MoE layer");
175        self.registry.register_local(expert);
176    }
177
178    /// Register remote expert location
179    pub fn register_remote_expert(&mut self, expert_id: ExpertId, peer_id: String) {
180        debug!(
181            "Registering remote expert {} in MoE layer at peer {}",
182            expert_id, peer_id
183        );
184        self.registry.register_remote(expert_id, peer_id);
185    }
186}
187
188#[async_trait]
189impl MixtureOfExperts for DistributedMoE {
190    async fn forward(&mut self, input: &Tensor) -> DistributedResult<Tensor> {
191        debug!("MoE forward pass, input shape: {:?}", input.dims());
192        // 1. Route tokens to experts
193        let routing = self.router.route(input)?;
194        debug!("Routing computed: aux_loss={:.4}", routing.aux_loss);
195
196        // 2. For now, just return input (placeholder)
197        // Real implementation would:
198        // - Partition tokens by expert assignment
199        // - Call local experts directly
200        // - Call remote experts via P2P
201        // - Combine results weighted by routing weights
202
203        Ok(input.clone())
204    }
205
206    fn registry(&self) -> &ExpertRegistry {
207        &self.registry
208    }
209
210    fn router(&self) -> &dyn ExpertRouter {
211        self.router.as_ref()
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218    use crate::expert::LocalExpert;
219    use candle_core::{DType, Device, Tensor};
220
221    fn make_router(hidden: usize, num_experts: usize, top_k: usize) -> TopKRouter {
222        let gate = Tensor::zeros((hidden, num_experts), DType::F32, &Device::Cpu).unwrap();
223        TopKRouter::new(gate, top_k, num_experts, 0.01)
224    }
225
226    #[test]
227    fn test_router_accessors() {
228        let router = make_router(64, 8, 2);
229        assert_eq!(router.top_k(), 2);
230        assert_eq!(router.num_experts(), 8);
231    }
232
233    #[test]
234    fn test_routing_output_matches_seq_len() {
235        let router = make_router(16, 4, 2);
236        let seq_len = 5usize;
237        let input = Tensor::zeros((seq_len, 16usize), DType::F32, &Device::Cpu).unwrap();
238        let routing = router.route(&input).unwrap();
239        assert_eq!(routing.expert_indices.len(), seq_len);
240        assert_eq!(routing.expert_weights.len(), seq_len);
241        for row in &routing.expert_indices {
242            assert_eq!(row.len(), 2, "top_k=2 → each token routed to 2 experts");
243        }
244    }
245
246    #[test]
247    fn test_routing_weights_sum_to_one() {
248        let router = make_router(8, 4, 2);
249        let input = Tensor::zeros((3usize, 8usize), DType::F32, &Device::Cpu).unwrap();
250        let routing = router.route(&input).unwrap();
251        for row in &routing.expert_weights {
252            let sum: f32 = row.iter().sum();
253            assert!((sum - 1.0).abs() < 1e-5, "weights sum={sum}");
254        }
255    }
256
257    #[test]
258    fn test_moe_register_local_expert() {
259        let router = make_router(16, 4, 2);
260        let cfg = MoEConfig {
261            hidden_dim: 16,
262            num_experts: 4,
263            top_k: 2,
264            timeout_ms: 1000,
265        };
266        let mut moe = DistributedMoE::new(Box::new(router), cfg);
267        moe.register_expert(Box::new(LocalExpert::new(0, 16)));
268        assert!(moe.registry().is_local(crate::expert::ExpertId::new(0)));
269        assert!(!moe.registry().is_local(crate::expert::ExpertId::new(1)));
270    }
271
272    #[test]
273    fn test_moe_register_remote_expert() {
274        let router = make_router(16, 4, 2);
275        let cfg = MoEConfig::default();
276        let mut moe = DistributedMoE::new(Box::new(router), cfg);
277        moe.register_remote_expert(crate::expert::ExpertId::new(5), "peer-abc".to_string());
278        assert_eq!(
279            moe.registry()
280                .get_remote_peer(crate::expert::ExpertId::new(5)),
281            Some(&"peer-abc".to_string())
282        );
283    }
284
285    #[tokio::test]
286    async fn test_moe_forward_returns_same_shape() {
287        let router = make_router(8, 2, 1);
288        let cfg = MoEConfig {
289            hidden_dim: 8,
290            num_experts: 2,
291            top_k: 1,
292            timeout_ms: 1000,
293        };
294        let mut moe = DistributedMoE::new(Box::new(router), cfg);
295        let input = Tensor::zeros((2usize, 8usize), DType::F32, &Device::Cpu).unwrap();
296        let output = moe.forward(&input).await.unwrap();
297        assert_eq!(output.dims(), input.dims());
298    }
299}