offline_intelligence/model_runtime/
onnx_runtime.rs1use async_trait::async_trait;
7use super::runtime_trait::*;
8use std::process::{Child, Command, Stdio};
9use std::time::Duration;
10use tracing::{info, warn};
11use tokio::time::sleep;
12
13pub struct ONNXRuntime {
14 config: Option<RuntimeConfig>,
15 server_process: Option<Child>,
16 http_client: reqwest::Client,
17 base_url: String,
18}
19
20impl ONNXRuntime {
21 pub fn new() -> Self {
22 Self {
23 config: None,
24 server_process: None,
25 http_client: reqwest::Client::builder()
26 .timeout(Duration::from_secs(600))
27 .build()
28 .unwrap_or_default(),
29 base_url: String::new(),
30 }
31 }
32
33 async fn start_server(&mut self, config: &RuntimeConfig) -> anyhow::Result<()> {
36 let binary_path = config.runtime_binary.as_ref()
40 .ok_or_else(|| anyhow::anyhow!("ONNX runtime requires runtime_binary path (e.g., onnx-server)"))?;
41
42 if !binary_path.exists() {
43 return Err(anyhow::anyhow!(
44 "ONNX server binary not found at: {}. You may need to create an ONNX inference server wrapper.",
45 binary_path.display()
46 ));
47 }
48
49 info!("Starting ONNX server for model: {}", config.model_path.display());
50
51 let mut cmd = Command::new(binary_path);
52 cmd.arg("--model").arg(&config.model_path)
53 .arg("--host").arg(&config.host)
54 .arg("--port").arg(config.port.to_string())
55 .stdout(Stdio::piped())
56 .stderr(Stdio::piped());
57
58 let child = cmd.spawn()
59 .map_err(|e| anyhow::anyhow!("Failed to spawn ONNX server: {}", e))?;
60
61 self.server_process = Some(child);
62 self.base_url = format!("http://{}:{}", config.host, config.port);
63
64 let _start = std::time::Instant::now();
66 let mut delay_ms: u64 = 100;
67 let mut last_log_secs: u64 = 0;
68 loop {
69 sleep(Duration::from_millis(delay_ms)).await;
70 if self.is_ready().await {
71 info!("✅ ONNX runtime ready after {:.1}s", _start.elapsed().as_secs_f64());
72 return Ok(());
73 }
74 let elapsed_secs = _start.elapsed().as_secs();
75 if elapsed_secs >= 120 {
76 break;
77 }
78 if elapsed_secs >= last_log_secs + 10 {
79 info!("Still waiting for ONNX server... ({}/120s)", elapsed_secs);
80 last_log_secs = elapsed_secs;
81 }
82 delay_ms = (delay_ms * 2).min(2_000);
83 }
84
85 Err(anyhow::anyhow!("ONNX server failed to start within 120 seconds"))
86 }
87}
88
89impl Default for ONNXRuntime {
90 fn default() -> Self {
91 Self::new()
92 }
93}
94
95#[async_trait]
96impl ModelRuntime for ONNXRuntime {
97 fn supported_format(&self) -> ModelFormat {
98 ModelFormat::ONNX
99 }
100
101 async fn initialize(&mut self, config: RuntimeConfig) -> anyhow::Result<()> {
102 info!("Initializing ONNX runtime");
103
104 if config.format != ModelFormat::ONNX {
105 return Err(anyhow::anyhow!("ONNX runtime received wrong format: {:?}", config.format));
106 }
107
108 if !config.model_path.exists() {
109 return Err(anyhow::anyhow!("Model file not found: {}", config.model_path.display()));
110 }
111
112 self.config = Some(config.clone());
113 self.start_server(&config).await?;
114
115 Ok(())
116 }
117
118 async fn is_ready(&self) -> bool {
119 if self.base_url.is_empty() {
120 return false;
121 }
122
123 let health_url = format!("{}/health", self.base_url);
124 match self.http_client.get(&health_url).send().await {
125 Ok(resp) => resp.status().is_success(),
126 Err(_) => false,
127 }
128 }
129
130 async fn health_check(&self) -> anyhow::Result<String> {
131 if self.base_url.is_empty() {
132 return Err(anyhow::anyhow!("Runtime not initialized"));
133 }
134
135 let health_url = format!("{}/health", self.base_url);
136 let resp = self.http_client.get(&health_url)
137 .send()
138 .await
139 .map_err(|e| anyhow::anyhow!("Health check failed: {}", e))?;
140
141 if resp.status().is_success() {
142 Ok("healthy".to_string())
143 } else {
144 Err(anyhow::anyhow!("Health check returned: {}", resp.status()))
145 }
146 }
147
148 fn base_url(&self) -> String {
149 self.base_url.clone()
150 }
151
152 async fn generate(&self, request: InferenceRequest) -> anyhow::Result<InferenceResponse> {
153 let url = self.completions_url();
154
155 let payload = serde_json::json!({
156 "model": "onnx-llm",
157 "messages": request.messages,
158 "max_tokens": request.max_tokens,
159 "temperature": request.temperature,
160 "stream": false,
161 });
162
163 let resp = self.http_client.post(&url)
164 .json(&payload)
165 .send()
166 .await
167 .map_err(|e| anyhow::anyhow!("Inference request failed: {}", e))?;
168
169 if !resp.status().is_success() {
170 let status = resp.status();
171 let body = resp.text().await.unwrap_or_default();
172 return Err(anyhow::anyhow!("Inference failed ({}): {}", status, body));
173 }
174
175 let response: serde_json::Value = resp.json().await
176 .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?;
177
178 let content = response["choices"][0]["message"]["content"]
179 .as_str()
180 .unwrap_or("")
181 .to_string();
182
183 let finish_reason = response["choices"][0]["finish_reason"]
184 .as_str()
185 .map(|s| s.to_string());
186
187 Ok(InferenceResponse { content, finish_reason })
188 }
189
190 async fn generate_stream(
191 &self,
192 request: InferenceRequest,
193 ) -> anyhow::Result<Box<dyn futures_util::Stream<Item = Result<String, anyhow::Error>> + Send + Unpin>> {
194 use futures_util::StreamExt;
195
196 let url = self.completions_url();
197
198 let payload = serde_json::json!({
199 "model": "onnx-llm",
200 "messages": request.messages,
201 "max_tokens": request.max_tokens,
202 "temperature": request.temperature,
203 "stream": true,
204 });
205
206 let resp = self.http_client.post(&url)
207 .json(&payload)
208 .send()
209 .await
210 .map_err(|e| anyhow::anyhow!("Stream request failed: {}", e))?;
211
212 if !resp.status().is_success() {
213 let status = resp.status();
214 let body = resp.text().await.unwrap_or_default();
215 return Err(anyhow::anyhow!("Stream failed ({}): {}", status, body));
216 }
217
218 let byte_stream = resp.bytes_stream();
219
220 let sse_stream = async_stream::try_stream! {
221 let mut buffer = String::new();
222 futures_util::pin_mut!(byte_stream);
223
224 while let Some(chunk_result) = byte_stream.next().await {
225 let chunk = chunk_result.map_err(|e| anyhow::anyhow!("Stream read error: {}", e))?;
226 buffer.push_str(&String::from_utf8_lossy(&chunk));
227
228 while let Some(newline_pos) = buffer.find('\n') {
229 let line = buffer[..newline_pos].trim().to_string();
230 buffer = buffer[newline_pos + 1..].to_string();
231
232 if line.is_empty() || !line.starts_with("data: ") {
233 continue;
234 }
235
236 let data = &line[6..];
237 if data == "[DONE]" {
238 return;
239 }
240
241 yield format!("data: {}\n\n", data);
242 }
243 }
244 };
245
246 Ok(Box::new(Box::pin(sse_stream)))
247 }
248
249 async fn shutdown(&mut self) -> anyhow::Result<()> {
250 info!("Shutting down ONNX runtime");
251
252 if let Some(mut child) = self.server_process.take() {
253 match child.kill() {
254 Ok(_) => {
255 info!("ONNX server process killed successfully");
256 let _ = child.wait();
257 }
258 Err(e) => {
259 warn!("Failed to kill ONNX server process: {}", e);
260 }
261 }
262 }
263
264 self.config = None;
265 self.base_url.clear();
266 Ok(())
267 }
268
269 fn metadata(&self) -> RuntimeMetadata {
270 RuntimeMetadata {
271 format: ModelFormat::ONNX,
272 runtime_name: "ONNX Runtime".to_string(),
273 version: "1.x".to_string(),
274 supports_gpu: true,
275 supports_streaming: true,
276 }
277 }
278}
279
280impl Drop for ONNXRuntime {
281 fn drop(&mut self) {
282 if let Some(mut child) = self.server_process.take() {
283 let _ = child.kill();
284 let _ = child.wait();
285 }
286 }
287}