offline_intelligence/model_runtime/
tensorrt_runtime.rs1use async_trait::async_trait;
7use super::runtime_trait::*;
8use std::process::{Child, Command, Stdio};
9use std::time::Duration;
10use tracing::{info, warn};
11use tokio::time::sleep;
12
13pub struct TensorRTRuntime {
14 config: Option<RuntimeConfig>,
15 server_process: Option<Child>,
16 http_client: reqwest::Client,
17 base_url: String,
18}
19
20impl TensorRTRuntime {
21 pub fn new() -> Self {
22 Self {
23 config: None,
24 server_process: None,
25 http_client: reqwest::Client::builder()
26 .timeout(Duration::from_secs(600))
27 .build()
28 .unwrap_or_default(),
29 base_url: String::new(),
30 }
31 }
32
33 async fn start_server(&mut self, config: &RuntimeConfig) -> anyhow::Result<()> {
34 let binary_path = config.runtime_binary.as_ref()
35 .ok_or_else(|| anyhow::anyhow!("TensorRT runtime requires runtime_binary path"))?;
36
37 if !binary_path.exists() {
38 return Err(anyhow::anyhow!(
39 "TensorRT server binary not found at: {}. Install TensorRT and provide a server wrapper.",
40 binary_path.display()
41 ));
42 }
43
44 info!("Starting TensorRT server for model: {}", config.model_path.display());
45
46 let mut cmd = Command::new(binary_path);
47 cmd.arg("--model").arg(&config.model_path)
48 .arg("--host").arg(&config.host)
49 .arg("--port").arg(config.port.to_string())
50 .stdout(Stdio::piped())
51 .stderr(Stdio::piped());
52
53 let child = cmd.spawn()
54 .map_err(|e| anyhow::anyhow!("Failed to spawn TensorRT server: {}", e))?;
55
56 self.server_process = Some(child);
57 self.base_url = format!("http://{}:{}", config.host, config.port);
58
59 let _start = std::time::Instant::now();
61 let mut delay_ms: u64 = 100;
62 let mut last_log_secs: u64 = 0;
63 loop {
64 sleep(Duration::from_millis(delay_ms)).await;
65 if self.is_ready().await {
66 info!("✅ TensorRT runtime ready after {:.1}s", _start.elapsed().as_secs_f64());
67 return Ok(());
68 }
69 let elapsed_secs = _start.elapsed().as_secs();
70 if elapsed_secs >= 120 {
71 break;
72 }
73 if elapsed_secs >= last_log_secs + 10 {
74 info!("Still waiting for TensorRT server... ({}/120s)", elapsed_secs);
75 last_log_secs = elapsed_secs;
76 }
77 delay_ms = (delay_ms * 2).min(2_000);
78 }
79
80 Err(anyhow::anyhow!("TensorRT server failed to start within 120 seconds"))
81 }
82}
83
84impl Default for TensorRTRuntime {
85 fn default() -> Self {
86 Self::new()
87 }
88}
89
90#[async_trait]
91impl ModelRuntime for TensorRTRuntime {
92 fn supported_format(&self) -> ModelFormat {
93 ModelFormat::TensorRT
94 }
95
96 async fn initialize(&mut self, config: RuntimeConfig) -> anyhow::Result<()> {
97 info!("Initializing TensorRT runtime");
98
99 if config.format != ModelFormat::TensorRT {
100 return Err(anyhow::anyhow!("TensorRT runtime received wrong format: {:?}", config.format));
101 }
102
103 self.config = Some(config.clone());
104 self.start_server(&config).await?;
105 Ok(())
106 }
107
108 async fn is_ready(&self) -> bool {
109 if self.base_url.is_empty() {
110 return false;
111 }
112
113 let health_url = format!("{}/health", self.base_url);
114 match self.http_client.get(&health_url).send().await {
115 Ok(resp) => resp.status().is_success(),
116 Err(_) => false,
117 }
118 }
119
120 async fn health_check(&self) -> anyhow::Result<String> {
121 if self.base_url.is_empty() {
122 return Err(anyhow::anyhow!("Runtime not initialized"));
123 }
124
125 let health_url = format!("{}/health", self.base_url);
126 let resp = self.http_client.get(&health_url).send().await
127 .map_err(|e| anyhow::anyhow!("Health check failed: {}", e))?;
128
129 if resp.status().is_success() {
130 Ok("healthy".to_string())
131 } else {
132 Err(anyhow::anyhow!("Health check returned: {}", resp.status()))
133 }
134 }
135
136 fn base_url(&self) -> String {
137 self.base_url.clone()
138 }
139
140 async fn generate(&self, request: InferenceRequest) -> anyhow::Result<InferenceResponse> {
141 let url = self.completions_url();
142
143 let payload = serde_json::json!({
144 "model": "tensorrt-llm",
145 "messages": request.messages,
146 "max_tokens": request.max_tokens,
147 "temperature": request.temperature,
148 "stream": false,
149 });
150
151 let resp = self.http_client.post(&url).json(&payload).send().await
152 .map_err(|e| anyhow::anyhow!("Inference request failed: {}", e))?;
153
154 if !resp.status().is_success() {
155 let status = resp.status();
156 let body = resp.text().await.unwrap_or_default();
157 return Err(anyhow::anyhow!("Inference failed ({}): {}", status, body));
158 }
159
160 let response: serde_json::Value = resp.json().await
161 .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?;
162
163 let content = response["choices"][0]["message"]["content"].as_str().unwrap_or("").to_string();
164 let finish_reason = response["choices"][0]["finish_reason"].as_str().map(|s| s.to_string());
165
166 Ok(InferenceResponse { content, finish_reason })
167 }
168
169 async fn generate_stream(
170 &self,
171 request: InferenceRequest,
172 ) -> anyhow::Result<Box<dyn futures_util::Stream<Item = Result<String, anyhow::Error>> + Send + Unpin>> {
173 use futures_util::StreamExt;
174
175 let url = self.completions_url();
176 let payload = serde_json::json!({
177 "model": "tensorrt-llm",
178 "messages": request.messages,
179 "max_tokens": request.max_tokens,
180 "temperature": request.temperature,
181 "stream": true,
182 });
183
184 let resp = self.http_client.post(&url).json(&payload).send().await
185 .map_err(|e| anyhow::anyhow!("Stream request failed: {}", e))?;
186
187 if !resp.status().is_success() {
188 let status = resp.status();
189 let body = resp.text().await.unwrap_or_default();
190 return Err(anyhow::anyhow!("Stream failed ({}): {}", status, body));
191 }
192
193 let byte_stream = resp.bytes_stream();
194 let sse_stream = async_stream::try_stream! {
195 let mut buffer = String::new();
196 futures_util::pin_mut!(byte_stream);
197
198 while let Some(chunk_result) = byte_stream.next().await {
199 let chunk = chunk_result.map_err(|e| anyhow::anyhow!("Stream read error: {}", e))?;
200 buffer.push_str(&String::from_utf8_lossy(&chunk));
201
202 while let Some(newline_pos) = buffer.find('\n') {
203 let line = buffer[..newline_pos].trim().to_string();
204 buffer = buffer[newline_pos + 1..].to_string();
205
206 if line.is_empty() || !line.starts_with("data: ") {
207 continue;
208 }
209
210 let data = &line[6..];
211 if data == "[DONE]" {
212 return;
213 }
214
215 yield format!("data: {}\n\n", data);
216 }
217 }
218 };
219
220 Ok(Box::new(Box::pin(sse_stream)))
221 }
222
223 async fn shutdown(&mut self) -> anyhow::Result<()> {
224 info!("Shutting down TensorRT runtime");
225
226 if let Some(mut child) = self.server_process.take() {
227 match child.kill() {
228 Ok(_) => {
229 info!("TensorRT server process killed successfully");
230 let _ = child.wait();
231 }
232 Err(e) => {
233 warn!("Failed to kill TensorRT server: {}", e);
234 }
235 }
236 }
237
238 self.config = None;
239 self.base_url.clear();
240 Ok(())
241 }
242
243 fn metadata(&self) -> RuntimeMetadata {
244 RuntimeMetadata {
245 format: ModelFormat::TensorRT,
246 runtime_name: "TensorRT".to_string(),
247 version: "latest".to_string(),
248 supports_gpu: true,
249 supports_streaming: true,
250 }
251 }
252}
253
254impl Drop for TensorRTRuntime {
255 fn drop(&mut self) {
256 if let Some(mut child) = self.server_process.take() {
257 let _ = child.kill();
258 let _ = child.wait();
259 }
260 }
261}