use std::sync::Arc;
use async_trait::async_trait;
use futures::stream::BoxStream;
use crate::batch::ExecuteBatch;
use crate::deployment::RateLimits;
use crate::error::{InferenceError, InferenceResult};
use crate::runtime::{RuntimeKind, TransportKind};
use crate::tokens::TokenChunk;
pub struct RunHandle {
inner: BoxStream<'static, InferenceResult<TokenChunk>>,
}
impl RunHandle {
pub fn streaming(inner: BoxStream<'static, InferenceResult<TokenChunk>>) -> Self {
Self { inner }
}
pub fn into_stream(self) -> BoxStream<'static, InferenceResult<TokenChunk>> {
self.inner
}
}
impl std::fmt::Debug for RunHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RunHandle").finish_non_exhaustive()
}
}
#[derive(Debug, Clone)]
pub enum WeightSource {
HuggingFace {
repo: String,
revision: Option<String>,
},
LocalPath {
path: std::path::PathBuf,
},
RuntimeManaged,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SessionRebuildCause {
CudaContextPoisoned,
RemoteAuthFailure,
RemoteConfigChange,
Manual,
}
pub type CudaContextHandle = Arc<dyn std::any::Any + Send + Sync>;
#[async_trait]
pub trait ModelRunner: Send + Sync {
async fn execute(&mut self, batch: ExecuteBatch) -> InferenceResult<RunHandle>;
async fn load_weights(
&mut self,
_ctx: Option<&CudaContextHandle>,
_source: WeightSource,
) -> InferenceResult<()> {
Ok(())
}
async fn rebuild_session(&mut self, cause: SessionRebuildCause) -> InferenceResult<()>;
fn runtime_kind(&self) -> RuntimeKind;
fn transport_kind(&self) -> TransportKind;
fn gil_pinned(&self) -> bool {
matches!(self.runtime_kind(), RuntimeKind::Vllm | RuntimeKind::Python(_))
}
fn rate_limits(&self) -> Option<&RateLimits> {
None
}
fn estimate_cost_usd(&self, _batch: &ExecuteBatch) -> f64 {
0.0
}
}
pub fn lift_internal<E: std::fmt::Display>(err: E) -> InferenceError {
InferenceError::Internal(err.to_string())
}