use tokio::runtime::Handle;
use crate::backend::Backend;
use crate::model::{
Architecture, InferenceContext, Model, ModelConfig, ModelError, ModelResult,
};
use crate::model::layers::{Linear, NormLayer};
use crate::tensor::{DType, Tensor};
use super::pipeline::PipelineExecutor;
pub struct DistributedModel {
token_embedding: Tensor,
norm: NormLayer,
output: Linear,
pipeline: tokio::sync::Mutex<PipelineExecutor>,
config: ModelConfig,
architecture: Architecture,
}
unsafe impl Send for DistributedModel {}
unsafe impl Sync for DistributedModel {}
impl DistributedModel {
pub fn new(
token_embedding: Tensor,
norm: NormLayer,
output: Linear,
pipeline: PipelineExecutor,
config: ModelConfig,
architecture: Architecture,
) -> Self {
Self {
token_embedding,
norm,
output,
pipeline: tokio::sync::Mutex::new(pipeline),
config,
architecture,
}
}
fn dequantize_embeddings(&self, backend: &dyn Backend) -> ModelResult<Vec<f32>> {
if self.token_embedding.dtype() == DType::F32 {
return Ok(self.token_embedding.as_f32()?.to_vec());
}
let vocab_size = self.config.vocab_size;
let hidden_size = self.config.hidden_size;
let mut out = Tensor::zeros(vec![vocab_size, hidden_size], DType::F32);
backend.dequantize(&self.token_embedding, &mut out)?;
Ok(out.as_f32()?.to_vec())
}
fn compute_logits(&self, hidden: &Tensor, backend: &dyn Backend) -> ModelResult<Tensor> {
let mut normed = Tensor::zeros(hidden.shape().to_vec(), DType::F32);
self.norm.forward(hidden, &mut normed, backend)?;
let mut logits = Tensor::zeros(vec![self.config.vocab_size], DType::F32);
self.output.forward(&normed, &mut logits, backend)?;
Ok(logits)
}
}
impl Model for DistributedModel {
fn forward(&self, tokens: &[u32], ctx: &mut InferenceContext) -> ModelResult<Tensor> {
let backend = ctx.backend.as_ref();
let new_pos = ctx.position + tokens.len();
if new_pos > self.config.max_seq_len {
return Err(ModelError::ContextLengthExceeded {
current: new_pos,
max: self.config.max_seq_len,
});
}
let embedding_data = self.dequantize_embeddings(backend)?;
let hidden_size = self.config.hidden_size;
let vocab_size = self.config.vocab_size;
let handle = Handle::try_current().map_err(|_| {
ModelError::ConfigError(
"distributed model requires a tokio runtime".into(),
)
})?;
let mut hidden_buf = vec![0.0f32; hidden_size];
for (token_offset, &token) in tokens.iter().enumerate() {
let current_pos = ctx.position + token_offset;
let token_idx = token as usize;
if token_idx >= vocab_size {
return Err(ModelError::InvalidMetadata {
key: "token".into(),
message: format!("Token ID {} exceeds vocab size {}", token, vocab_size),
});
}
let src = token_idx * hidden_size;
hidden_buf.copy_from_slice(&embedding_data[src..src + hidden_size]);
let hidden = Tensor::from_f32(&hidden_buf, vec![hidden_size])?;
let pipeline_result = tokio::task::block_in_place(|| {
handle.block_on(async {
let mut pipeline = self.pipeline.lock().await;
pipeline.forward(&hidden, current_pos).await
})
});
let output_hidden = pipeline_result.map_err(|e| {
ModelError::ConfigError(format!("distributed forward failed: {}", e))
})?;
if token_offset + 1 == tokens.len() {
ctx.position = new_pos;
ctx.kv_cache.seq_len = new_pos;
return self.compute_logits(&output_hidden, backend);
}
}
Err(ModelError::ConfigError("No tokens to process".into()))
}
fn config(&self) -> &ModelConfig {
&self.config
}
fn architecture(&self) -> Architecture {
self.architecture
}
}