use crate::error::DistributedResult;
use async_trait::async_trait;
use candle_core::Tensor;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tracing::{debug, info, warn};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ExpertId(pub u64);
impl ExpertId {
pub fn new(id: u64) -> Self {
Self(id)
}
}
impl std::fmt::Display for ExpertId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "expert-{}", self.0)
}
}
#[async_trait]
pub trait Expert: Send + Sync {
fn id(&self) -> ExpertId;
async fn forward(&self, input: &Tensor) -> DistributedResult<Tensor>;
fn hidden_dim(&self) -> usize;
fn is_ready(&self) -> bool;
}
pub struct ExpertRegistry {
local_experts: HashMap<ExpertId, Box<dyn Expert>>,
remote_experts: HashMap<ExpertId, String>,
fallbacks: HashMap<ExpertId, Vec<ExpertId>>,
}
impl ExpertRegistry {
pub fn new() -> Self {
Self {
local_experts: HashMap::new(),
remote_experts: HashMap::new(),
fallbacks: HashMap::new(),
}
}
pub fn register_local(&mut self, expert: Box<dyn Expert>) {
let id = expert.id();
info!("Registering local expert {}", id);
self.local_experts.insert(id, expert);
}
pub fn register_remote(&mut self, expert_id: ExpertId, peer_id: String) {
info!(
"Registering remote expert {} at peer {}",
expert_id, peer_id
);
self.remote_experts.insert(expert_id, peer_id);
}
pub fn is_local(&self, expert_id: ExpertId) -> bool {
self.local_experts.contains_key(&expert_id)
}
pub fn get_local(&self, expert_id: ExpertId) -> Option<&dyn Expert> {
self.local_experts.get(&expert_id).map(|e| e.as_ref())
}
pub fn get_remote_peer(&self, expert_id: ExpertId) -> Option<&String> {
self.remote_experts.get(&expert_id)
}
pub fn register_fallback(&mut self, expert_id: ExpertId, fallbacks: Vec<ExpertId>) {
self.fallbacks.insert(expert_id, fallbacks);
}
pub fn get_fallbacks(&self, expert_id: ExpertId) -> Option<&Vec<ExpertId>> {
self.fallbacks.get(&expert_id)
}
pub fn report_failure(&mut self, _expert_id: ExpertId) {
warn!("Expert {} reported failure", _expert_id);
}
pub fn list_experts(&self) -> Vec<ExpertId> {
let mut experts: Vec<_> = self.local_experts.keys().copied().collect();
experts.extend(self.remote_experts.keys().copied());
experts.sort_by_key(|e| e.0);
experts.dedup();
experts
}
}
impl Default for ExpertRegistry {
fn default() -> Self {
Self::new()
}
}
pub struct LocalExpert {
id: ExpertId,
hidden_dim: usize,
}
impl LocalExpert {
pub fn new(id: u64, hidden_dim: usize) -> Self {
Self {
id: ExpertId::new(id),
hidden_dim,
}
}
}
#[async_trait]
impl Expert for LocalExpert {
fn id(&self) -> ExpertId {
self.id
}
async fn forward(&self, input: &Tensor) -> DistributedResult<Tensor> {
debug!(
"LocalExpert {} forward pass, input shape: {:?}",
self.id,
input.dims()
);
Ok(input.clone())
}
fn hidden_dim(&self) -> usize {
self.hidden_dim
}
fn is_ready(&self) -> bool {
true
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_expert_registry() {
let mut registry = ExpertRegistry::new();
let expert = LocalExpert::new(1, 4096);
registry.register_local(Box::new(expert));
assert!(registry.is_local(ExpertId::new(1)));
assert!(!registry.is_local(ExpertId::new(2)));
registry.register_remote(ExpertId::new(2), "peer-123".to_string());
assert!(!registry.is_local(ExpertId::new(2)));
assert_eq!(
registry.get_remote_peer(ExpertId::new(2)),
Some(&"peer-123".to_string())
);
}
}