clnrm_core/services/
vllm.rs1use crate::cleanroom::{HealthStatus, ServiceHandle, ServicePlugin};
7use crate::error::{CleanroomError, Result};
8use serde_json::{json, Value};
9use std::collections::HashMap;
10use std::sync::Arc;
11use tokio::sync::RwLock;
12use uuid::Uuid;
13
14#[derive(Debug, Clone)]
16pub struct VllmConfig {
17 pub endpoint: String,
19 pub model: String,
21 pub max_num_seqs: Option<u32>,
23 pub max_model_len: Option<u32>,
25 pub tensor_parallel_size: Option<u32>,
27 pub gpu_memory_utilization: Option<f32>,
29 pub enable_prefix_caching: Option<bool>,
31 pub timeout_seconds: u64,
33}
34
35#[derive(Debug)]
37pub struct VllmPlugin {
38 name: String,
39 config: VllmConfig,
40 client: Arc<RwLock<Option<reqwest::Client>>>,
41}
42
43impl VllmPlugin {
44 pub fn new(name: &str, config: VllmConfig) -> Self {
46 Self {
47 name: name.to_string(),
48 config,
49 client: Arc::new(RwLock::new(None)),
50 }
51 }
52
53 async fn init_client(&self) -> Result<reqwest::Client> {
55 let client = reqwest::Client::builder()
56 .timeout(std::time::Duration::from_secs(self.config.timeout_seconds))
57 .build()
58 .map_err(|e| {
59 CleanroomError::internal_error(format!("Failed to create HTTP client: {}", e))
60 })?;
61
62 Ok(client)
63 }
64
65 async fn test_connection(&self) -> Result<()> {
67 let mut client_guard = self.client.write().await;
68 if client_guard.is_none() {
69 *client_guard = Some(self.init_client().await?);
70 }
71 let client = client_guard
72 .as_ref()
73 .ok_or_else(|| CleanroomError::internal_error("HTTP client not initialized"))?;
74
75 let url = format!("{}/health", self.config.endpoint);
76
77 let response = client.get(&url).send().await.map_err(|e| {
78 CleanroomError::service_error(format!("Failed to connect to vLLM: {}", e))
79 })?;
80
81 if response.status().is_success() {
82 Ok(())
83 } else {
84 Err(CleanroomError::service_error("vLLM service not responding"))
85 }
86 }
87
88 pub async fn generate_text(
90 &self,
91 prompt: &str,
92 max_tokens: Option<u32>,
93 temperature: Option<f32>,
94 ) -> Result<VllmResponse> {
95 let mut client_guard = self.client.write().await;
96 if client_guard.is_none() {
97 *client_guard = Some(self.init_client().await?);
98 }
99 let client = client_guard
100 .as_ref()
101 .ok_or_else(|| CleanroomError::internal_error("HTTP client not initialized"))?;
102
103 let url = format!("{}/v1/completions", self.config.endpoint);
104
105 let mut payload = json!({
106 "model": self.config.model,
107 "prompt": prompt,
108 "stream": false
109 });
110
111 if let Some(max_tokens) = max_tokens {
112 payload["max_tokens"] = json!(max_tokens);
113 }
114
115 if let Some(temperature) = temperature {
116 payload["temperature"] = json!(temperature);
117 }
118
119 let response = client
120 .post(&url)
121 .header("Content-Type", "application/json")
122 .json(&payload)
123 .send()
124 .await
125 .map_err(|e| {
126 CleanroomError::service_error(format!("Failed to generate text: {}", e))
127 })?;
128
129 if response.status().is_success() {
130 let vllm_response: VllmResponse = response.json().await.map_err(|e| {
131 CleanroomError::service_error(format!("Failed to parse response: {}", e))
132 })?;
133
134 Ok(vllm_response)
135 } else {
136 let error_text = response
137 .text()
138 .await
139 .unwrap_or_else(|_| "Unknown error".to_string());
140
141 Err(CleanroomError::service_error(format!(
142 "vLLM API error: {}",
143 error_text
144 )))
145 }
146 }
147
148 pub async fn get_model_info(&self) -> Result<VllmModelInfo> {
150 let mut client_guard = self.client.write().await;
151 if client_guard.is_none() {
152 *client_guard = Some(self.init_client().await?);
153 }
154 let client = client_guard
155 .as_ref()
156 .ok_or_else(|| CleanroomError::internal_error("HTTP client not initialized"))?;
157
158 let url = format!("{}/v1/models", self.config.endpoint);
159
160 let response = client.get(&url).send().await.map_err(|e| {
161 CleanroomError::service_error(format!("Failed to get model info: {}", e))
162 })?;
163
164 if response.status().is_success() {
165 let model_info: VllmModelInfo = response.json().await.map_err(|e| {
166 CleanroomError::service_error(format!("Failed to parse model info: {}", e))
167 })?;
168
169 Ok(model_info)
170 } else {
171 Err(CleanroomError::service_error(
172 "Failed to retrieve model info",
173 ))
174 }
175 }
176}
177
178#[derive(Debug, serde::Deserialize)]
180pub struct VllmResponse {
181 pub id: String,
183 pub object: String,
185 pub created: u64,
187 pub model: String,
189 pub choices: Vec<VllmChoice>,
191 pub usage: VllmUsage,
193}
194
195#[derive(Debug, serde::Deserialize)]
197pub struct VllmChoice {
198 pub text: String,
200 pub index: u32,
202 pub logprobs: Option<Value>,
204 pub finish_reason: String,
206}
207
208#[derive(Debug, serde::Deserialize)]
210pub struct VllmUsage {
211 pub prompt_tokens: u32,
213 pub completion_tokens: u32,
215 pub total_tokens: u32,
217}
218
219#[derive(Debug, serde::Deserialize)]
221pub struct VllmModelInfo {
222 pub object: String,
224 pub data: Vec<VllmModelData>,
226}
227
228#[derive(Debug, serde::Deserialize)]
230pub struct VllmModelData {
231 pub id: String,
233 pub object: String,
235 pub created: u64,
237 pub owned_by: String,
239}
240
241impl ServicePlugin for VllmPlugin {
242 fn name(&self) -> &str {
243 &self.name
244 }
245
246 fn start(&self) -> Result<ServiceHandle> {
247 tokio::task::block_in_place(|| {
249 tokio::runtime::Handle::current().block_on(async {
250 let health_check = async {
252 match self.test_connection().await {
253 Ok(_) => HealthStatus::Healthy,
254 Err(_) => HealthStatus::Unhealthy,
255 }
256 };
257
258 let health = health_check.await;
259
260 let mut metadata = HashMap::new();
261 metadata.insert("endpoint".to_string(), self.config.endpoint.clone());
262 metadata.insert("model".to_string(), self.config.model.clone());
263 metadata.insert(
264 "timeout_seconds".to_string(),
265 self.config.timeout_seconds.to_string(),
266 );
267 metadata.insert("health_status".to_string(), format!("{:?}", health));
268
269 if let Some(max_num_seqs) = self.config.max_num_seqs {
270 metadata.insert("max_num_seqs".to_string(), max_num_seqs.to_string());
271 }
272
273 if let Some(max_model_len) = self.config.max_model_len {
274 metadata.insert("max_model_len".to_string(), max_model_len.to_string());
275 }
276
277 Ok(ServiceHandle {
278 id: Uuid::new_v4().to_string(),
279 service_name: self.name.clone(),
280 metadata,
281 })
282 })
283 })
284 }
285
286 fn stop(&self, _handle: ServiceHandle) -> Result<()> {
287 Ok(())
289 }
290
291 fn health_check(&self, handle: &ServiceHandle) -> HealthStatus {
292 if let Some(health_status) = handle.metadata.get("health_status") {
293 match health_status.as_str() {
294 "Healthy" => HealthStatus::Healthy,
295 "Unhealthy" => HealthStatus::Unhealthy,
296 _ => HealthStatus::Unknown,
297 }
298 } else {
299 HealthStatus::Unknown
300 }
301 }
302}