liquid_edge/runtime/onnx/
mod.rs1use crate::error::{EdgeError, EdgeResult};
4use crate::runtime::{InferenceInput, InferenceOutput, RuntimeBackend};
5use crate::{Device, Model};
6use serde_json::Value;
7use std::collections::HashMap;
8use std::path::{Path, PathBuf};
9
10use ndarray::{ArrayD, IxDyn};
11
12#[cfg(feature = "onnx")]
13use ort::{session::Session, value::Value as OrtValue};
14
15#[cfg(all(feature = "onnx", not(target_arch = "wasm32")))]
16use ort::execution_providers::ExecutionProvider;
17
18pub struct OnnxBackend {
20 session: Session,
21 input_info: Vec<InputInfo>,
22 output_info: Vec<OutputInfo>,
23}
24
25#[derive(Debug, Clone)]
26struct InputInfo {
27 name: String,
28 shape: Vec<i64>,
29 data_type: String,
30}
31
32#[derive(Debug, Clone)]
33struct OutputInfo {
34 name: String,
35 shape: Vec<i64>,
36 data_type: String,
37}
38
39impl OnnxBackend {
40 pub fn from_model_with_device(model: Box<dyn Model>, device: Device) -> EdgeResult<Self> {
42 model.validate()?;
44
45 if !device.is_available() {
47 return Err(EdgeError::runtime(format!(
48 "Device {device} is not available"
49 )));
50 }
51
52 let model_path = model.model_path();
54 let onnx_file = if model_path.is_file()
55 && model_path.extension().and_then(|e| e.to_str()) == Some("onnx")
56 {
57 model_path.to_path_buf()
58 } else {
59 model_path.join("model.onnx")
60 };
61
62 if !onnx_file.exists() {
63 return Err(EdgeError::model(format!(
64 "ONNX model file not found: {}",
65 onnx_file.display()
66 )));
67 }
68
69 Self::new_with_device(onnx_file, device)
70 }
71
72 pub fn from_model(model: Box<dyn Model>) -> EdgeResult<Self> {
74 let device = crate::device::cpu();
75 Self::from_model_with_device(model, device)
76 }
77
78 pub fn new_with_device<P: AsRef<Path>>(model_path: P, device: Device) -> EdgeResult<Self> {
80 if !device.is_available() {
82 return Err(EdgeError::runtime(format!(
83 "Device {device} is not available"
84 )));
85 }
86
87 let mut builder = Session::builder()
89 .map_err(|e| EdgeError::runtime(format!("Failed to create session builder: {e}")))?;
90
91 match device {
93 #[allow(unused_variables)]
94 Device::Cuda(id) => {
95 #[cfg(feature = "cuda")]
96 {
97 use ort::execution_providers::CUDAExecutionProvider;
98 let ep = CUDAExecutionProvider::default().with_device_id(id as i32);
99 match ep.is_available() {
100 Ok(true) => {
101 ep.register(&mut builder).map_err(|e| {
102 EdgeError::runtime(format!("Failed to register CUDA: {e}"))
103 })?;
104 }
105 _ => {
106 return Err(EdgeError::runtime("CUDA execution provider not available"))
107 }
108 }
109 }
110 #[cfg(not(feature = "cuda"))]
111 {
112 return Err(EdgeError::runtime("CUDA feature not enabled"));
113 }
114 }
115 Device::Cpu(_) => {
116 use ort::execution_providers::CPUExecutionProvider;
117 let ep = CPUExecutionProvider::default();
118 ep.register(&mut builder)
119 .map_err(|e| EdgeError::runtime(format!("Failed to register CPU: {e}")))?;
120 }
121 }
122
123 let session = builder
124 .commit_from_file(model_path)
125 .map_err(|e| EdgeError::model(format!("Failed to load ONNX model: {e}")))?;
126
127 Self::create_backend(session)
128 }
129
130 pub fn new<P: AsRef<Path>>(model_path: P) -> EdgeResult<Self> {
132 let device = crate::device::cpu();
133 Self::new_with_device(model_path, device)
134 }
135
136 fn create_backend(session: Session) -> EdgeResult<Self> {
138 let input_info: Vec<InputInfo> = session
140 .inputs
141 .iter()
142 .map(|input| {
143 let shape = vec![-1, -1]; InputInfo {
146 name: input.name.clone(),
147 shape,
148 data_type: format!("{:?}", input.input_type),
149 }
150 })
151 .collect();
152
153 let output_info: Vec<OutputInfo> = session
155 .outputs
156 .iter()
157 .map(|output| {
158 let shape = vec![-1, -1, -1]; OutputInfo {
161 name: output.name.clone(),
162 shape,
163 data_type: format!("{:?}", output.output_type),
164 }
165 })
166 .collect();
167
168 log::info!(
169 "ONNX Backend initialized with {} inputs and {} outputs",
170 input_info.len(),
171 output_info.len()
172 );
173
174 for (i, input) in input_info.iter().enumerate() {
175 log::info!(
176 " Input {}: name='{}', type={}, shape={:?}",
177 i,
178 input.name,
179 input.data_type,
180 input.shape
181 );
182 }
183
184 for (i, output) in output_info.iter().enumerate() {
185 log::info!(
186 " Output {}: name='{}', type={}, shape={:?}",
187 i,
188 output.name,
189 output.data_type,
190 output.shape
191 );
192 }
193
194 Ok(Self {
195 session,
196 input_info,
197 output_info,
198 })
199 }
200
201 fn json_to_tensor(
203 &self,
204 name: &str,
205 data: &Value,
206 ) -> EdgeResult<ort::value::Value<ort::value::DynValueTypeMarker>> {
207 match data {
208 Value::Array(arr) => {
209 if let Ok(i64_values) = arr
210 .iter()
211 .map(|v| v.as_i64().ok_or("Invalid i64"))
212 .collect::<Result<Vec<_>, _>>()
213 {
214 let len = i64_values.len();
215 let array = ArrayD::<i64>::from_shape_vec(IxDyn(&[1, len]), i64_values)
216 .map_err(|e| {
217 EdgeError::inference(format!(
218 "Failed to create i64 tensor for {name}: {e}"
219 ))
220 })?;
221
222 Ok(OrtValue::from_array(array)
223 .map_err(|e| {
224 EdgeError::inference(format!(
225 "Failed to create ONNX value for {name}: {e}"
226 ))
227 })?
228 .into_dyn())
229 }
230 else if let Ok(f32_values) = arr
232 .iter()
233 .map(|v| v.as_f64().map(|f| f as f32).ok_or("Invalid f32"))
234 .collect::<Result<Vec<_>, _>>()
235 {
236 let len = f32_values.len();
237 let array = ArrayD::<f32>::from_shape_vec(IxDyn(&[1, len]), f32_values)
238 .map_err(|e| {
239 EdgeError::inference(format!(
240 "Failed to create f32 tensor for {name}: {e}"
241 ))
242 })?;
243
244 Ok(OrtValue::from_array(array)
245 .map_err(|e| {
246 EdgeError::inference(format!(
247 "Failed to create ONNX value for {name}: {e}"
248 ))
249 })?
250 .into_dyn())
251 } else {
252 Err(EdgeError::inference(format!(
253 "Unsupported data type in array for input: {name}"
254 )))
255 }
256 }
257 _ => Err(EdgeError::inference(format!(
258 "Unsupported JSON type for input: {name}"
259 ))),
260 }
261 }
262
263 fn tensor_to_json_static(
265 tensor: &ort::value::Value<ort::value::DynValueTypeMarker>,
266 ) -> EdgeResult<Value> {
267 if let Ok((_, data)) = tensor.try_extract_tensor::<f32>() {
269 let values: Vec<Value> = data
270 .iter()
271 .map(|&x| {
272 Value::Number(
273 serde_json::Number::from_f64(x as f64)
274 .unwrap_or(serde_json::Number::from(0)),
275 )
276 })
277 .collect();
278 return Ok(Value::Array(values));
279 }
280
281 if let Ok((_, data)) = tensor.try_extract_tensor::<i64>() {
283 let values: Vec<Value> = data.iter().map(|&x| Value::Number(x.into())).collect();
284 return Ok(Value::Array(values));
285 }
286
287 Err(EdgeError::inference(
288 "Unsupported tensor type for output conversion",
289 ))
290 }
291}
292
293impl RuntimeBackend for OnnxBackend {
294 fn infer(&mut self, input: InferenceInput) -> EdgeResult<InferenceOutput> {
295 let mut onnx_inputs = HashMap::new();
297
298 for input_info in &self.input_info {
299 if let Some(data) = input.inputs.get(&input_info.name) {
300 let tensor = self.json_to_tensor(&input_info.name, data)?;
301 onnx_inputs.insert(input_info.name.clone(), tensor);
302 } else {
303 return Err(EdgeError::inference(format!(
304 "Missing required input: {}",
305 input_info.name
306 )));
307 }
308 }
309
310 let outputs = self
312 .session
313 .run(onnx_inputs)
314 .map_err(|e| EdgeError::inference(format!("ONNX inference failed: {e}")))?;
315
316 let mut result_outputs = HashMap::new();
318 for output_info in &self.output_info {
319 if let Some(tensor) = outputs.get(&output_info.name) {
320 let json_data = Self::tensor_to_json_static(tensor)?;
321 result_outputs.insert(output_info.name.clone(), json_data);
322 }
323 }
324
325 let mut metadata = HashMap::new();
326 metadata.insert("backend".to_string(), Value::String("onnx".to_string()));
327 metadata.insert("inference_time_ms".to_string(), Value::Number(0.into())); Ok(InferenceOutput {
330 outputs: result_outputs,
331 metadata,
332 })
333 }
334
335 fn model_info(&self) -> HashMap<String, Value> {
336 let mut info = HashMap::new();
337 info.insert(
338 "backend_type".to_string(),
339 Value::String("onnx".to_string()),
340 );
341 info.insert(
342 "num_inputs".to_string(),
343 Value::Number(self.input_info.len().into()),
344 );
345 info.insert(
346 "num_outputs".to_string(),
347 Value::Number(self.output_info.len().into()),
348 );
349
350 let inputs: Vec<Value> = self
351 .input_info
352 .iter()
353 .map(|input| {
354 serde_json::json!({
355 "name": input.name,
356 "data_type": input.data_type,
357 "shape": input.shape
358 })
359 })
360 .collect();
361 info.insert("inputs".to_string(), Value::Array(inputs));
362
363 let outputs: Vec<Value> = self
364 .output_info
365 .iter()
366 .map(|output| {
367 serde_json::json!({
368 "name": output.name,
369 "data_type": output.data_type,
370 "shape": output.shape
371 })
372 })
373 .collect();
374 info.insert("outputs".to_string(), Value::Array(outputs));
375
376 info
377 }
378
379 fn is_ready(&self) -> bool {
380 true }
382
383 fn backend_info(&self) -> HashMap<String, Value> {
384 let mut info = HashMap::new();
385 info.insert(
386 "name".to_string(),
387 Value::String("ONNX Runtime".to_string()),
388 );
389 info.insert("version".to_string(), Value::String("2.0".to_string()));
390 info.insert("supports_gpu".to_string(), Value::Bool(false)); info
392 }
393}
394
395#[derive(Debug, Clone)]
397pub struct OnnxModel {
398 path: PathBuf,
399 metadata: HashMap<String, Value>,
400}
401
402impl OnnxModel {
403 pub fn from_directory<P: AsRef<Path>>(path: P) -> EdgeResult<Self> {
405 let path = path.as_ref().to_path_buf();
406 let mut metadata = HashMap::new();
407
408 if !path.exists() {
410 return Err(EdgeError::model(format!(
411 "Model directory does not exist: {}",
412 path.display()
413 )));
414 }
415
416 let config_path = path.join("config.json");
418 if config_path.exists() {
419 let config_content = std::fs::read_to_string(&config_path)?;
420 let config: Value = serde_json::from_str(&config_content)?;
421
422 if let Some(model_type) = config.get("model_type").and_then(|v| v.as_str()) {
424 metadata.insert(
425 "model_type".to_string(),
426 Value::String(model_type.to_string()),
427 );
428 }
429
430 if let Some(vocab_size) = config.get("vocab_size") {
432 metadata.insert("vocab_size".to_string(), vocab_size.clone());
433 }
434 if let Some(hidden_size) = config.get("hidden_size") {
435 metadata.insert("hidden_size".to_string(), hidden_size.clone());
436 }
437 if let Some(max_position_embeddings) = config.get("max_position_embeddings") {
438 metadata.insert(
439 "max_position_embeddings".to_string(),
440 max_position_embeddings.clone(),
441 );
442 }
443 if let Some(bos_token_id) = config.get("bos_token_id") {
444 metadata.insert("bos_token_id".to_string(), bos_token_id.clone());
445 }
446 if let Some(eos_token_id) = config.get("eos_token_id") {
447 metadata.insert("eos_token_id".to_string(), eos_token_id.clone());
448 }
449 if let Some(pad_token_id) = config.get("pad_token_id") {
450 metadata.insert("pad_token_id".to_string(), pad_token_id.clone());
451 }
452 }
453
454 metadata.insert("format".to_string(), Value::String("onnx".to_string()));
456 metadata.insert(
457 "path".to_string(),
458 Value::String(path.display().to_string()),
459 );
460
461 Ok(Self { path, metadata })
462 }
463
464 pub fn from_file<P: AsRef<Path>>(path: P) -> EdgeResult<Self> {
466 let path = path.as_ref().to_path_buf();
467
468 if !path.exists() {
469 return Err(EdgeError::model(format!(
470 "Model file does not exist: {}",
471 path.display()
472 )));
473 }
474
475 if path.extension().and_then(|e| e.to_str()) != Some("onnx") {
476 return Err(EdgeError::model("File must have .onnx extension"));
477 }
478
479 let mut metadata = HashMap::new();
480 metadata.insert("format".to_string(), Value::String("onnx".to_string()));
481 metadata.insert(
482 "path".to_string(),
483 Value::String(path.display().to_string()),
484 );
485
486 Ok(Self { path, metadata })
487 }
488
489 pub fn with_metadata(mut self, key: String, value: Value) -> Self {
491 self.metadata.insert(key, value);
492 self
493 }
494}
495
496impl Model for OnnxModel {
497 fn model_type(&self) -> &str {
498 "onnx"
499 }
500
501 fn model_path(&self) -> &Path {
502 &self.path
503 }
504
505 fn metadata(&self) -> &HashMap<String, Value> {
506 &self.metadata
507 }
508
509 fn config(&self) -> EdgeResult<Value> {
510 let config_path = self.path.join("config.json");
511 if config_path.exists() {
512 let config_content = std::fs::read_to_string(&config_path)?;
513 let config: Value = serde_json::from_str(&config_content)?;
514 Ok(config)
515 } else {
516 Ok(serde_json::json!({
518 "model_type": "onnx",
519 "path": self.path.display().to_string()
520 }))
521 }
522 }
523
524 fn validate(&self) -> EdgeResult<()> {
525 if !self.path.exists() {
526 return Err(EdgeError::model(format!(
527 "Model path does not exist: {}",
528 self.path.display()
529 )));
530 }
531
532 let onnx_file = if self.path.is_file() {
534 self.path.clone()
536 } else {
537 self.path.join("model.onnx")
539 };
540
541 if !onnx_file.exists() {
542 return Err(EdgeError::model(format!(
543 "ONNX model file not found: {}",
544 onnx_file.display()
545 )));
546 }
547
548 Ok(())
549 }
550}
551
552pub struct ModelBuilder;
554
555impl ModelBuilder {
556 pub fn onnx_from_directory<P: AsRef<Path>>(path: P) -> EdgeResult<OnnxModel> {
558 OnnxModel::from_directory(path)
559 }
560
561 pub fn onnx_from_file<P: AsRef<Path>>(path: P) -> EdgeResult<OnnxModel> {
563 OnnxModel::from_file(path)
564 }
565}
566
567pub fn onnx_model<P: AsRef<Path>>(path: P) -> EdgeResult<OnnxModel> {
569 OnnxModel::from_directory(path)
570}