1use crate::cleanroom::{HealthStatus, ServiceHandle, ServicePlugin};
7use crate::error::{CleanroomError, Result};
8use serde_json::json;
9use std::collections::HashMap;
10use std::sync::Arc;
11use tokio::sync::RwLock;
12use uuid::Uuid;
13
14#[derive(Debug, Clone)]
16pub struct TgiConfig {
17 pub endpoint: String,
19 pub model_id: String,
21 pub max_total_tokens: Option<u32>,
23 pub max_input_length: Option<u32>,
25 pub max_batch_prefill_tokens: Option<u32>,
27 pub max_concurrent_requests: Option<u32>,
29 pub max_batch_total_tokens: Option<u32>,
31 pub timeout_seconds: u64,
33}
34
35#[derive(Debug)]
37pub struct TgiPlugin {
38 name: String,
39 config: TgiConfig,
40 client: Arc<RwLock<Option<reqwest::Client>>>,
41}
42
43impl TgiPlugin {
44 pub fn new(name: &str, config: TgiConfig) -> Self {
46 Self {
47 name: name.to_string(),
48 config,
49 client: Arc::new(RwLock::new(None)),
50 }
51 }
52
53 async fn init_client(&self) -> Result<reqwest::Client> {
55 let client = reqwest::Client::builder()
56 .timeout(std::time::Duration::from_secs(self.config.timeout_seconds))
57 .build()
58 .map_err(|e| {
59 CleanroomError::internal_error(format!("Failed to create HTTP client: {}", e))
60 })?;
61
62 Ok(client)
63 }
64
65 async fn test_connection(&self) -> Result<()> {
67 let mut client_guard = self.client.write().await;
68 if client_guard.is_none() {
69 *client_guard = Some(self.init_client().await?);
70 }
71 let client = client_guard
72 .as_ref()
73 .ok_or_else(|| CleanroomError::internal_error("HTTP client not initialized"))?;
74
75 let url = format!("{}/health", self.config.endpoint);
76
77 let response = client.get(&url).send().await.map_err(|e| {
78 CleanroomError::service_error(format!("Failed to connect to TGI: {}", e))
79 })?;
80
81 if response.status().is_success() {
82 Ok(())
83 } else {
84 Err(CleanroomError::service_error("TGI service not responding"))
85 }
86 }
87
88 pub async fn generate_text(
90 &self,
91 inputs: &str,
92 parameters: Option<TgiParameters>,
93 ) -> Result<TgiResponse> {
94 let mut client_guard = self.client.write().await;
95 if client_guard.is_none() {
96 *client_guard = Some(self.init_client().await?);
97 }
98 let client = client_guard
99 .as_ref()
100 .ok_or_else(|| CleanroomError::internal_error("HTTP client not initialized"))?;
101
102 let url = format!("{}/generate", self.config.endpoint);
103
104 let mut payload = json!({
105 "inputs": inputs,
106 "parameters": parameters.unwrap_or_default()
107 });
108
109 let params = payload["parameters"].as_object_mut().ok_or_else(|| {
111 CleanroomError::internal_error(
112 "Invalid JSON structure: parameters field missing or not an object",
113 )
114 })?;
115 params.entry("max_new_tokens").or_insert(json!(100));
116 params.entry("temperature").or_insert(json!(0.7));
117 params.entry("do_sample").or_insert(json!(true));
118
119 let response = client
120 .post(&url)
121 .header("Content-Type", "application/json")
122 .json(&payload)
123 .send()
124 .await
125 .map_err(|e| {
126 CleanroomError::service_error(format!("Failed to generate text: {}", e))
127 })?;
128
129 if response.status().is_success() {
130 let tgi_response: TgiResponse = response.json().await.map_err(|e| {
131 CleanroomError::service_error(format!("Failed to parse response: {}", e))
132 })?;
133
134 Ok(tgi_response)
135 } else {
136 let error_text = response
137 .text()
138 .await
139 .unwrap_or_else(|_| "Unknown error".to_string());
140
141 Err(CleanroomError::service_error(format!(
142 "TGI API error: {}",
143 error_text
144 )))
145 }
146 }
147
148 pub async fn get_info(&self) -> Result<TgiInfo> {
150 let mut client_guard = self.client.write().await;
151 if client_guard.is_none() {
152 *client_guard = Some(self.init_client().await?);
153 }
154 let client = client_guard
155 .as_ref()
156 .ok_or_else(|| CleanroomError::internal_error("HTTP client not initialized"))?;
157
158 let url = format!("{}/info", self.config.endpoint);
159
160 let response = client
161 .get(&url)
162 .send()
163 .await
164 .map_err(|e| CleanroomError::service_error(format!("Failed to get info: {}", e)))?;
165
166 if response.status().is_success() {
167 let info: TgiInfo = response.json().await.map_err(|e| {
168 CleanroomError::service_error(format!("Failed to parse info: {}", e))
169 })?;
170
171 Ok(info)
172 } else {
173 Err(CleanroomError::service_error(
174 "Failed to retrieve service info",
175 ))
176 }
177 }
178}
179
180#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
182pub struct TgiParameters {
183 #[serde(rename = "max_new_tokens")]
185 pub max_new_tokens: Option<u32>,
186 pub temperature: Option<f32>,
188 #[serde(rename = "top_p")]
190 pub top_p: Option<f32>,
191 #[serde(rename = "top_k")]
193 pub top_k: Option<u32>,
194 #[serde(rename = "do_sample")]
196 pub do_sample: Option<bool>,
197 #[serde(rename = "num_beams")]
199 pub num_beams: Option<u32>,
200 #[serde(rename = "repetition_penalty")]
202 pub repetition_penalty: Option<f32>,
203}
204
205impl Default for TgiParameters {
206 fn default() -> Self {
207 Self {
208 max_new_tokens: Some(100),
209 temperature: Some(0.7),
210 top_p: None,
211 top_k: None,
212 do_sample: Some(true),
213 num_beams: None,
214 repetition_penalty: None,
215 }
216 }
217}
218
219#[derive(Debug, serde::Deserialize)]
221pub struct TgiResponse {
222 pub generated_text: String,
224 pub prompt: Option<String>,
226 pub details: Option<TgiDetails>,
228 pub warnings: Option<Vec<String>>,
230}
231
232#[derive(Debug, serde::Deserialize)]
234pub struct TgiDetails {
235 pub finish_reason: String,
237 pub generated_tokens: u32,
239 pub seed: Option<u64>,
241 pub parameters: Option<TgiParameters>,
243}
244
245#[derive(Debug, serde::Deserialize)]
247pub struct TgiInfo {
248 pub model_id: String,
250 pub model_sha: Option<String>,
252 pub max_total_tokens: u32,
254 pub max_input_length: u32,
256 pub max_batch_prefill_tokens: u32,
258 pub max_concurrent_requests: u32,
260 pub max_batch_total_tokens: u32,
262 pub tokenization: Option<TgiTokenization>,
264 pub model_dtype: Option<String>,
266}
267
268#[derive(Debug, serde::Deserialize)]
270pub struct TgiTokenization {
271 pub tokenizer_class: Option<String>,
273 pub tokenizer_slow: Option<bool>,
275}
276
277impl ServicePlugin for TgiPlugin {
278 fn name(&self) -> &str {
279 &self.name
280 }
281
282 fn start(&self) -> Result<ServiceHandle> {
283 tokio::task::block_in_place(|| {
285 tokio::runtime::Handle::current().block_on(async {
286 let health_check = async {
288 match self.test_connection().await {
289 Ok(_) => HealthStatus::Healthy,
290 Err(_) => HealthStatus::Unhealthy,
291 }
292 };
293
294 let health = health_check.await;
295
296 let mut metadata = HashMap::new();
297 metadata.insert("endpoint".to_string(), self.config.endpoint.clone());
298 metadata.insert("model_id".to_string(), self.config.model_id.clone());
299 metadata.insert(
300 "timeout_seconds".to_string(),
301 self.config.timeout_seconds.to_string(),
302 );
303 metadata.insert("health_status".to_string(), format!("{:?}", health));
304
305 if let Some(max_total_tokens) = self.config.max_total_tokens {
306 metadata.insert("max_total_tokens".to_string(), max_total_tokens.to_string());
307 }
308
309 Ok(ServiceHandle {
310 id: Uuid::new_v4().to_string(),
311 service_name: self.name.clone(),
312 metadata,
313 })
314 })
315 })
316 }
317
318 fn stop(&self, _handle: ServiceHandle) -> Result<()> {
319 Ok(())
321 }
322
323 fn health_check(&self, handle: &ServiceHandle) -> HealthStatus {
324 if let Some(health_status) = handle.metadata.get("health_status") {
325 match health_status.as_str() {
326 "Healthy" => HealthStatus::Healthy,
327 "Unhealthy" => HealthStatus::Unhealthy,
328 _ => HealthStatus::Unknown,
329 }
330 } else {
331 HealthStatus::Unknown
332 }
333 }
334}