clnrm_core/services/
vllm.rs

1//! vLLM Service Plugin
2//!
3//! Provides integration with vLLM (Very Large Language Model) inference server.
4//! vLLM is a high-throughput, memory-efficient inference engine for LLMs.
5
6use crate::cleanroom::{HealthStatus, ServiceHandle, ServicePlugin};
7use crate::error::{CleanroomError, Result};
8use serde_json::{json, Value};
9use std::collections::HashMap;
10use std::sync::Arc;
11use tokio::sync::RwLock;
12use uuid::Uuid;
13
14/// vLLM service configuration
15#[derive(Debug, Clone)]
16pub struct VllmConfig {
17    /// Service endpoint URL
18    pub endpoint: String,
19    /// Model to serve
20    pub model: String,
21    /// Maximum number of sequences per iteration
22    pub max_num_seqs: Option<u32>,
23    /// Maximum model length
24    pub max_model_len: Option<u32>,
25    /// Tensor parallelism degree
26    pub tensor_parallel_size: Option<u32>,
27    /// GPU memory utilization
28    pub gpu_memory_utilization: Option<f32>,
29    /// Enable automatic prefix caching
30    pub enable_prefix_caching: Option<bool>,
31    /// Request timeout in seconds
32    pub timeout_seconds: u64,
33}
34
35/// vLLM service plugin
36#[derive(Debug)]
37pub struct VllmPlugin {
38    name: String,
39    config: VllmConfig,
40    client: Arc<RwLock<Option<reqwest::Client>>>,
41}
42
43impl VllmPlugin {
44    /// Create a new vLLM plugin instance
45    pub fn new(name: &str, config: VllmConfig) -> Self {
46        Self {
47            name: name.to_string(),
48            config,
49            client: Arc::new(RwLock::new(None)),
50        }
51    }
52
53    /// Initialize the HTTP client for vLLM API calls
54    async fn init_client(&self) -> Result<reqwest::Client> {
55        let client = reqwest::Client::builder()
56            .timeout(std::time::Duration::from_secs(self.config.timeout_seconds))
57            .build()
58            .map_err(|e| {
59                CleanroomError::internal_error(format!("Failed to create HTTP client: {}", e))
60            })?;
61
62        Ok(client)
63    }
64
65    /// Test connection to vLLM service
66    async fn test_connection(&self) -> Result<()> {
67        let mut client_guard = self.client.write().await;
68        if client_guard.is_none() {
69            *client_guard = Some(self.init_client().await?);
70        }
71        let client = client_guard
72            .as_ref()
73            .ok_or_else(|| CleanroomError::internal_error("HTTP client not initialized"))?;
74
75        let url = format!("{}/health", self.config.endpoint);
76
77        let response = client.get(&url).send().await.map_err(|e| {
78            CleanroomError::service_error(format!("Failed to connect to vLLM: {}", e))
79        })?;
80
81        if response.status().is_success() {
82            Ok(())
83        } else {
84            Err(CleanroomError::service_error("vLLM service not responding"))
85        }
86    }
87
88    /// Generate text using vLLM API (OpenAI-compatible)
89    pub async fn generate_text(
90        &self,
91        prompt: &str,
92        max_tokens: Option<u32>,
93        temperature: Option<f32>,
94    ) -> Result<VllmResponse> {
95        let mut client_guard = self.client.write().await;
96        if client_guard.is_none() {
97            *client_guard = Some(self.init_client().await?);
98        }
99        let client = client_guard
100            .as_ref()
101            .ok_or_else(|| CleanroomError::internal_error("HTTP client not initialized"))?;
102
103        let url = format!("{}/v1/completions", self.config.endpoint);
104
105        let mut payload = json!({
106            "model": self.config.model,
107            "prompt": prompt,
108            "stream": false
109        });
110
111        if let Some(max_tokens) = max_tokens {
112            payload["max_tokens"] = json!(max_tokens);
113        }
114
115        if let Some(temperature) = temperature {
116            payload["temperature"] = json!(temperature);
117        }
118
119        let response = client
120            .post(&url)
121            .header("Content-Type", "application/json")
122            .json(&payload)
123            .send()
124            .await
125            .map_err(|e| {
126                CleanroomError::service_error(format!("Failed to generate text: {}", e))
127            })?;
128
129        if response.status().is_success() {
130            let vllm_response: VllmResponse = response.json().await.map_err(|e| {
131                CleanroomError::service_error(format!("Failed to parse response: {}", e))
132            })?;
133
134            Ok(vllm_response)
135        } else {
136            let error_text = response
137                .text()
138                .await
139                .unwrap_or_else(|_| "Unknown error".to_string());
140
141            Err(CleanroomError::service_error(format!(
142                "vLLM API error: {}",
143                error_text
144            )))
145        }
146    }
147
148    /// Get model information from vLLM
149    pub async fn get_model_info(&self) -> Result<VllmModelInfo> {
150        let mut client_guard = self.client.write().await;
151        if client_guard.is_none() {
152            *client_guard = Some(self.init_client().await?);
153        }
154        let client = client_guard
155            .as_ref()
156            .ok_or_else(|| CleanroomError::internal_error("HTTP client not initialized"))?;
157
158        let url = format!("{}/v1/models", self.config.endpoint);
159
160        let response = client.get(&url).send().await.map_err(|e| {
161            CleanroomError::service_error(format!("Failed to get model info: {}", e))
162        })?;
163
164        if response.status().is_success() {
165            let model_info: VllmModelInfo = response.json().await.map_err(|e| {
166                CleanroomError::service_error(format!("Failed to parse model info: {}", e))
167            })?;
168
169            Ok(model_info)
170        } else {
171            Err(CleanroomError::service_error(
172                "Failed to retrieve model info",
173            ))
174        }
175    }
176}
177
178/// Response from vLLM text generation (OpenAI-compatible)
179#[derive(Debug, serde::Deserialize)]
180pub struct VllmResponse {
181    /// Unique ID for the request
182    pub id: String,
183    /// Object type (always "text_completion")
184    pub object: String,
185    /// Creation timestamp
186    pub created: u64,
187    /// Model used for generation
188    pub model: String,
189    /// Generated choices
190    pub choices: Vec<VllmChoice>,
191    /// Usage statistics
192    pub usage: VllmUsage,
193}
194
195/// Individual choice in vLLM response
196#[derive(Debug, serde::Deserialize)]
197pub struct VllmChoice {
198    /// Generated text
199    pub text: String,
200    /// Index of this choice
201    pub index: u32,
202    /// Log probabilities (if requested)
203    pub logprobs: Option<Value>,
204    /// Reason for finishing
205    pub finish_reason: String,
206}
207
208/// Token usage information
209#[derive(Debug, serde::Deserialize)]
210pub struct VllmUsage {
211    /// Number of prompt tokens
212    pub prompt_tokens: u32,
213    /// Number of completion tokens
214    pub completion_tokens: u32,
215    /// Total tokens used
216    pub total_tokens: u32,
217}
218
219/// Model information response
220#[derive(Debug, serde::Deserialize)]
221pub struct VllmModelInfo {
222    /// Object type (always "list")
223    pub object: String,
224    /// List of available models
225    pub data: Vec<VllmModelData>,
226}
227
228/// Individual model data
229#[derive(Debug, serde::Deserialize)]
230pub struct VllmModelData {
231    /// Model identifier
232    pub id: String,
233    /// Object type (always "model")
234    pub object: String,
235    /// Creation timestamp
236    pub created: u64,
237    /// Model owner
238    pub owned_by: String,
239}
240
241impl ServicePlugin for VllmPlugin {
242    fn name(&self) -> &str {
243        &self.name
244    }
245
246    fn start(&self) -> Result<ServiceHandle> {
247        // Use tokio::task::block_in_place for async operations
248        tokio::task::block_in_place(|| {
249            tokio::runtime::Handle::current().block_on(async {
250                // Test connection to vLLM service
251                let health_check = async {
252                    match self.test_connection().await {
253                        Ok(_) => HealthStatus::Healthy,
254                        Err(_) => HealthStatus::Unhealthy,
255                    }
256                };
257
258                let health = health_check.await;
259
260                let mut metadata = HashMap::new();
261                metadata.insert("endpoint".to_string(), self.config.endpoint.clone());
262                metadata.insert("model".to_string(), self.config.model.clone());
263                metadata.insert(
264                    "timeout_seconds".to_string(),
265                    self.config.timeout_seconds.to_string(),
266                );
267                metadata.insert("health_status".to_string(), format!("{:?}", health));
268
269                if let Some(max_num_seqs) = self.config.max_num_seqs {
270                    metadata.insert("max_num_seqs".to_string(), max_num_seqs.to_string());
271                }
272
273                if let Some(max_model_len) = self.config.max_model_len {
274                    metadata.insert("max_model_len".to_string(), max_model_len.to_string());
275                }
276
277                Ok(ServiceHandle {
278                    id: Uuid::new_v4().to_string(),
279                    service_name: self.name.clone(),
280                    metadata,
281                })
282            })
283        })
284    }
285
286    fn stop(&self, _handle: ServiceHandle) -> Result<()> {
287        // HTTP-based service, no cleanup needed beyond dropping the client
288        Ok(())
289    }
290
291    fn health_check(&self, handle: &ServiceHandle) -> HealthStatus {
292        if let Some(health_status) = handle.metadata.get("health_status") {
293            match health_status.as_str() {
294                "Healthy" => HealthStatus::Healthy,
295                "Unhealthy" => HealthStatus::Unhealthy,
296                _ => HealthStatus::Unknown,
297            }
298        } else {
299            HealthStatus::Unknown
300        }
301    }
302}