kwaai-distributed 0.4.63

Distributed ML operations for KwaaiNet - MoE, parameter averaging (Hivemind patterns)
Documentation
//! Mixture of Experts (MoE) implementation
//!
//! Enables arbitrarily large models by distributing "expert" sublayers
//! across network participants.

use crate::error::{DistributedError, DistributedResult};
use crate::expert::{ExpertId, ExpertRegistry};
use async_trait::async_trait;
use candle_core::Tensor;
use serde::{Deserialize, Serialize};
use tracing::{debug, info};

/// Routing information for MoE layer
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Routing {
    /// Expert indices for each token [batch, seq_len, top_k]
    pub expert_indices: Vec<Vec<ExpertId>>,
    /// Expert weights for each token [batch, seq_len, top_k]
    pub expert_weights: Vec<Vec<f32>>,
    /// Auxiliary load balancing loss
    pub aux_loss: f32,
}

/// Trait for expert routing
///
/// The router determines which experts handle each token.
#[async_trait]
pub trait ExpertRouter: Send + Sync {
    /// Route tokens to experts
    ///
    /// Returns routing information including expert assignments
    /// and weights for each token.
    fn route(&self, hidden_states: &Tensor) -> DistributedResult<Routing>;

    /// Number of experts to route to per token
    fn top_k(&self) -> usize;

    /// Total number of experts
    fn num_experts(&self) -> usize;
}

/// Trait for Mixture of Experts layer
#[async_trait]
pub trait MixtureOfExperts: Send + Sync {
    /// Forward pass through MoE layer
    ///
    /// Routes tokens to experts and combines results.
    async fn forward(&mut self, input: &Tensor) -> DistributedResult<Tensor>;

    /// Get the expert registry
    fn registry(&self) -> &ExpertRegistry;

    /// Get the router
    fn router(&self) -> &dyn ExpertRouter;
}

/// Simple top-k router implementation
pub struct TopKRouter {
    /// Gating weights [hidden_size, num_experts]
    gate_weights: Tensor,
    /// Number of experts to select per token
    top_k: usize,
    /// Total number of experts
    num_experts: usize,
    /// Auxiliary loss coefficient
    #[allow(dead_code)]
    aux_loss_coef: f32,
}

impl TopKRouter {
    /// Create a new top-k router
    pub fn new(gate_weights: Tensor, top_k: usize, num_experts: usize, aux_loss_coef: f32) -> Self {
        Self {
            gate_weights,
            top_k,
            num_experts,
            aux_loss_coef,
        }
    }
}

#[async_trait]
impl ExpertRouter for TopKRouter {
    fn route(&self, hidden_states: &Tensor) -> DistributedResult<Routing> {
        // Compute gating scores
        let scores = hidden_states
            .matmul(&self.gate_weights)
            .map_err(|e| DistributedError::RoutingFailed(e.to_string()))?;

        // Get dimensions
        let dims = scores.dims();
        let _batch_size = if dims.len() > 2 { dims[0] } else { 1 };
        let seq_len = if dims.len() > 2 { dims[1] } else { dims[0] };

        // Placeholder routing - in real implementation, would do proper softmax + top-k
        // For now, return uniform routing to first top_k experts
        let expert_indices: Vec<Vec<ExpertId>> = (0..seq_len)
            .map(|_| (0..self.top_k).map(|i| ExpertId::new(i as u64)).collect())
            .collect();

        let expert_weights: Vec<Vec<f32>> = (0..seq_len)
            .map(|_| vec![1.0 / self.top_k as f32; self.top_k])
            .collect();

        Ok(Routing {
            expert_indices,
            expert_weights,
            aux_loss: 0.0,
        })
    }

    fn top_k(&self) -> usize {
        self.top_k
    }

    fn num_experts(&self) -> usize {
        self.num_experts
    }
}

/// Distributed MoE layer implementation
pub struct DistributedMoE {
    /// Expert router
    router: Box<dyn ExpertRouter>,
    /// Expert registry
    registry: ExpertRegistry,
    /// Configuration
    #[allow(dead_code)]
    config: MoEConfig,
}

/// Configuration for MoE layer
#[derive(Debug, Clone)]
pub struct MoEConfig {
    /// Hidden dimension
    pub hidden_dim: usize,
    /// Number of experts
    pub num_experts: usize,
    /// Top-k experts per token
    pub top_k: usize,
    /// Timeout for remote calls (ms)
    pub timeout_ms: u64,
}

impl Default for MoEConfig {
    fn default() -> Self {
        Self {
            hidden_dim: 4096,
            num_experts: 8,
            top_k: 2,
            timeout_ms: 5000,
        }
    }
}

impl DistributedMoE {
    /// Create a new distributed MoE layer
    pub fn new(router: Box<dyn ExpertRouter>, config: MoEConfig) -> Self {
        info!(
            num_experts = config.num_experts,
            top_k = config.top_k,
            hidden_dim = config.hidden_dim,
            "Creating DistributedMoE layer"
        );
        Self {
            router,
            registry: ExpertRegistry::new(),
            config,
        }
    }

    /// Register an expert (local or remote)
    pub fn register_expert(&mut self, expert: Box<dyn crate::expert::Expert>) {
        debug!("Registering local expert in MoE layer");
        self.registry.register_local(expert);
    }

    /// Register remote expert location
    pub fn register_remote_expert(&mut self, expert_id: ExpertId, peer_id: String) {
        debug!(
            "Registering remote expert {} in MoE layer at peer {}",
            expert_id, peer_id
        );
        self.registry.register_remote(expert_id, peer_id);
    }
}

#[async_trait]
impl MixtureOfExperts for DistributedMoE {
    async fn forward(&mut self, input: &Tensor) -> DistributedResult<Tensor> {
        debug!("MoE forward pass, input shape: {:?}", input.dims());
        // 1. Route tokens to experts
        let routing = self.router.route(input)?;
        debug!("Routing computed: aux_loss={:.4}", routing.aux_loss);

        // 2. For now, just return input (placeholder)
        // Real implementation would:
        // - Partition tokens by expert assignment
        // - Call local experts directly
        // - Call remote experts via P2P
        // - Combine results weighted by routing weights

        Ok(input.clone())
    }

    fn registry(&self) -> &ExpertRegistry {
        &self.registry
    }

    fn router(&self) -> &dyn ExpertRouter {
        self.router.as_ref()
    }
}