use crate::tensor::Tensor;
use super::proto::shard_service_client::ShardServiceClient;
use super::proto::{ForwardRequest, ResetRequest};
use super::tensor_transfer::{tensor_from_proto, tensor_to_proto};
use super::{DistributedError, DistributedResult};
pub struct ShardConnection {
pub client: ShardServiceClient<tonic::transport::Channel>,
pub name: String,
pub layer_start: usize,
pub layer_end: usize,
}
pub struct PipelineExecutor {
shards: Vec<ShardConnection>,
}
impl PipelineExecutor {
pub fn new(shards: Vec<ShardConnection>) -> DistributedResult<Self> {
if shards.is_empty() {
return Err(DistributedError::NoShards);
}
Ok(Self { shards })
}
pub fn num_shards(&self) -> usize {
self.shards.len()
}
pub async fn forward(
&mut self,
hidden: &Tensor,
position: usize,
) -> DistributedResult<Tensor> {
let mut current_hidden = hidden.clone();
for shard in &mut self.shards {
let hidden_proto = tensor_to_proto(¤t_hidden);
let request = ForwardRequest {
hidden_state: Some(hidden_proto),
position: position as u32,
seq_len: (position + 1) as u32,
};
let response = shard
.client
.forward(request)
.await
.map_err(|e| {
DistributedError::Shard(format!(
"forward failed on shard '{}' (layers {}..{}): {}",
shard.name, shard.layer_start, shard.layer_end, e
))
})?
.into_inner();
if !response.success {
return Err(DistributedError::Shard(format!(
"shard '{}' returned error: {}",
shard.name, response.error
)));
}
let output_proto = response.hidden_state.ok_or_else(|| {
DistributedError::Shard(format!(
"shard '{}' returned empty hidden state",
shard.name
))
})?;
current_hidden = tensor_from_proto(&output_proto)?;
}
Ok(current_hidden)
}
pub async fn reset_kv_caches(&mut self) -> DistributedResult<()> {
for shard in &mut self.shards {
shard
.client
.reset_kv_cache(ResetRequest {})
.await
.map_err(|e| {
DistributedError::Shard(format!(
"failed to reset KV cache on shard '{}': {}",
shard.name, e
))
})?;
}
Ok(())
}
}