clnrm_core/services/
tgi.rs

1//! Text Generation Inference (TGI) Service Plugin
2//!
3//! Provides integration with Hugging Face's Text Generation Inference server.
4//! TGI is optimized for high-performance text generation with LLMs.
5
6use crate::cleanroom::{HealthStatus, ServiceHandle, ServicePlugin};
7use crate::error::{CleanroomError, Result};
8use serde_json::json;
9use std::collections::HashMap;
10use std::sync::Arc;
11use tokio::sync::RwLock;
12use uuid::Uuid;
13
14/// TGI service configuration
15#[derive(Debug, Clone)]
16pub struct TgiConfig {
17    /// Service endpoint URL
18    pub endpoint: String,
19    /// Model to serve (Hugging Face model ID)
20    pub model_id: String,
21    /// Maximum total tokens per request
22    pub max_total_tokens: Option<u32>,
23    /// Maximum input length
24    pub max_input_length: Option<u32>,
25    /// Maximum batch size for prefill
26    pub max_batch_prefill_tokens: Option<u32>,
27    /// Maximum number of concurrent requests
28    pub max_concurrent_requests: Option<u32>,
29    /// Maximum batch size for total tokens
30    pub max_batch_total_tokens: Option<u32>,
31    /// Request timeout in seconds
32    pub timeout_seconds: u64,
33}
34
35/// TGI service plugin
36#[derive(Debug)]
37pub struct TgiPlugin {
38    name: String,
39    config: TgiConfig,
40    client: Arc<RwLock<Option<reqwest::Client>>>,
41}
42
43impl TgiPlugin {
44    /// Create a new TGI plugin instance
45    pub fn new(name: &str, config: TgiConfig) -> Self {
46        Self {
47            name: name.to_string(),
48            config,
49            client: Arc::new(RwLock::new(None)),
50        }
51    }
52
53    /// Initialize the HTTP client for TGI API calls
54    async fn init_client(&self) -> Result<reqwest::Client> {
55        let client = reqwest::Client::builder()
56            .timeout(std::time::Duration::from_secs(self.config.timeout_seconds))
57            .build()
58            .map_err(|e| {
59                CleanroomError::internal_error(format!("Failed to create HTTP client: {}", e))
60            })?;
61
62        Ok(client)
63    }
64
65    /// Test connection to TGI service
66    async fn test_connection(&self) -> Result<()> {
67        let mut client_guard = self.client.write().await;
68        if client_guard.is_none() {
69            *client_guard = Some(self.init_client().await?);
70        }
71        let client = client_guard
72            .as_ref()
73            .ok_or_else(|| CleanroomError::internal_error("HTTP client not initialized"))?;
74
75        let url = format!("{}/health", self.config.endpoint);
76
77        let response = client.get(&url).send().await.map_err(|e| {
78            CleanroomError::service_error(format!("Failed to connect to TGI: {}", e))
79        })?;
80
81        if response.status().is_success() {
82            Ok(())
83        } else {
84            Err(CleanroomError::service_error("TGI service not responding"))
85        }
86    }
87
88    /// Generate text using TGI API
89    pub async fn generate_text(
90        &self,
91        inputs: &str,
92        parameters: Option<TgiParameters>,
93    ) -> Result<TgiResponse> {
94        let mut client_guard = self.client.write().await;
95        if client_guard.is_none() {
96            *client_guard = Some(self.init_client().await?);
97        }
98        let client = client_guard
99            .as_ref()
100            .ok_or_else(|| CleanroomError::internal_error("HTTP client not initialized"))?;
101
102        let url = format!("{}/generate", self.config.endpoint);
103
104        let mut payload = json!({
105            "inputs": inputs,
106            "parameters": parameters.unwrap_or_default()
107        });
108
109        // Set default parameters if not provided
110        let params = payload["parameters"].as_object_mut().ok_or_else(|| {
111            CleanroomError::internal_error(
112                "Invalid JSON structure: parameters field missing or not an object",
113            )
114        })?;
115        params.entry("max_new_tokens").or_insert(json!(100));
116        params.entry("temperature").or_insert(json!(0.7));
117        params.entry("do_sample").or_insert(json!(true));
118
119        let response = client
120            .post(&url)
121            .header("Content-Type", "application/json")
122            .json(&payload)
123            .send()
124            .await
125            .map_err(|e| {
126                CleanroomError::service_error(format!("Failed to generate text: {}", e))
127            })?;
128
129        if response.status().is_success() {
130            let tgi_response: TgiResponse = response.json().await.map_err(|e| {
131                CleanroomError::service_error(format!("Failed to parse response: {}", e))
132            })?;
133
134            Ok(tgi_response)
135        } else {
136            let error_text = response
137                .text()
138                .await
139                .unwrap_or_else(|_| "Unknown error".to_string());
140
141            Err(CleanroomError::service_error(format!(
142                "TGI API error: {}",
143                error_text
144            )))
145        }
146    }
147
148    /// Get model information from TGI
149    pub async fn get_info(&self) -> Result<TgiInfo> {
150        let mut client_guard = self.client.write().await;
151        if client_guard.is_none() {
152            *client_guard = Some(self.init_client().await?);
153        }
154        let client = client_guard
155            .as_ref()
156            .ok_or_else(|| CleanroomError::internal_error("HTTP client not initialized"))?;
157
158        let url = format!("{}/info", self.config.endpoint);
159
160        let response = client
161            .get(&url)
162            .send()
163            .await
164            .map_err(|e| CleanroomError::service_error(format!("Failed to get info: {}", e)))?;
165
166        if response.status().is_success() {
167            let info: TgiInfo = response.json().await.map_err(|e| {
168                CleanroomError::service_error(format!("Failed to parse info: {}", e))
169            })?;
170
171            Ok(info)
172        } else {
173            Err(CleanroomError::service_error(
174                "Failed to retrieve service info",
175            ))
176        }
177    }
178}
179
180/// TGI generation parameters
181#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
182pub struct TgiParameters {
183    /// Maximum number of new tokens to generate
184    #[serde(rename = "max_new_tokens")]
185    pub max_new_tokens: Option<u32>,
186    /// Sampling temperature (0.0 to 2.0)
187    pub temperature: Option<f32>,
188    /// Top-p sampling parameter
189    #[serde(rename = "top_p")]
190    pub top_p: Option<f32>,
191    /// Top-k sampling parameter
192    #[serde(rename = "top_k")]
193    pub top_k: Option<u32>,
194    /// Whether to use sampling (vs greedy decoding)
195    #[serde(rename = "do_sample")]
196    pub do_sample: Option<bool>,
197    /// Number of beams for beam search
198    #[serde(rename = "num_beams")]
199    pub num_beams: Option<u32>,
200    /// Repetition penalty
201    #[serde(rename = "repetition_penalty")]
202    pub repetition_penalty: Option<f32>,
203}
204
205impl Default for TgiParameters {
206    fn default() -> Self {
207        Self {
208            max_new_tokens: Some(100),
209            temperature: Some(0.7),
210            top_p: None,
211            top_k: None,
212            do_sample: Some(true),
213            num_beams: None,
214            repetition_penalty: None,
215        }
216    }
217}
218
219/// Response from TGI text generation
220#[derive(Debug, serde::Deserialize)]
221pub struct TgiResponse {
222    /// Generated text
223    pub generated_text: String,
224    /// Input prompt that was used
225    pub prompt: Option<String>,
226    /// Generation details
227    pub details: Option<TgiDetails>,
228    /// Generation warnings
229    pub warnings: Option<Vec<String>>,
230}
231
232/// Generation details
233#[derive(Debug, serde::Deserialize)]
234pub struct TgiDetails {
235    /// Whether generation finished
236    pub finish_reason: String,
237    /// Number of generated tokens
238    pub generated_tokens: u32,
239    /// Seed used for generation
240    pub seed: Option<u64>,
241    /// Generation parameters used
242    pub parameters: Option<TgiParameters>,
243}
244
245/// TGI service information
246#[derive(Debug, serde::Deserialize)]
247pub struct TgiInfo {
248    /// Model ID being served
249    pub model_id: String,
250    /// Model SHA hash
251    pub model_sha: Option<String>,
252    /// Maximum total tokens supported
253    pub max_total_tokens: u32,
254    /// Maximum input length supported
255    pub max_input_length: u32,
256    /// Maximum batch size for prefill
257    pub max_batch_prefill_tokens: u32,
258    /// Maximum number of concurrent requests
259    pub max_concurrent_requests: u32,
260    /// Maximum batch size for total tokens
261    pub max_batch_total_tokens: u32,
262    /// Tokenization details
263    pub tokenization: Option<TgiTokenization>,
264    /// Model dtype
265    pub model_dtype: Option<String>,
266}
267
268/// Tokenization information
269#[derive(Debug, serde::Deserialize)]
270pub struct TgiTokenization {
271    /// Tokenizer class
272    pub tokenizer_class: Option<String>,
273    /// Whether tokenizer is slow
274    pub tokenizer_slow: Option<bool>,
275}
276
277impl ServicePlugin for TgiPlugin {
278    fn name(&self) -> &str {
279        &self.name
280    }
281
282    fn start(&self) -> Result<ServiceHandle> {
283        // Use tokio::task::block_in_place for async operations
284        tokio::task::block_in_place(|| {
285            tokio::runtime::Handle::current().block_on(async {
286                // Test connection to TGI service
287                let health_check = async {
288                    match self.test_connection().await {
289                        Ok(_) => HealthStatus::Healthy,
290                        Err(_) => HealthStatus::Unhealthy,
291                    }
292                };
293
294                let health = health_check.await;
295
296                let mut metadata = HashMap::new();
297                metadata.insert("endpoint".to_string(), self.config.endpoint.clone());
298                metadata.insert("model_id".to_string(), self.config.model_id.clone());
299                metadata.insert(
300                    "timeout_seconds".to_string(),
301                    self.config.timeout_seconds.to_string(),
302                );
303                metadata.insert("health_status".to_string(), format!("{:?}", health));
304
305                if let Some(max_total_tokens) = self.config.max_total_tokens {
306                    metadata.insert("max_total_tokens".to_string(), max_total_tokens.to_string());
307                }
308
309                Ok(ServiceHandle {
310                    id: Uuid::new_v4().to_string(),
311                    service_name: self.name.clone(),
312                    metadata,
313                })
314            })
315        })
316    }
317
318    fn stop(&self, _handle: ServiceHandle) -> Result<()> {
319        // HTTP-based service, no cleanup needed beyond dropping the client
320        Ok(())
321    }
322
323    fn health_check(&self, handle: &ServiceHandle) -> HealthStatus {
324        if let Some(health_status) = handle.metadata.get("health_status") {
325            match health_status.as_str() {
326                "Healthy" => HealthStatus::Healthy,
327                "Unhealthy" => HealthStatus::Unhealthy,
328                _ => HealthStatus::Unknown,
329            }
330        } else {
331            HealthStatus::Unknown
332        }
333    }
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339
340    #[test]
341    fn test_tgi_plugin_creation() {
342        let config = TgiConfig {
343            endpoint: "http://localhost:8080".to_string(),
344            model_id: "microsoft/DialoGPT-medium".to_string(),
345            max_total_tokens: Some(2048),
346            max_input_length: Some(1024),
347            max_batch_prefill_tokens: Some(4096),
348            max_concurrent_requests: Some(32),
349            max_batch_total_tokens: Some(8192),
350            timeout_seconds: 60,
351        };
352
353        let plugin = TgiPlugin::new("test_tgi", config);
354        assert_eq!(plugin.name(), "test_tgi");
355    }
356
357    #[test]
358    fn test_tgi_config() {
359        let config = TgiConfig {
360            endpoint: "http://localhost:8080".to_string(),
361            model_id: "microsoft/DialoGPT-medium".to_string(),
362            max_total_tokens: Some(2048),
363            max_input_length: Some(1024),
364            max_batch_prefill_tokens: Some(4096),
365            max_concurrent_requests: Some(32),
366            max_batch_total_tokens: Some(8192),
367            timeout_seconds: 60,
368        };
369
370        assert_eq!(config.endpoint, "http://localhost:8080");
371        assert_eq!(config.model_id, "microsoft/DialoGPT-medium");
372        assert_eq!(config.timeout_seconds, 60);
373    }
374}