Skip to main content

offline_intelligence/model_runtime/
onnx_runtime.rs

1//! ONNX Runtime Adapter
2//!
3//! Adapter for ONNX format models using ONNX Runtime.
4//! This implementation provides a basic HTTP server wrapper around ONNX Runtime.
5
6use async_trait::async_trait;
7use super::runtime_trait::*;
8use std::process::{Child, Command, Stdio};
9use std::time::Duration;
10use tracing::{info, warn};
11use tokio::time::sleep;
12
13pub struct ONNXRuntime {
14    config: Option<RuntimeConfig>,
15    server_process: Option<Child>,
16    http_client: reqwest::Client,
17    base_url: String,
18}
19
20impl ONNXRuntime {
21    pub fn new() -> Self {
22        Self {
23            config: None,
24            server_process: None,
25            http_client: reqwest::Client::builder()
26                .timeout(Duration::from_secs(600))
27                .build()
28                .unwrap_or_default(),
29            base_url: String::new(),
30        }
31    }
32
33    /// Start ONNX inference server
34    /// Note: This requires a separate ONNX server implementation or wrapper
35    async fn start_server(&mut self, config: &RuntimeConfig) -> anyhow::Result<()> {
36        // For ONNX, we need a server wrapper (could be Python-based or custom Rust)
37        // This is a placeholder implementation
38        
39        let binary_path = config.runtime_binary.as_ref()
40            .ok_or_else(|| anyhow::anyhow!("ONNX runtime requires runtime_binary path (e.g., onnx-server)"))?;
41
42        if !binary_path.exists() {
43            return Err(anyhow::anyhow!(
44                "ONNX server binary not found at: {}. You may need to create an ONNX inference server wrapper.",
45                binary_path.display()
46            ));
47        }
48
49        info!("Starting ONNX server for model: {}", config.model_path.display());
50        
51        let mut cmd = Command::new(binary_path);
52        cmd.arg("--model").arg(&config.model_path)
53            .arg("--host").arg(&config.host)
54            .arg("--port").arg(config.port.to_string())
55            .stdout(Stdio::piped())
56            .stderr(Stdio::piped());
57
58        let child = cmd.spawn()
59            .map_err(|e| anyhow::anyhow!("Failed to spawn ONNX server: {}", e))?;
60
61        self.server_process = Some(child);
62        self.base_url = format!("http://{}:{}", config.host, config.port);
63
64        // Wait for server to be ready (up to 120 seconds) with exponential backoff.
65        let _start = std::time::Instant::now();
66        let mut delay_ms: u64 = 100;
67        let mut last_log_secs: u64 = 0;
68        loop {
69            sleep(Duration::from_millis(delay_ms)).await;
70            if self.is_ready().await {
71                info!("✅ ONNX runtime ready after {:.1}s", _start.elapsed().as_secs_f64());
72                return Ok(());
73            }
74            let elapsed_secs = _start.elapsed().as_secs();
75            if elapsed_secs >= 120 {
76                break;
77            }
78            if elapsed_secs >= last_log_secs + 10 {
79                info!("Still waiting for ONNX server... ({}/120s)", elapsed_secs);
80                last_log_secs = elapsed_secs;
81            }
82            delay_ms = (delay_ms * 2).min(2_000);
83        }
84
85        Err(anyhow::anyhow!("ONNX server failed to start within 120 seconds"))
86    }
87}
88
89impl Default for ONNXRuntime {
90    fn default() -> Self {
91        Self::new()
92    }
93}
94
95#[async_trait]
96impl ModelRuntime for ONNXRuntime {
97    fn supported_format(&self) -> ModelFormat {
98        ModelFormat::ONNX
99    }
100
101    async fn initialize(&mut self, config: RuntimeConfig) -> anyhow::Result<()> {
102        info!("Initializing ONNX runtime");
103        
104        if config.format != ModelFormat::ONNX {
105            return Err(anyhow::anyhow!("ONNX runtime received wrong format: {:?}", config.format));
106        }
107
108        if !config.model_path.exists() {
109            return Err(anyhow::anyhow!("Model file not found: {}", config.model_path.display()));
110        }
111
112        self.config = Some(config.clone());
113        self.start_server(&config).await?;
114        
115        Ok(())
116    }
117
118    async fn is_ready(&self) -> bool {
119        if self.base_url.is_empty() {
120            return false;
121        }
122
123        let health_url = format!("{}/health", self.base_url);
124        match self.http_client.get(&health_url).send().await {
125            Ok(resp) => resp.status().is_success(),
126            Err(_) => false,
127        }
128    }
129
130    async fn health_check(&self) -> anyhow::Result<String> {
131        if self.base_url.is_empty() {
132            return Err(anyhow::anyhow!("Runtime not initialized"));
133        }
134
135        let health_url = format!("{}/health", self.base_url);
136        let resp = self.http_client.get(&health_url)
137            .send()
138            .await
139            .map_err(|e| anyhow::anyhow!("Health check failed: {}", e))?;
140
141        if resp.status().is_success() {
142            Ok("healthy".to_string())
143        } else {
144            Err(anyhow::anyhow!("Health check returned: {}", resp.status()))
145        }
146    }
147
148    fn base_url(&self) -> String {
149        self.base_url.clone()
150    }
151
152    async fn generate(&self, request: InferenceRequest) -> anyhow::Result<InferenceResponse> {
153        let url = self.completions_url();
154        
155        let payload = serde_json::json!({
156            "model": "onnx-llm",
157            "messages": request.messages,
158            "max_tokens": request.max_tokens,
159            "temperature": request.temperature,
160            "stream": false,
161        });
162
163        let resp = self.http_client.post(&url)
164            .json(&payload)
165            .send()
166            .await
167            .map_err(|e| anyhow::anyhow!("Inference request failed: {}", e))?;
168
169        if !resp.status().is_success() {
170            let status = resp.status();
171            let body = resp.text().await.unwrap_or_default();
172            return Err(anyhow::anyhow!("Inference failed ({}): {}", status, body));
173        }
174
175        let response: serde_json::Value = resp.json().await
176            .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?;
177
178        let content = response["choices"][0]["message"]["content"]
179            .as_str()
180            .unwrap_or("")
181            .to_string();
182
183        let finish_reason = response["choices"][0]["finish_reason"]
184            .as_str()
185            .map(|s| s.to_string());
186
187        Ok(InferenceResponse { content, finish_reason })
188    }
189
190    async fn generate_stream(
191        &self,
192        request: InferenceRequest,
193    ) -> anyhow::Result<Box<dyn futures_util::Stream<Item = Result<String, anyhow::Error>> + Send + Unpin>> {
194        use futures_util::StreamExt;
195        
196        let url = self.completions_url();
197        
198        let payload = serde_json::json!({
199            "model": "onnx-llm",
200            "messages": request.messages,
201            "max_tokens": request.max_tokens,
202            "temperature": request.temperature,
203            "stream": true,
204        });
205
206        let resp = self.http_client.post(&url)
207            .json(&payload)
208            .send()
209            .await
210            .map_err(|e| anyhow::anyhow!("Stream request failed: {}", e))?;
211
212        if !resp.status().is_success() {
213            let status = resp.status();
214            let body = resp.text().await.unwrap_or_default();
215            return Err(anyhow::anyhow!("Stream failed ({}): {}", status, body));
216        }
217
218        let byte_stream = resp.bytes_stream();
219        
220        let sse_stream = async_stream::try_stream! {
221            let mut buffer = String::new();
222            futures_util::pin_mut!(byte_stream);
223
224            while let Some(chunk_result) = byte_stream.next().await {
225                let chunk = chunk_result.map_err(|e| anyhow::anyhow!("Stream read error: {}", e))?;
226                buffer.push_str(&String::from_utf8_lossy(&chunk));
227
228                while let Some(newline_pos) = buffer.find('\n') {
229                    let line = buffer[..newline_pos].trim().to_string();
230                    buffer = buffer[newline_pos + 1..].to_string();
231
232                    if line.is_empty() || !line.starts_with("data: ") {
233                        continue;
234                    }
235
236                    let data = &line[6..];
237                    if data == "[DONE]" {
238                        return;
239                    }
240
241                    yield format!("data: {}\n\n", data);
242                }
243            }
244        };
245
246        Ok(Box::new(Box::pin(sse_stream)))
247    }
248
249    async fn shutdown(&mut self) -> anyhow::Result<()> {
250        info!("Shutting down ONNX runtime");
251        
252        if let Some(mut child) = self.server_process.take() {
253            match child.kill() {
254                Ok(_) => {
255                    info!("ONNX server process killed successfully");
256                    let _ = child.wait();
257                }
258                Err(e) => {
259                    warn!("Failed to kill ONNX server process: {}", e);
260                }
261            }
262        }
263
264        self.config = None;
265        self.base_url.clear();
266        Ok(())
267    }
268
269    fn metadata(&self) -> RuntimeMetadata {
270        RuntimeMetadata {
271            format: ModelFormat::ONNX,
272            runtime_name: "ONNX Runtime".to_string(),
273            version: "1.x".to_string(),
274            supports_gpu: true,
275            supports_streaming: true,
276        }
277    }
278}
279
280impl Drop for ONNXRuntime {
281    fn drop(&mut self) {
282        if let Some(mut child) = self.server_process.take() {
283            let _ = child.kill();
284            let _ = child.wait();
285        }
286    }
287}