Skip to main content

offline_intelligence/model_runtime/
tensorrt_runtime.rs

1//! TensorRT Runtime Adapter
2//!
3//! Adapter for TensorRT optimized models (NVIDIA).
4//! Requires NVIDIA GPU and TensorRT runtime.
5
6use async_trait::async_trait;
7use super::runtime_trait::*;
8use std::process::{Child, Command, Stdio};
9use std::time::Duration;
10use tracing::{info, warn};
11use tokio::time::sleep;
12
13pub struct TensorRTRuntime {
14    config: Option<RuntimeConfig>,
15    server_process: Option<Child>,
16    http_client: reqwest::Client,
17    base_url: String,
18}
19
20impl TensorRTRuntime {
21    pub fn new() -> Self {
22        Self {
23            config: None,
24            server_process: None,
25            http_client: reqwest::Client::builder()
26                .timeout(Duration::from_secs(600))
27                .build()
28                .unwrap_or_default(),
29            base_url: String::new(),
30        }
31    }
32
33    async fn start_server(&mut self, config: &RuntimeConfig) -> anyhow::Result<()> {
34        let binary_path = config.runtime_binary.as_ref()
35            .ok_or_else(|| anyhow::anyhow!("TensorRT runtime requires runtime_binary path"))?;
36
37        if !binary_path.exists() {
38            return Err(anyhow::anyhow!(
39                "TensorRT server binary not found at: {}. Install TensorRT and provide a server wrapper.",
40                binary_path.display()
41            ));
42        }
43
44        info!("Starting TensorRT server for model: {}", config.model_path.display());
45        
46        let mut cmd = Command::new(binary_path);
47        cmd.arg("--model").arg(&config.model_path)
48            .arg("--host").arg(&config.host)
49            .arg("--port").arg(config.port.to_string())
50            .stdout(Stdio::piped())
51            .stderr(Stdio::piped());
52
53        let child = cmd.spawn()
54            .map_err(|e| anyhow::anyhow!("Failed to spawn TensorRT server: {}", e))?;
55
56        self.server_process = Some(child);
57        self.base_url = format!("http://{}:{}", config.host, config.port);
58
59        // Wait for server to be ready (up to 120 seconds) with exponential backoff.
60        let _start = std::time::Instant::now();
61        let mut delay_ms: u64 = 100;
62        let mut last_log_secs: u64 = 0;
63        loop {
64            sleep(Duration::from_millis(delay_ms)).await;
65            if self.is_ready().await {
66                info!("✅ TensorRT runtime ready after {:.1}s", _start.elapsed().as_secs_f64());
67                return Ok(());
68            }
69            let elapsed_secs = _start.elapsed().as_secs();
70            if elapsed_secs >= 120 {
71                break;
72            }
73            if elapsed_secs >= last_log_secs + 10 {
74                info!("Still waiting for TensorRT server... ({}/120s)", elapsed_secs);
75                last_log_secs = elapsed_secs;
76            }
77            delay_ms = (delay_ms * 2).min(2_000);
78        }
79
80        Err(anyhow::anyhow!("TensorRT server failed to start within 120 seconds"))
81    }
82}
83
84impl Default for TensorRTRuntime {
85    fn default() -> Self {
86        Self::new()
87    }
88}
89
90#[async_trait]
91impl ModelRuntime for TensorRTRuntime {
92    fn supported_format(&self) -> ModelFormat {
93        ModelFormat::TensorRT
94    }
95
96    async fn initialize(&mut self, config: RuntimeConfig) -> anyhow::Result<()> {
97        info!("Initializing TensorRT runtime");
98        
99        if config.format != ModelFormat::TensorRT {
100            return Err(anyhow::anyhow!("TensorRT runtime received wrong format: {:?}", config.format));
101        }
102
103        self.config = Some(config.clone());
104        self.start_server(&config).await?;
105        Ok(())
106    }
107
108    async fn is_ready(&self) -> bool {
109        if self.base_url.is_empty() {
110            return false;
111        }
112
113        let health_url = format!("{}/health", self.base_url);
114        match self.http_client.get(&health_url).send().await {
115            Ok(resp) => resp.status().is_success(),
116            Err(_) => false,
117        }
118    }
119
120    async fn health_check(&self) -> anyhow::Result<String> {
121        if self.base_url.is_empty() {
122            return Err(anyhow::anyhow!("Runtime not initialized"));
123        }
124
125        let health_url = format!("{}/health", self.base_url);
126        let resp = self.http_client.get(&health_url).send().await
127            .map_err(|e| anyhow::anyhow!("Health check failed: {}", e))?;
128
129        if resp.status().is_success() {
130            Ok("healthy".to_string())
131        } else {
132            Err(anyhow::anyhow!("Health check returned: {}", resp.status()))
133        }
134    }
135
136    fn base_url(&self) -> String {
137        self.base_url.clone()
138    }
139
140    async fn generate(&self, request: InferenceRequest) -> anyhow::Result<InferenceResponse> {
141        let url = self.completions_url();
142        
143        let payload = serde_json::json!({
144            "model": "tensorrt-llm",
145            "messages": request.messages,
146            "max_tokens": request.max_tokens,
147            "temperature": request.temperature,
148            "stream": false,
149        });
150
151        let resp = self.http_client.post(&url).json(&payload).send().await
152            .map_err(|e| anyhow::anyhow!("Inference request failed: {}", e))?;
153
154        if !resp.status().is_success() {
155            let status = resp.status();
156            let body = resp.text().await.unwrap_or_default();
157            return Err(anyhow::anyhow!("Inference failed ({}): {}", status, body));
158        }
159
160        let response: serde_json::Value = resp.json().await
161            .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?;
162
163        let content = response["choices"][0]["message"]["content"].as_str().unwrap_or("").to_string();
164        let finish_reason = response["choices"][0]["finish_reason"].as_str().map(|s| s.to_string());
165
166        Ok(InferenceResponse { content, finish_reason })
167    }
168
169    async fn generate_stream(
170        &self,
171        request: InferenceRequest,
172    ) -> anyhow::Result<Box<dyn futures_util::Stream<Item = Result<String, anyhow::Error>> + Send + Unpin>> {
173        use futures_util::StreamExt;
174        
175        let url = self.completions_url();
176        let payload = serde_json::json!({
177            "model": "tensorrt-llm",
178            "messages": request.messages,
179            "max_tokens": request.max_tokens,
180            "temperature": request.temperature,
181            "stream": true,
182        });
183
184        let resp = self.http_client.post(&url).json(&payload).send().await
185            .map_err(|e| anyhow::anyhow!("Stream request failed: {}", e))?;
186
187        if !resp.status().is_success() {
188            let status = resp.status();
189            let body = resp.text().await.unwrap_or_default();
190            return Err(anyhow::anyhow!("Stream failed ({}): {}", status, body));
191        }
192
193        let byte_stream = resp.bytes_stream();
194        let sse_stream = async_stream::try_stream! {
195            let mut buffer = String::new();
196            futures_util::pin_mut!(byte_stream);
197
198            while let Some(chunk_result) = byte_stream.next().await {
199                let chunk = chunk_result.map_err(|e| anyhow::anyhow!("Stream read error: {}", e))?;
200                buffer.push_str(&String::from_utf8_lossy(&chunk));
201
202                while let Some(newline_pos) = buffer.find('\n') {
203                    let line = buffer[..newline_pos].trim().to_string();
204                    buffer = buffer[newline_pos + 1..].to_string();
205
206                    if line.is_empty() || !line.starts_with("data: ") {
207                        continue;
208                    }
209
210                    let data = &line[6..];
211                    if data == "[DONE]" {
212                        return;
213                    }
214
215                    yield format!("data: {}\n\n", data);
216                }
217            }
218        };
219
220        Ok(Box::new(Box::pin(sse_stream)))
221    }
222
223    async fn shutdown(&mut self) -> anyhow::Result<()> {
224        info!("Shutting down TensorRT runtime");
225        
226        if let Some(mut child) = self.server_process.take() {
227            match child.kill() {
228                Ok(_) => {
229                    info!("TensorRT server process killed successfully");
230                    let _ = child.wait();
231                }
232                Err(e) => {
233                    warn!("Failed to kill TensorRT server: {}", e);
234                }
235            }
236        }
237
238        self.config = None;
239        self.base_url.clear();
240        Ok(())
241    }
242
243    fn metadata(&self) -> RuntimeMetadata {
244        RuntimeMetadata {
245            format: ModelFormat::TensorRT,
246            runtime_name: "TensorRT".to_string(),
247            version: "latest".to_string(),
248            supports_gpu: true,
249            supports_streaming: true,
250        }
251    }
252}
253
254impl Drop for TensorRTRuntime {
255    fn drop(&mut self) {
256        if let Some(mut child) = self.server_process.take() {
257            let _ = child.kill();
258            let _ = child.wait();
259        }
260    }
261}