use crate::error::DistributedResult;
use async_trait::async_trait;
use candle_core::Tensor;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tracing::{debug, info, warn};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ExpertId(pub u64);
impl ExpertId {
pub fn new(id: u64) -> Self {
Self(id)
}
}
impl std::fmt::Display for ExpertId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "expert-{}", self.0)
}
}
#[async_trait]
pub trait Expert: Send + Sync {
fn id(&self) -> ExpertId;
async fn forward(&self, input: &Tensor) -> DistributedResult<Tensor>;
fn hidden_dim(&self) -> usize;
fn is_ready(&self) -> bool;
}
pub struct ExpertRegistry {
local_experts: HashMap<ExpertId, Box<dyn Expert>>,
remote_experts: HashMap<ExpertId, String>,
fallbacks: HashMap<ExpertId, Vec<ExpertId>>,
}
impl ExpertRegistry {
pub fn new() -> Self {
Self {
local_experts: HashMap::new(),
remote_experts: HashMap::new(),
fallbacks: HashMap::new(),
}
}
pub fn register_local(&mut self, expert: Box<dyn Expert>) {
let id = expert.id();
info!("Registering local expert {}", id);
self.local_experts.insert(id, expert);
}
pub fn register_remote(&mut self, expert_id: ExpertId, peer_id: String) {
info!(
"Registering remote expert {} at peer {}",
expert_id, peer_id
);
self.remote_experts.insert(expert_id, peer_id);
}
pub fn is_local(&self, expert_id: ExpertId) -> bool {
self.local_experts.contains_key(&expert_id)
}
pub fn get_local(&self, expert_id: ExpertId) -> Option<&dyn Expert> {
self.local_experts.get(&expert_id).map(|e| e.as_ref())
}
pub fn get_remote_peer(&self, expert_id: ExpertId) -> Option<&String> {
self.remote_experts.get(&expert_id)
}
pub fn register_fallback(&mut self, expert_id: ExpertId, fallbacks: Vec<ExpertId>) {
self.fallbacks.insert(expert_id, fallbacks);
}
pub fn get_fallbacks(&self, expert_id: ExpertId) -> Option<&Vec<ExpertId>> {
self.fallbacks.get(&expert_id)
}
pub fn report_failure(&mut self, _expert_id: ExpertId) {
warn!("Expert {} reported failure", _expert_id);
}
pub fn list_experts(&self) -> Vec<ExpertId> {
let mut experts: Vec<_> = self.local_experts.keys().copied().collect();
experts.extend(self.remote_experts.keys().copied());
experts.sort_by_key(|e| e.0);
experts.dedup();
experts
}
}
impl Default for ExpertRegistry {
fn default() -> Self {
Self::new()
}
}
pub struct LocalExpert {
id: ExpertId,
hidden_dim: usize,
}
impl LocalExpert {
pub fn new(id: u64, hidden_dim: usize) -> Self {
Self {
id: ExpertId::new(id),
hidden_dim,
}
}
}
#[async_trait]
impl Expert for LocalExpert {
fn id(&self) -> ExpertId {
self.id
}
async fn forward(&self, input: &Tensor) -> DistributedResult<Tensor> {
debug!(
"LocalExpert {} forward pass, input shape: {:?}",
self.id,
input.dims()
);
Ok(input.clone())
}
fn hidden_dim(&self) -> usize {
self.hidden_dim
}
fn is_ready(&self) -> bool {
true
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_expert_registry() {
let mut registry = ExpertRegistry::new();
let expert = LocalExpert::new(1, 4096);
registry.register_local(Box::new(expert));
assert!(registry.is_local(ExpertId::new(1)));
assert!(!registry.is_local(ExpertId::new(2)));
registry.register_remote(ExpertId::new(2), "peer-123".to_string());
assert!(!registry.is_local(ExpertId::new(2)));
assert_eq!(
registry.get_remote_peer(ExpertId::new(2)),
Some(&"peer-123".to_string())
);
}
#[test]
fn test_expert_id_display() {
let id = ExpertId::new(7);
assert_eq!(id.to_string(), "expert-7");
}
#[test]
fn test_expert_id_equality_and_hash() {
use std::collections::HashSet;
let a = ExpertId::new(3);
let b = ExpertId::new(3);
let c = ExpertId::new(4);
assert_eq!(a, b);
assert_ne!(a, c);
let mut s = HashSet::new();
s.insert(a);
s.insert(b);
assert_eq!(s.len(), 1);
}
#[test]
fn test_list_experts_sorted() {
let mut registry = ExpertRegistry::new();
registry.register_local(Box::new(LocalExpert::new(3, 64)));
registry.register_local(Box::new(LocalExpert::new(1, 64)));
registry.register_remote(ExpertId::new(2), "peer-x".to_string());
let list = registry.list_experts();
assert_eq!(list, vec![ExpertId::new(1), ExpertId::new(2), ExpertId::new(3)]);
}
#[test]
fn test_list_experts_deduplicates() {
let mut registry = ExpertRegistry::new();
registry.register_local(Box::new(LocalExpert::new(5, 64)));
registry.register_remote(ExpertId::new(5), "peer-y".to_string());
let list = registry.list_experts();
assert_eq!(list.iter().filter(|&&id| id == ExpertId::new(5)).count(), 1);
}
#[test]
fn test_fallback_registration_and_retrieval() {
let mut registry = ExpertRegistry::new();
let fallbacks = vec![ExpertId::new(10), ExpertId::new(11)];
registry.register_fallback(ExpertId::new(5), fallbacks.clone());
assert_eq!(registry.get_fallbacks(ExpertId::new(5)), Some(&fallbacks));
assert_eq!(registry.get_fallbacks(ExpertId::new(99)), None);
}
#[test]
fn test_report_failure_does_not_panic() {
let mut registry = ExpertRegistry::new();
registry.report_failure(ExpertId::new(999));
}
#[test]
fn test_local_expert_properties() {
let expert = LocalExpert::new(42, 2048);
assert_eq!(expert.id(), ExpertId::new(42));
assert_eq!(expert.hidden_dim(), 2048);
assert!(expert.is_ready());
}
#[tokio::test]
async fn test_local_expert_forward_passthrough() {
use candle_core::{DType, Device, Tensor};
let expert = LocalExpert::new(0, 4);
let input = Tensor::zeros((2usize, 4usize), DType::F32, &Device::Cpu).unwrap();
let output = expert.forward(&input).await.unwrap();
assert_eq!(output.dims(), input.dims());
}
}