offline_intelligence/model_runtime/
onnx_runtime.rs1use async_trait::async_trait;
7use super::runtime_trait::*;
8use std::process::{Child, Command, Stdio};
9use std::time::Duration;
10use tracing::{info, warn};
11use tokio::time::sleep;
12
13pub struct ONNXRuntime {
14 config: Option<RuntimeConfig>,
15 server_process: Option<Child>,
16 http_client: reqwest::Client,
17 base_url: String,
18}
19
20impl ONNXRuntime {
21 pub fn new() -> Self {
22 Self {
23 config: None,
24 server_process: None,
25 http_client: reqwest::Client::builder()
26 .timeout(Duration::from_secs(600))
27 .build()
28 .unwrap_or_default(),
29 base_url: String::new(),
30 }
31 }
32
33 async fn start_server(&mut self, config: &RuntimeConfig) -> anyhow::Result<()> {
36 let binary_path = config.runtime_binary.as_ref()
40 .ok_or_else(|| anyhow::anyhow!("ONNX runtime requires runtime_binary path (e.g., onnx-server)"))?;
41
42 if !binary_path.exists() {
43 return Err(anyhow::anyhow!(
44 "ONNX server binary not found at: {}. You may need to create an ONNX inference server wrapper.",
45 binary_path.display()
46 ));
47 }
48
49 info!("Starting ONNX server for model: {}", config.model_path.display());
50
51 let mut cmd = Command::new(binary_path);
52 cmd.arg("--model").arg(&config.model_path)
53 .arg("--host").arg(&config.host)
54 .arg("--port").arg(config.port.to_string())
55 .stdout(Stdio::piped())
56 .stderr(Stdio::piped());
57
58 let child = cmd.spawn()
59 .map_err(|e| anyhow::anyhow!("Failed to spawn ONNX server: {}", e))?;
60
61 self.server_process = Some(child);
62 self.base_url = format!("http://{}:{}", config.host, config.port);
63
64 for attempt in 1..=15 {
66 sleep(Duration::from_secs(2)).await;
67 if self.is_ready().await {
68 info!("✅ ONNX runtime ready after {} seconds", attempt * 2);
69 return Ok(());
70 }
71 }
72
73 Err(anyhow::anyhow!("ONNX server failed to start within 30 seconds"))
74 }
75}
76
77impl Default for ONNXRuntime {
78 fn default() -> Self {
79 Self::new()
80 }
81}
82
83#[async_trait]
84impl ModelRuntime for ONNXRuntime {
85 fn supported_format(&self) -> ModelFormat {
86 ModelFormat::ONNX
87 }
88
89 async fn initialize(&mut self, config: RuntimeConfig) -> anyhow::Result<()> {
90 info!("Initializing ONNX runtime");
91
92 if config.format != ModelFormat::ONNX {
93 return Err(anyhow::anyhow!("ONNX runtime received wrong format: {:?}", config.format));
94 }
95
96 if !config.model_path.exists() {
97 return Err(anyhow::anyhow!("Model file not found: {}", config.model_path.display()));
98 }
99
100 self.config = Some(config.clone());
101 self.start_server(&config).await?;
102
103 Ok(())
104 }
105
106 async fn is_ready(&self) -> bool {
107 if self.base_url.is_empty() {
108 return false;
109 }
110
111 let health_url = format!("{}/health", self.base_url);
112 match self.http_client.get(&health_url).send().await {
113 Ok(resp) => resp.status().is_success(),
114 Err(_) => false,
115 }
116 }
117
118 async fn health_check(&self) -> anyhow::Result<String> {
119 if self.base_url.is_empty() {
120 return Err(anyhow::anyhow!("Runtime not initialized"));
121 }
122
123 let health_url = format!("{}/health", self.base_url);
124 let resp = self.http_client.get(&health_url)
125 .send()
126 .await
127 .map_err(|e| anyhow::anyhow!("Health check failed: {}", e))?;
128
129 if resp.status().is_success() {
130 Ok("healthy".to_string())
131 } else {
132 Err(anyhow::anyhow!("Health check returned: {}", resp.status()))
133 }
134 }
135
136 fn base_url(&self) -> String {
137 self.base_url.clone()
138 }
139
140 async fn generate(&self, request: InferenceRequest) -> anyhow::Result<InferenceResponse> {
141 let url = self.completions_url();
142
143 let payload = serde_json::json!({
144 "model": "onnx-llm",
145 "messages": request.messages,
146 "max_tokens": request.max_tokens,
147 "temperature": request.temperature,
148 "stream": false,
149 });
150
151 let resp = self.http_client.post(&url)
152 .json(&payload)
153 .send()
154 .await
155 .map_err(|e| anyhow::anyhow!("Inference request failed: {}", e))?;
156
157 if !resp.status().is_success() {
158 let status = resp.status();
159 let body = resp.text().await.unwrap_or_default();
160 return Err(anyhow::anyhow!("Inference failed ({}): {}", status, body));
161 }
162
163 let response: serde_json::Value = resp.json().await
164 .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?;
165
166 let content = response["choices"][0]["message"]["content"]
167 .as_str()
168 .unwrap_or("")
169 .to_string();
170
171 let finish_reason = response["choices"][0]["finish_reason"]
172 .as_str()
173 .map(|s| s.to_string());
174
175 Ok(InferenceResponse { content, finish_reason })
176 }
177
178 async fn generate_stream(
179 &self,
180 request: InferenceRequest,
181 ) -> anyhow::Result<Box<dyn futures_util::Stream<Item = Result<String, anyhow::Error>> + Send + Unpin>> {
182 use futures_util::StreamExt;
183
184 let url = self.completions_url();
185
186 let payload = serde_json::json!({
187 "model": "onnx-llm",
188 "messages": request.messages,
189 "max_tokens": request.max_tokens,
190 "temperature": request.temperature,
191 "stream": true,
192 });
193
194 let resp = self.http_client.post(&url)
195 .json(&payload)
196 .send()
197 .await
198 .map_err(|e| anyhow::anyhow!("Stream request failed: {}", e))?;
199
200 if !resp.status().is_success() {
201 let status = resp.status();
202 let body = resp.text().await.unwrap_or_default();
203 return Err(anyhow::anyhow!("Stream failed ({}): {}", status, body));
204 }
205
206 let byte_stream = resp.bytes_stream();
207
208 let sse_stream = async_stream::try_stream! {
209 let mut buffer = String::new();
210 futures_util::pin_mut!(byte_stream);
211
212 while let Some(chunk_result) = byte_stream.next().await {
213 let chunk = chunk_result.map_err(|e| anyhow::anyhow!("Stream read error: {}", e))?;
214 buffer.push_str(&String::from_utf8_lossy(&chunk));
215
216 while let Some(newline_pos) = buffer.find('\n') {
217 let line = buffer[..newline_pos].trim().to_string();
218 buffer = buffer[newline_pos + 1..].to_string();
219
220 if line.is_empty() || !line.starts_with("data: ") {
221 continue;
222 }
223
224 let data = &line[6..];
225 if data == "[DONE]" {
226 return;
227 }
228
229 yield format!("data: {}\n\n", data);
230 }
231 }
232 };
233
234 Ok(Box::new(Box::pin(sse_stream)))
235 }
236
237 async fn shutdown(&mut self) -> anyhow::Result<()> {
238 info!("Shutting down ONNX runtime");
239
240 if let Some(mut child) = self.server_process.take() {
241 match child.kill() {
242 Ok(_) => {
243 info!("ONNX server process killed successfully");
244 let _ = child.wait();
245 }
246 Err(e) => {
247 warn!("Failed to kill ONNX server process: {}", e);
248 }
249 }
250 }
251
252 self.config = None;
253 self.base_url.clear();
254 Ok(())
255 }
256
257 fn metadata(&self) -> RuntimeMetadata {
258 RuntimeMetadata {
259 format: ModelFormat::ONNX,
260 runtime_name: "ONNX Runtime".to_string(),
261 version: "1.x".to_string(),
262 supports_gpu: true,
263 supports_streaming: true,
264 }
265 }
266}
267
268impl Drop for ONNXRuntime {
269 fn drop(&mut self) {
270 if let Some(mut child) = self.server_process.take() {
271 let _ = child.kill();
272 let _ = child.wait();
273 }
274 }
275}