offline_intelligence/model_runtime/
tensorrt_runtime.rs1use async_trait::async_trait;
7use super::runtime_trait::*;
8use std::process::{Child, Command, Stdio};
9use std::time::Duration;
10use tracing::{info, warn};
11use tokio::time::sleep;
12
13pub struct TensorRTRuntime {
14 config: Option<RuntimeConfig>,
15 server_process: Option<Child>,
16 http_client: reqwest::Client,
17 base_url: String,
18}
19
20impl TensorRTRuntime {
21 pub fn new() -> Self {
22 Self {
23 config: None,
24 server_process: None,
25 http_client: reqwest::Client::builder()
26 .timeout(Duration::from_secs(600))
27 .build()
28 .unwrap_or_default(),
29 base_url: String::new(),
30 }
31 }
32
33 async fn start_server(&mut self, config: &RuntimeConfig) -> anyhow::Result<()> {
34 let binary_path = config.runtime_binary.as_ref()
35 .ok_or_else(|| anyhow::anyhow!("TensorRT runtime requires runtime_binary path"))?;
36
37 if !binary_path.exists() {
38 return Err(anyhow::anyhow!(
39 "TensorRT server binary not found at: {}. Install TensorRT and provide a server wrapper.",
40 binary_path.display()
41 ));
42 }
43
44 info!("Starting TensorRT server for model: {}", config.model_path.display());
45
46 let mut cmd = Command::new(binary_path);
47 cmd.arg("--model").arg(&config.model_path)
48 .arg("--host").arg(&config.host)
49 .arg("--port").arg(config.port.to_string())
50 .stdout(Stdio::piped())
51 .stderr(Stdio::piped());
52
53 let child = cmd.spawn()
54 .map_err(|e| anyhow::anyhow!("Failed to spawn TensorRT server: {}", e))?;
55
56 self.server_process = Some(child);
57 self.base_url = format!("http://{}:{}", config.host, config.port);
58
59 for attempt in 1..=15 {
60 sleep(Duration::from_secs(2)).await;
61 if self.is_ready().await {
62 info!("✅ TensorRT runtime ready after {} seconds", attempt * 2);
63 return Ok(());
64 }
65 }
66
67 Err(anyhow::anyhow!("TensorRT server failed to start within 30 seconds"))
68 }
69}
70
71impl Default for TensorRTRuntime {
72 fn default() -> Self {
73 Self::new()
74 }
75}
76
77#[async_trait]
78impl ModelRuntime for TensorRTRuntime {
79 fn supported_format(&self) -> ModelFormat {
80 ModelFormat::TensorRT
81 }
82
83 async fn initialize(&mut self, config: RuntimeConfig) -> anyhow::Result<()> {
84 info!("Initializing TensorRT runtime");
85
86 if config.format != ModelFormat::TensorRT {
87 return Err(anyhow::anyhow!("TensorRT runtime received wrong format: {:?}", config.format));
88 }
89
90 self.config = Some(config.clone());
91 self.start_server(&config).await?;
92 Ok(())
93 }
94
95 async fn is_ready(&self) -> bool {
96 if self.base_url.is_empty() {
97 return false;
98 }
99
100 let health_url = format!("{}/health", self.base_url);
101 match self.http_client.get(&health_url).send().await {
102 Ok(resp) => resp.status().is_success(),
103 Err(_) => false,
104 }
105 }
106
107 async fn health_check(&self) -> anyhow::Result<String> {
108 if self.base_url.is_empty() {
109 return Err(anyhow::anyhow!("Runtime not initialized"));
110 }
111
112 let health_url = format!("{}/health", self.base_url);
113 let resp = self.http_client.get(&health_url).send().await
114 .map_err(|e| anyhow::anyhow!("Health check failed: {}", e))?;
115
116 if resp.status().is_success() {
117 Ok("healthy".to_string())
118 } else {
119 Err(anyhow::anyhow!("Health check returned: {}", resp.status()))
120 }
121 }
122
123 fn base_url(&self) -> String {
124 self.base_url.clone()
125 }
126
127 async fn generate(&self, request: InferenceRequest) -> anyhow::Result<InferenceResponse> {
128 let url = self.completions_url();
129
130 let payload = serde_json::json!({
131 "model": "tensorrt-llm",
132 "messages": request.messages,
133 "max_tokens": request.max_tokens,
134 "temperature": request.temperature,
135 "stream": false,
136 });
137
138 let resp = self.http_client.post(&url).json(&payload).send().await
139 .map_err(|e| anyhow::anyhow!("Inference request failed: {}", e))?;
140
141 if !resp.status().is_success() {
142 let status = resp.status();
143 let body = resp.text().await.unwrap_or_default();
144 return Err(anyhow::anyhow!("Inference failed ({}): {}", status, body));
145 }
146
147 let response: serde_json::Value = resp.json().await
148 .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?;
149
150 let content = response["choices"][0]["message"]["content"].as_str().unwrap_or("").to_string();
151 let finish_reason = response["choices"][0]["finish_reason"].as_str().map(|s| s.to_string());
152
153 Ok(InferenceResponse { content, finish_reason })
154 }
155
156 async fn generate_stream(
157 &self,
158 request: InferenceRequest,
159 ) -> anyhow::Result<Box<dyn futures_util::Stream<Item = Result<String, anyhow::Error>> + Send + Unpin>> {
160 use futures_util::StreamExt;
161
162 let url = self.completions_url();
163 let payload = serde_json::json!({
164 "model": "tensorrt-llm",
165 "messages": request.messages,
166 "max_tokens": request.max_tokens,
167 "temperature": request.temperature,
168 "stream": true,
169 });
170
171 let resp = self.http_client.post(&url).json(&payload).send().await
172 .map_err(|e| anyhow::anyhow!("Stream request failed: {}", e))?;
173
174 if !resp.status().is_success() {
175 let status = resp.status();
176 let body = resp.text().await.unwrap_or_default();
177 return Err(anyhow::anyhow!("Stream failed ({}): {}", status, body));
178 }
179
180 let byte_stream = resp.bytes_stream();
181 let sse_stream = async_stream::try_stream! {
182 let mut buffer = String::new();
183 futures_util::pin_mut!(byte_stream);
184
185 while let Some(chunk_result) = byte_stream.next().await {
186 let chunk = chunk_result.map_err(|e| anyhow::anyhow!("Stream read error: {}", e))?;
187 buffer.push_str(&String::from_utf8_lossy(&chunk));
188
189 while let Some(newline_pos) = buffer.find('\n') {
190 let line = buffer[..newline_pos].trim().to_string();
191 buffer = buffer[newline_pos + 1..].to_string();
192
193 if line.is_empty() || !line.starts_with("data: ") {
194 continue;
195 }
196
197 let data = &line[6..];
198 if data == "[DONE]" {
199 return;
200 }
201
202 yield format!("data: {}\n\n", data);
203 }
204 }
205 };
206
207 Ok(Box::new(Box::pin(sse_stream)))
208 }
209
210 async fn shutdown(&mut self) -> anyhow::Result<()> {
211 info!("Shutting down TensorRT runtime");
212
213 if let Some(mut child) = self.server_process.take() {
214 match child.kill() {
215 Ok(_) => {
216 info!("TensorRT server process killed successfully");
217 let _ = child.wait();
218 }
219 Err(e) => {
220 warn!("Failed to kill TensorRT server: {}", e);
221 }
222 }
223 }
224
225 self.config = None;
226 self.base_url.clear();
227 Ok(())
228 }
229
230 fn metadata(&self) -> RuntimeMetadata {
231 RuntimeMetadata {
232 format: ModelFormat::TensorRT,
233 runtime_name: "TensorRT".to_string(),
234 version: "latest".to_string(),
235 supports_gpu: true,
236 supports_streaming: true,
237 }
238 }
239}
240
241impl Drop for TensorRTRuntime {
242 fn drop(&mut self) {
243 if let Some(mut child) = self.server_process.take() {
244 let _ = child.kill();
245 let _ = child.wait();
246 }
247 }
248}