use crate::error::{DistributedError, DistributedResult};
use crate::expert::{ExpertId, ExpertRegistry};
use async_trait::async_trait;
use candle_core::Tensor;
use serde::{Deserialize, Serialize};
use tracing::{debug, info};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Routing {
pub expert_indices: Vec<Vec<ExpertId>>,
pub expert_weights: Vec<Vec<f32>>,
pub aux_loss: f32,
}
#[async_trait]
pub trait ExpertRouter: Send + Sync {
fn route(&self, hidden_states: &Tensor) -> DistributedResult<Routing>;
fn top_k(&self) -> usize;
fn num_experts(&self) -> usize;
}
#[async_trait]
pub trait MixtureOfExperts: Send + Sync {
async fn forward(&mut self, input: &Tensor) -> DistributedResult<Tensor>;
fn registry(&self) -> &ExpertRegistry;
fn router(&self) -> &dyn ExpertRouter;
}
pub struct TopKRouter {
gate_weights: Tensor,
top_k: usize,
num_experts: usize,
#[allow(dead_code)]
aux_loss_coef: f32,
}
impl TopKRouter {
pub fn new(gate_weights: Tensor, top_k: usize, num_experts: usize, aux_loss_coef: f32) -> Self {
Self {
gate_weights,
top_k,
num_experts,
aux_loss_coef,
}
}
}
#[async_trait]
impl ExpertRouter for TopKRouter {
fn route(&self, hidden_states: &Tensor) -> DistributedResult<Routing> {
let scores = hidden_states
.matmul(&self.gate_weights)
.map_err(|e| DistributedError::RoutingFailed(e.to_string()))?;
let dims = scores.dims();
let _batch_size = if dims.len() > 2 { dims[0] } else { 1 };
let seq_len = if dims.len() > 2 { dims[1] } else { dims[0] };
let expert_indices: Vec<Vec<ExpertId>> = (0..seq_len)
.map(|_| (0..self.top_k).map(|i| ExpertId::new(i as u64)).collect())
.collect();
let expert_weights: Vec<Vec<f32>> = (0..seq_len)
.map(|_| vec![1.0 / self.top_k as f32; self.top_k])
.collect();
Ok(Routing {
expert_indices,
expert_weights,
aux_loss: 0.0,
})
}
fn top_k(&self) -> usize {
self.top_k
}
fn num_experts(&self) -> usize {
self.num_experts
}
}
pub struct DistributedMoE {
router: Box<dyn ExpertRouter>,
registry: ExpertRegistry,
#[allow(dead_code)]
config: MoEConfig,
}
#[derive(Debug, Clone)]
pub struct MoEConfig {
pub hidden_dim: usize,
pub num_experts: usize,
pub top_k: usize,
pub timeout_ms: u64,
}
impl Default for MoEConfig {
fn default() -> Self {
Self {
hidden_dim: 4096,
num_experts: 8,
top_k: 2,
timeout_ms: 5000,
}
}
}
impl DistributedMoE {
pub fn new(router: Box<dyn ExpertRouter>, config: MoEConfig) -> Self {
info!(
num_experts = config.num_experts,
top_k = config.top_k,
hidden_dim = config.hidden_dim,
"Creating DistributedMoE layer"
);
Self {
router,
registry: ExpertRegistry::new(),
config,
}
}
pub fn register_expert(&mut self, expert: Box<dyn crate::expert::Expert>) {
debug!("Registering local expert in MoE layer");
self.registry.register_local(expert);
}
pub fn register_remote_expert(&mut self, expert_id: ExpertId, peer_id: String) {
debug!(
"Registering remote expert {} in MoE layer at peer {}",
expert_id, peer_id
);
self.registry.register_remote(expert_id, peer_id);
}
}
#[async_trait]
impl MixtureOfExperts for DistributedMoE {
async fn forward(&mut self, input: &Tensor) -> DistributedResult<Tensor> {
debug!("MoE forward pass, input shape: {:?}", input.dims());
let routing = self.router.route(input)?;
debug!("Routing computed: aux_loss={:.4}", routing.aux_loss);
Ok(input.clone())
}
fn registry(&self) -> &ExpertRegistry {
&self.registry
}
fn router(&self) -> &dyn ExpertRouter {
self.router.as_ref()
}
}