Skip to main content

atomr_infer_runtime_vllm/
lib.rs

1//! # inference-runtime-vllm
2//!
3//! vLLM (Python) runtime — canonical local-LLM backend. Doc §2.1, §10.3.
4//!
5//! ## Feature flags
6//!
7//! - `vllm` — pull in PyO3 + the `AsyncLLMEngine` bridge. Without
8//!   this feature the runner compiles to a typed-error stub so a
9//!   `cargo build --features remote-only` consumer never pulls
10//!   pyo3 / vllm / cudarc.
11//! - `gemma-default` — adds the env probe + HuggingFace cache
12//!   resolver + optional `hf-hub` pre-download path so an operator
13//!   can auto-provision a Gemma 4 deployment when the host has a
14//!   workable GPU + Python + vLLM + HF token. See
15//!   `inference::defaults::gemma` for the rollup-side adapter.
16//!
17//! ## Lifecycle
18//!
19//! `VllmRunner::new` is cheap and synchronous — it stores the config.
20//! The Python `AsyncLLMEngine` is built lazily on the first
21//! [`ModelRunner::execute`] call, so a runner can be instantiated
22//! on hosts without a GPU (handy for config-layer tests).
23
24#![forbid(unsafe_code)]
25#![deny(rust_2018_idioms)]
26
27use async_trait::async_trait;
28use serde::{Deserialize, Serialize};
29
30use atomr_infer_core::batch::ExecuteBatch;
31use atomr_infer_core::error::InferenceResult;
32#[cfg(any(not(feature = "vllm"), all(test, not(feature = "vllm"))))]
33use atomr_infer_core::error::InferenceError;
34use atomr_infer_core::runner::{ModelRunner, RunHandle, SessionRebuildCause};
35use atomr_infer_core::runtime::{RuntimeKind, TransportKind};
36
37#[cfg(feature = "vllm")]
38mod engine;
39
40#[cfg(feature = "gemma-default")]
41pub mod defaults;
42#[cfg(feature = "gemma-default")]
43pub mod hf_cache;
44#[cfg(feature = "gemma-default")]
45pub mod probe;
46
47/// vLLM engine configuration. Pass-through for the Python builder
48/// arguments (`AsyncEngineArgs`); the perf knobs at the bottom map
49/// 1:1 to vLLM's own settings of the same name.
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct VllmConfig {
52    /// HuggingFace repo id or local path the engine loads.
53    pub model: String,
54    #[serde(default = "default_tp")]
55    pub tensor_parallel_size: u32,
56    /// Numeric dtype: `"auto"`, `"float16"`, `"bfloat16"`, `"float32"`.
57    #[serde(default = "default_dtype")]
58    pub dtype: String,
59    /// Fraction of GPU memory the engine pre-allocates. Defaults to
60    /// vLLM's `0.9`; the Gemma auto-provisioner overrides to `0.5`
61    /// to leave room for dev tools.
62    #[serde(default)]
63    pub gpu_memory_utilization: Option<f32>,
64    /// Maximum sequence length. `None` ⇒ vLLM picks from the model
65    /// config.
66    #[serde(default)]
67    pub max_model_len: Option<u32>,
68    /// Optional HuggingFace cache directory. When set, the engine is
69    /// constructed with `HF_HOME` pointing here so multi-instance
70    /// deployments share a single on-disk cache.
71    #[serde(default)]
72    pub hf_cache_dir: Option<std::path::PathBuf>,
73    /// Disable CUDA graphs (vLLM `enforce_eager`). Default `None`
74    /// ⇒ vLLM picks (graphs on for most models). Enabling this is
75    /// the easiest perf experiment — CUDA graphs typically give
76    /// 1.5–2× throughput on small models.
77    #[serde(default)]
78    pub enforce_eager: Option<bool>,
79    /// Cache common prompt prefixes across requests. Useful for chat
80    /// with shared system prompts. Default `None` ⇒ vLLM default
81    /// (off in v0.6, on in v0.7+).
82    #[serde(default)]
83    pub enable_prefix_caching: Option<bool>,
84    /// Chunked prefill: split long prompts so prefill interleaves
85    /// with decode. Improves TTFT under concurrent load.
86    #[serde(default)]
87    pub enable_chunked_prefill: Option<bool>,
88    /// Maximum concurrent sequences the scheduler runs. Higher ⇒
89    /// better steady-state throughput at cost of per-request latency.
90    /// Default `None` ⇒ vLLM picks (256 in v0.6).
91    #[serde(default)]
92    pub max_num_seqs: Option<u32>,
93    /// PagedAttention block size in tokens. Default `None` ⇒ vLLM
94    /// picks (16). Larger blocks ⇒ better throughput, smaller ⇒
95    /// finer-grained memory packing.
96    #[serde(default)]
97    pub block_size: Option<u32>,
98    /// Quantization scheme: `"awq"`, `"gptq"`, `"squeezellm"`,
99    /// `"fp8"`, etc. `None` ⇒ unquantized (whatever the checkpoint
100    /// natively is).
101    #[serde(default)]
102    pub quantization: Option<String>,
103    /// Per-prompt multimodal-input cap, e.g. `{"image": 0,
104    /// "audio": 0}`. For multimodal models like Gemma 4, setting
105    /// these to 0 tells vLLM the workload is text-only and lets it
106    /// skip the worst-case vision/audio buffer allocation during
107    /// KV-cache profiling — often the difference between fitting
108    /// in 16 GB and OOMing. Note: vLLM 0.20's Gemma 4 text-only
109    /// path is buggy (per-layer-embeddings share mm plumbing); use
110    /// `cpu_offload_gb` instead on small GPUs.
111    #[serde(default)]
112    pub limit_mm_per_prompt: Option<std::collections::BTreeMap<String, u32>>,
113    /// Offload N GB of model weights to CPU RAM (vLLM
114    /// `cpu_offload_gb`). On a 16 GB GPU running Gemma 4 E4B,
115    /// `Some(4)` is enough to fit the multimodal profile pass that
116    /// otherwise OOMs at ~15.5 GB. Trade-off: per-token decode
117    /// slows ~30–50 % because each forward pass copies offloaded
118    /// weights GPU↔CPU.
119    #[serde(default)]
120    pub cpu_offload_gb: Option<u32>,
121}
122
123fn default_tp() -> u32 {
124    1
125}
126fn default_dtype() -> String {
127    "auto".to_string()
128}
129
130/// vLLM runner. Constructs in O(1); the engine boots lazily on the
131/// first call to [`ModelRunner::execute`].
132pub struct VllmRunner {
133    #[cfg_attr(not(feature = "vllm"), allow(dead_code))]
134    config: VllmConfig,
135    #[cfg(feature = "vllm")]
136    engine: tokio::sync::OnceCell<std::sync::Arc<engine::VllmEngine>>,
137}
138
139impl VllmRunner {
140    pub fn new(config: VllmConfig) -> Self {
141        Self {
142            config,
143            #[cfg(feature = "vllm")]
144            engine: tokio::sync::OnceCell::new(),
145        }
146    }
147
148    #[cfg(feature = "vllm")]
149    async fn ensure_engine(&self) -> InferenceResult<std::sync::Arc<engine::VllmEngine>> {
150        self.engine
151            .get_or_try_init(|| async {
152                engine::VllmEngine::launch(&self.config)
153                    .await
154                    .map(std::sync::Arc::new)
155            })
156            .await
157            .cloned()
158    }
159}
160
161#[async_trait]
162impl ModelRunner for VllmRunner {
163    #[cfg_attr(
164        feature = "vllm",
165        tracing::instrument(skip(self, batch), fields(request_id = %batch.request_id, model = %batch.model))
166    )]
167    async fn execute(&mut self, batch: ExecuteBatch) -> InferenceResult<RunHandle> {
168        #[cfg(not(feature = "vllm"))]
169        {
170            let _ = batch;
171            Err(InferenceError::Internal(
172                "vllm feature disabled at build time — rebuild with --features vllm".into(),
173            ))
174        }
175        #[cfg(feature = "vllm")]
176        {
177            let engine = self.ensure_engine().await?;
178            engine.generate(batch).await
179        }
180    }
181
182    async fn rebuild_session(&mut self, cause: SessionRebuildCause) -> InferenceResult<()> {
183        #[cfg(feature = "vllm")]
184        {
185            // CudaContextPoisoned / Manual ⇒ tear down the cached
186            // engine handle. The next `execute` reconstructs it; vLLM
187            // V1 doesn't always release VRAM cleanly, so a hard
188            // rebuild may need a process restart in practice.
189            if matches!(
190                cause,
191                SessionRebuildCause::CudaContextPoisoned | SessionRebuildCause::Manual
192            ) {
193                self.engine = tokio::sync::OnceCell::new();
194            }
195        }
196        let _ = cause;
197        Ok(())
198    }
199
200    fn runtime_kind(&self) -> RuntimeKind {
201        RuntimeKind::Vllm
202    }
203    fn transport_kind(&self) -> TransportKind {
204        TransportKind::LocalGpu
205    }
206    fn gil_pinned(&self) -> bool {
207        true
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214
215    fn test_config(model: &str) -> VllmConfig {
216        VllmConfig {
217            model: model.into(),
218            tensor_parallel_size: 1,
219            dtype: "auto".into(),
220            gpu_memory_utilization: Some(0.5),
221            max_model_len: Some(8192),
222            hf_cache_dir: None,
223            enforce_eager: None,
224            enable_prefix_caching: None,
225            enable_chunked_prefill: None,
226            max_num_seqs: None,
227            block_size: None,
228            quantization: None,
229            limit_mm_per_prompt: None,
230            cpu_offload_gb: None,
231        }
232    }
233
234    #[test]
235    fn config_round_trips_through_serde() {
236        let cfg = test_config("google/gemma-4-E4B-it");
237        let s = serde_json::to_string(&cfg).expect("serialize");
238        let back: VllmConfig = serde_json::from_str(&s).expect("deserialize");
239        assert_eq!(back.model, cfg.model);
240        assert_eq!(back.gpu_memory_utilization, cfg.gpu_memory_utilization);
241    }
242
243    #[test]
244    fn runner_reports_runtime_kind() {
245        let r = VllmRunner::new(test_config("test"));
246        assert_eq!(r.runtime_kind(), RuntimeKind::Vllm);
247        assert_eq!(r.transport_kind(), TransportKind::LocalGpu);
248        assert!(r.gil_pinned());
249    }
250
251    #[cfg(not(feature = "vllm"))]
252    #[tokio::test]
253    async fn execute_without_feature_returns_internal_error() {
254        use atomr_infer_core::batch::SamplingParams;
255
256        let mut r = VllmRunner::new(test_config("test"));
257        let batch = ExecuteBatch {
258            request_id: "t".into(),
259            model: "t".into(),
260            messages: vec![],
261            sampling: SamplingParams::default(),
262            stream: false,
263            estimated_tokens: 1,
264        };
265        let result = r.execute(batch).await;
266        assert!(matches!(result, Err(InferenceError::Internal(_))));
267    }
268}