use std::ops::Range;
use crate::model::{LlamaModel, Model, ModelLoader, RopeType};
use crate::model::layers::{Linear, NormLayer, RMSNorm};
use super::config::ClusterConfig;
use super::model::DistributedModel;
use super::pipeline::{PipelineExecutor, ShardConnection};
use super::proto::shard_service_client::ShardServiceClient;
use super::proto::{ConfigureRequest, HealthRequest, LayerData, NamedTensor};
use super::tensor_transfer::tensor_to_proto;
use super::{DistributedError, DistributedResult};
pub struct Coordinator {
config: ClusterConfig,
}
impl Coordinator {
pub fn new(config: ClusterConfig) -> Self {
Self { config }
}
pub async fn setup(&self) -> DistributedResult<DistributedModel> {
tracing::info!(
"Setting up distributed cluster with {} shards",
self.config.shards.len()
);
let loader = ModelLoader::load(&self.config.model_path)?;
let model = loader.build_model()?;
let model_config = model.config().clone();
let num_layers = model_config.num_layers;
let architecture = model.architecture();
let assignments = self.config.compute_layer_assignments(num_layers)?;
tracing::info!("Layer assignments:");
for (i, range) in assignments.iter().enumerate() {
tracing::info!(
" {} -> layers {}..{}",
self.config.shards[i].name,
range.start,
range.end
);
}
let mut shard_connections = Vec::new();
for (i, shard_spec) in self.config.shards.iter().enumerate() {
let range = &assignments[i];
tracing::info!(
"Connecting to shard '{}' at {}",
shard_spec.name,
shard_spec.address
);
let endpoint = tonic::transport::Channel::from_shared(format!(
"http://{}",
shard_spec.address
))
.map_err(|e| {
DistributedError::Config(format!(
"invalid shard address '{}': {}",
shard_spec.address, e
))
})?
.connect_timeout(self.config.connect_timeout())
.timeout(self.config.request_timeout());
let channel = endpoint.connect().await.map_err(|e| {
DistributedError::Shard(format!(
"failed to connect to shard '{}' at {}: {}",
shard_spec.name, shard_spec.address, e
))
})?;
let mut client = ShardServiceClient::new(channel)
.max_decoding_message_size(256 * 1024 * 1024) .max_encoding_message_size(256 * 1024 * 1024);
let health = client.health(HealthRequest {}).await.map_err(|e| {
DistributedError::HealthCheckFailed {
name: shard_spec.name.clone(),
reason: e.to_string(),
}
})?;
let health_resp = health.into_inner();
if !health_resp.healthy {
return Err(DistributedError::HealthCheckFailed {
name: shard_spec.name.clone(),
reason: "shard reported unhealthy".into(),
});
}
tracing::info!(
"Shard '{}' healthy (backend: {})",
shard_spec.name,
health_resp.backend_name
);
let use_neox = matches!(
model_config.rope_config.rope_type,
RopeType::NeoX
);
let max_seq = if self.config.max_seq_len > 0 {
self.config.max_seq_len
} else {
model_config.max_seq_len
};
client
.configure(ConfigureRequest {
hidden_size: model_config.hidden_size as u32,
intermediate_size: model_config.intermediate_size as u32,
num_layers: num_layers as u32,
num_heads: model_config.num_heads as u32,
num_kv_heads: model_config.num_kv_heads as u32,
head_dim: model_config.head_dim as u32,
max_seq_len: max_seq as u32,
norm_eps: model_config.norm_eps,
rope_freq_base: model_config.rope_config.freq_base,
rope_freq_scale: model_config.rope_config.freq_scale,
use_neox_rope: use_neox,
layer_start: range.start as u32,
layer_end: range.end as u32,
use_gpu: self.config.use_gpu,
})
.await
.map_err(|e| {
DistributedError::Shard(format!(
"failed to configure shard '{}': {}",
shard_spec.name, e
))
})?;
self.send_layers(&model, &mut client, range).await?;
shard_connections.push(ShardConnection {
client,
name: shard_spec.name.clone(),
layer_start: range.start,
layer_end: range.end,
});
}
let pipeline = PipelineExecutor::new(shard_connections)?;
let token_embedding = model.token_embedding().clone();
let norm = NormLayer::RMS(
RMSNorm::new(model.norm().weight().clone(), model.norm().eps())
.map_err(DistributedError::Model)?,
);
let output = Linear::new(model.output().weight.clone(), model.output().bias.clone())
.map_err(DistributedError::Model)?;
let distributed_model = DistributedModel::new(
token_embedding,
norm,
output,
pipeline,
model_config,
architecture,
);
tracing::info!("Distributed cluster setup complete");
Ok(distributed_model)
}
async fn send_layers(
&self,
model: &LlamaModel,
client: &mut ShardServiceClient<tonic::transport::Channel>,
range: &Range<usize>,
) -> DistributedResult<()> {
let layers: Vec<LayerData> = model
.layers()
.iter()
.filter(|l| range.contains(&l.layer_idx))
.map(|layer| {
let mut tensors = Vec::new();
tensors.push(NamedTensor {
name: "attn_norm.weight".into(),
tensor: Some(tensor_to_proto(layer.attn_norm.weight())),
});
if let Some(attn) = layer.attention() {
tensors.push(NamedTensor {
name: "attn_q.weight".into(),
tensor: Some(tensor_to_proto(&attn.wq.weight)),
});
if let Some(ref bias) = attn.wq.bias {
tensors.push(NamedTensor {
name: "attn_q.bias".into(),
tensor: Some(tensor_to_proto(bias)),
});
}
tensors.push(NamedTensor {
name: "attn_k.weight".into(),
tensor: Some(tensor_to_proto(&attn.wk.weight)),
});
if let Some(ref bias) = attn.wk.bias {
tensors.push(NamedTensor {
name: "attn_k.bias".into(),
tensor: Some(tensor_to_proto(bias)),
});
}
tensors.push(NamedTensor {
name: "attn_v.weight".into(),
tensor: Some(tensor_to_proto(&attn.wv.weight)),
});
if let Some(ref bias) = attn.wv.bias {
tensors.push(NamedTensor {
name: "attn_v.bias".into(),
tensor: Some(tensor_to_proto(bias)),
});
}
tensors.push(NamedTensor {
name: "attn_output.weight".into(),
tensor: Some(tensor_to_proto(&attn.wo.weight)),
});
if let Some(ref bias) = attn.wo.bias {
tensors.push(NamedTensor {
name: "attn_output.bias".into(),
tensor: Some(tensor_to_proto(bias)),
});
}
}
tensors.push(NamedTensor {
name: "ffn_norm.weight".into(),
tensor: Some(tensor_to_proto(layer.ffn_norm.weight())),
});
if let Some(ffn) = layer.ffn() {
tensors.push(NamedTensor {
name: "ffn_gate.weight".into(),
tensor: Some(tensor_to_proto(&ffn.w_gate.weight)),
});
tensors.push(NamedTensor {
name: "ffn_up.weight".into(),
tensor: Some(tensor_to_proto(&ffn.w_up.weight)),
});
tensors.push(NamedTensor {
name: "ffn_down.weight".into(),
tensor: Some(tensor_to_proto(&ffn.w_down.weight)),
});
}
LayerData {
layer_index: layer.layer_idx as u32,
tensors,
}
})
.collect();
let num_layers = layers.len();
let stream = futures::stream::iter(layers);
client
.load_layers(stream)
.await
.map_err(|e| {
DistributedError::Shard(format!("failed to stream layers: {}", e))
})?;
tracing::info!("Streamed {} layers to shard", num_layers);
Ok(())
}
}