#![forbid(unsafe_code)]
#![deny(rust_2018_idioms)]
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use atomr_infer_core::batch::ExecuteBatch;
use atomr_infer_core::error::InferenceResult;
#[cfg(any(not(feature = "vllm"), all(test, not(feature = "vllm"))))]
use atomr_infer_core::error::InferenceError;
use atomr_infer_core::runner::{ModelRunner, RunHandle, SessionRebuildCause};
use atomr_infer_core::runtime::{RuntimeKind, TransportKind};
#[cfg(feature = "vllm")]
mod engine;
#[cfg(feature = "gemma-default")]
pub mod defaults;
#[cfg(feature = "gemma-default")]
pub mod hf_cache;
#[cfg(feature = "gemma-default")]
pub mod probe;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VllmConfig {
pub model: String,
#[serde(default = "default_tp")]
pub tensor_parallel_size: u32,
#[serde(default = "default_dtype")]
pub dtype: String,
#[serde(default)]
pub gpu_memory_utilization: Option<f32>,
#[serde(default)]
pub max_model_len: Option<u32>,
#[serde(default)]
pub hf_cache_dir: Option<std::path::PathBuf>,
#[serde(default)]
pub enforce_eager: Option<bool>,
#[serde(default)]
pub enable_prefix_caching: Option<bool>,
#[serde(default)]
pub enable_chunked_prefill: Option<bool>,
#[serde(default)]
pub max_num_seqs: Option<u32>,
#[serde(default)]
pub block_size: Option<u32>,
#[serde(default)]
pub quantization: Option<String>,
#[serde(default)]
pub limit_mm_per_prompt: Option<std::collections::BTreeMap<String, u32>>,
#[serde(default)]
pub cpu_offload_gb: Option<u32>,
}
fn default_tp() -> u32 {
1
}
fn default_dtype() -> String {
"auto".to_string()
}
pub struct VllmRunner {
#[cfg_attr(not(feature = "vllm"), allow(dead_code))]
config: VllmConfig,
#[cfg(feature = "vllm")]
engine: tokio::sync::OnceCell<std::sync::Arc<engine::VllmEngine>>,
}
impl VllmRunner {
pub fn new(config: VllmConfig) -> Self {
Self {
config,
#[cfg(feature = "vllm")]
engine: tokio::sync::OnceCell::new(),
}
}
#[cfg(feature = "vllm")]
async fn ensure_engine(&self) -> InferenceResult<std::sync::Arc<engine::VllmEngine>> {
self.engine
.get_or_try_init(|| async {
engine::VllmEngine::launch(&self.config)
.await
.map(std::sync::Arc::new)
})
.await
.cloned()
}
}
#[async_trait]
impl ModelRunner for VllmRunner {
#[cfg_attr(
feature = "vllm",
tracing::instrument(skip(self, batch), fields(request_id = %batch.request_id, model = %batch.model))
)]
async fn execute(&mut self, batch: ExecuteBatch) -> InferenceResult<RunHandle> {
#[cfg(not(feature = "vllm"))]
{
let _ = batch;
Err(InferenceError::Internal(
"vllm feature disabled at build time — rebuild with --features vllm".into(),
))
}
#[cfg(feature = "vllm")]
{
let engine = self.ensure_engine().await?;
engine.generate(batch).await
}
}
async fn rebuild_session(&mut self, cause: SessionRebuildCause) -> InferenceResult<()> {
#[cfg(feature = "vllm")]
{
if matches!(
cause,
SessionRebuildCause::CudaContextPoisoned | SessionRebuildCause::Manual
) {
self.engine = tokio::sync::OnceCell::new();
}
}
let _ = cause;
Ok(())
}
fn runtime_kind(&self) -> RuntimeKind {
RuntimeKind::Vllm
}
fn transport_kind(&self) -> TransportKind {
TransportKind::LocalGpu
}
fn gil_pinned(&self) -> bool {
true
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_config(model: &str) -> VllmConfig {
VllmConfig {
model: model.into(),
tensor_parallel_size: 1,
dtype: "auto".into(),
gpu_memory_utilization: Some(0.5),
max_model_len: Some(8192),
hf_cache_dir: None,
enforce_eager: None,
enable_prefix_caching: None,
enable_chunked_prefill: None,
max_num_seqs: None,
block_size: None,
quantization: None,
limit_mm_per_prompt: None,
cpu_offload_gb: None,
}
}
#[test]
fn config_round_trips_through_serde() {
let cfg = test_config("google/gemma-4-E4B-it");
let s = serde_json::to_string(&cfg).expect("serialize");
let back: VllmConfig = serde_json::from_str(&s).expect("deserialize");
assert_eq!(back.model, cfg.model);
assert_eq!(back.gpu_memory_utilization, cfg.gpu_memory_utilization);
}
#[test]
fn runner_reports_runtime_kind() {
let r = VllmRunner::new(test_config("test"));
assert_eq!(r.runtime_kind(), RuntimeKind::Vllm);
assert_eq!(r.transport_kind(), TransportKind::LocalGpu);
assert!(r.gil_pinned());
}
#[cfg(not(feature = "vllm"))]
#[tokio::test]
async fn execute_without_feature_returns_internal_error() {
use atomr_infer_core::batch::SamplingParams;
let mut r = VllmRunner::new(test_config("test"));
let batch = ExecuteBatch {
request_id: "t".into(),
model: "t".into(),
messages: vec![],
sampling: SamplingParams::default(),
stream: false,
estimated_tokens: 1,
};
let result = r.execute(batch).await;
assert!(matches!(result, Err(InferenceError::Internal(_))));
}
}