atomr_infer_runtime_vllm/
lib.rs1#![forbid(unsafe_code)]
25#![deny(rust_2018_idioms)]
26
27use async_trait::async_trait;
28use serde::{Deserialize, Serialize};
29
30use atomr_infer_core::batch::ExecuteBatch;
31use atomr_infer_core::error::InferenceResult;
32#[cfg(any(not(feature = "vllm"), all(test, not(feature = "vllm"))))]
33use atomr_infer_core::error::InferenceError;
34use atomr_infer_core::runner::{ModelRunner, RunHandle, SessionRebuildCause};
35use atomr_infer_core::runtime::{RuntimeKind, TransportKind};
36
37#[cfg(feature = "vllm")]
38mod engine;
39
40#[cfg(feature = "gemma-default")]
41pub mod defaults;
42#[cfg(feature = "gemma-default")]
43pub mod hf_cache;
44#[cfg(feature = "gemma-default")]
45pub mod probe;
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct VllmConfig {
52 pub model: String,
54 #[serde(default = "default_tp")]
55 pub tensor_parallel_size: u32,
56 #[serde(default = "default_dtype")]
58 pub dtype: String,
59 #[serde(default)]
63 pub gpu_memory_utilization: Option<f32>,
64 #[serde(default)]
67 pub max_model_len: Option<u32>,
68 #[serde(default)]
72 pub hf_cache_dir: Option<std::path::PathBuf>,
73 #[serde(default)]
78 pub enforce_eager: Option<bool>,
79 #[serde(default)]
83 pub enable_prefix_caching: Option<bool>,
84 #[serde(default)]
87 pub enable_chunked_prefill: Option<bool>,
88 #[serde(default)]
92 pub max_num_seqs: Option<u32>,
93 #[serde(default)]
97 pub block_size: Option<u32>,
98 #[serde(default)]
102 pub quantization: Option<String>,
103 #[serde(default)]
112 pub limit_mm_per_prompt: Option<std::collections::BTreeMap<String, u32>>,
113 #[serde(default)]
120 pub cpu_offload_gb: Option<u32>,
121}
122
123fn default_tp() -> u32 {
124 1
125}
126fn default_dtype() -> String {
127 "auto".to_string()
128}
129
130pub struct VllmRunner {
133 #[cfg_attr(not(feature = "vllm"), allow(dead_code))]
134 config: VllmConfig,
135 #[cfg(feature = "vllm")]
136 engine: tokio::sync::OnceCell<std::sync::Arc<engine::VllmEngine>>,
137}
138
139impl VllmRunner {
140 pub fn new(config: VllmConfig) -> Self {
141 Self {
142 config,
143 #[cfg(feature = "vllm")]
144 engine: tokio::sync::OnceCell::new(),
145 }
146 }
147
148 #[cfg(feature = "vllm")]
149 async fn ensure_engine(&self) -> InferenceResult<std::sync::Arc<engine::VllmEngine>> {
150 self.engine
151 .get_or_try_init(|| async {
152 engine::VllmEngine::launch(&self.config)
153 .await
154 .map(std::sync::Arc::new)
155 })
156 .await
157 .cloned()
158 }
159}
160
161#[async_trait]
162impl ModelRunner for VllmRunner {
163 #[cfg_attr(
164 feature = "vllm",
165 tracing::instrument(skip(self, batch), fields(request_id = %batch.request_id, model = %batch.model))
166 )]
167 async fn execute(&mut self, batch: ExecuteBatch) -> InferenceResult<RunHandle> {
168 #[cfg(not(feature = "vllm"))]
169 {
170 let _ = batch;
171 Err(InferenceError::Internal(
172 "vllm feature disabled at build time — rebuild with --features vllm".into(),
173 ))
174 }
175 #[cfg(feature = "vllm")]
176 {
177 let engine = self.ensure_engine().await?;
178 engine.generate(batch).await
179 }
180 }
181
182 async fn rebuild_session(&mut self, cause: SessionRebuildCause) -> InferenceResult<()> {
183 #[cfg(feature = "vllm")]
184 {
185 if matches!(
190 cause,
191 SessionRebuildCause::CudaContextPoisoned | SessionRebuildCause::Manual
192 ) {
193 self.engine = tokio::sync::OnceCell::new();
194 }
195 }
196 let _ = cause;
197 Ok(())
198 }
199
200 fn runtime_kind(&self) -> RuntimeKind {
201 RuntimeKind::Vllm
202 }
203 fn transport_kind(&self) -> TransportKind {
204 TransportKind::LocalGpu
205 }
206 fn gil_pinned(&self) -> bool {
207 true
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214
215 fn test_config(model: &str) -> VllmConfig {
216 VllmConfig {
217 model: model.into(),
218 tensor_parallel_size: 1,
219 dtype: "auto".into(),
220 gpu_memory_utilization: Some(0.5),
221 max_model_len: Some(8192),
222 hf_cache_dir: None,
223 enforce_eager: None,
224 enable_prefix_caching: None,
225 enable_chunked_prefill: None,
226 max_num_seqs: None,
227 block_size: None,
228 quantization: None,
229 limit_mm_per_prompt: None,
230 cpu_offload_gb: None,
231 }
232 }
233
234 #[test]
235 fn config_round_trips_through_serde() {
236 let cfg = test_config("google/gemma-4-E4B-it");
237 let s = serde_json::to_string(&cfg).expect("serialize");
238 let back: VllmConfig = serde_json::from_str(&s).expect("deserialize");
239 assert_eq!(back.model, cfg.model);
240 assert_eq!(back.gpu_memory_utilization, cfg.gpu_memory_utilization);
241 }
242
243 #[test]
244 fn runner_reports_runtime_kind() {
245 let r = VllmRunner::new(test_config("test"));
246 assert_eq!(r.runtime_kind(), RuntimeKind::Vllm);
247 assert_eq!(r.transport_kind(), TransportKind::LocalGpu);
248 assert!(r.gil_pinned());
249 }
250
251 #[cfg(not(feature = "vllm"))]
252 #[tokio::test]
253 async fn execute_without_feature_returns_internal_error() {
254 use atomr_infer_core::batch::SamplingParams;
255
256 let mut r = VllmRunner::new(test_config("test"));
257 let batch = ExecuteBatch {
258 request_id: "t".into(),
259 model: "t".into(),
260 messages: vec![],
261 sampling: SamplingParams::default(),
262 stream: false,
263 estimated_tokens: 1,
264 };
265 let result = r.execute(batch).await;
266 assert!(matches!(result, Err(InferenceError::Internal(_))));
267 }
268}