liquid_edge/runtime/onnx/
mod.rs

1//! ONNX Runtime backend for liquid-edge inference
2
3use crate::error::{EdgeError, EdgeResult};
4use crate::runtime::{InferenceInput, InferenceOutput, RuntimeBackend};
5use crate::{Device, Model};
6use serde_json::Value;
7use std::collections::HashMap;
8use std::path::{Path, PathBuf};
9
10use ndarray::{ArrayD, IxDyn};
11
12#[cfg(feature = "onnx")]
13use ort::{session::Session, value::Value as OrtValue};
14
15#[cfg(all(feature = "onnx", not(target_arch = "wasm32")))]
16use ort::execution_providers::ExecutionProvider;
17
18/// ONNX Runtime backend for edge inference
19pub struct OnnxBackend {
20    session: Session,
21    input_info: Vec<InputInfo>,
22    output_info: Vec<OutputInfo>,
23}
24
25#[derive(Debug, Clone)]
26struct InputInfo {
27    name: String,
28    shape: Vec<i64>,
29    data_type: String,
30}
31
32#[derive(Debug, Clone)]
33struct OutputInfo {
34    name: String,
35    shape: Vec<i64>,
36    data_type: String,
37}
38
39impl OnnxBackend {
40    /// Create a new ONNX backend from a model with a specific device
41    pub fn from_model_with_device(model: Box<dyn Model>, device: Device) -> EdgeResult<Self> {
42        // Validate the model
43        model.validate()?;
44
45        // Check if device is available
46        if !device.is_available() {
47            return Err(EdgeError::runtime(format!(
48                "Device {device} is not available"
49            )));
50        }
51
52        // Get the actual ONNX model file path
53        let model_path = model.model_path();
54        let onnx_file = if model_path.is_file()
55            && model_path.extension().and_then(|e| e.to_str()) == Some("onnx")
56        {
57            model_path.to_path_buf()
58        } else {
59            model_path.join("model.onnx")
60        };
61
62        if !onnx_file.exists() {
63            return Err(EdgeError::model(format!(
64                "ONNX model file not found: {}",
65                onnx_file.display()
66            )));
67        }
68
69        Self::new_with_device(onnx_file, device)
70    }
71
72    /// Create a new ONNX backend from a model (uses CPU device by default)
73    pub fn from_model(model: Box<dyn Model>) -> EdgeResult<Self> {
74        let device = crate::device::cpu();
75        Self::from_model_with_device(model, device)
76    }
77
78    /// Create a new ONNX backend with a specific device
79    pub fn new_with_device<P: AsRef<Path>>(model_path: P, device: Device) -> EdgeResult<Self> {
80        // Check if device is available
81        if !device.is_available() {
82            return Err(EdgeError::runtime(format!(
83                "Device {device} is not available"
84            )));
85        }
86
87        // Create session with device-specific execution provider following USLS pattern
88        let mut builder = Session::builder()
89            .map_err(|e| EdgeError::runtime(format!("Failed to create session builder: {e}")))?;
90
91        // Register execution provider based on device type (following USLS pattern)
92        match device {
93            #[allow(unused_variables)]
94            Device::Cuda(id) => {
95                #[cfg(feature = "cuda")]
96                {
97                    use ort::execution_providers::CUDAExecutionProvider;
98                    let ep = CUDAExecutionProvider::default().with_device_id(id as i32);
99                    match ep.is_available() {
100                        Ok(true) => {
101                            ep.register(&mut builder).map_err(|e| {
102                                EdgeError::runtime(format!("Failed to register CUDA: {e}"))
103                            })?;
104                        }
105                        _ => {
106                            return Err(EdgeError::runtime("CUDA execution provider not available"))
107                        }
108                    }
109                }
110                #[cfg(not(feature = "cuda"))]
111                {
112                    return Err(EdgeError::runtime("CUDA feature not enabled"));
113                }
114            }
115            Device::Cpu(_) => {
116                use ort::execution_providers::CPUExecutionProvider;
117                let ep = CPUExecutionProvider::default();
118                ep.register(&mut builder)
119                    .map_err(|e| EdgeError::runtime(format!("Failed to register CPU: {e}")))?;
120            }
121        }
122
123        let session = builder
124            .commit_from_file(model_path)
125            .map_err(|e| EdgeError::model(format!("Failed to load ONNX model: {e}")))?;
126
127        Self::create_backend(session)
128    }
129
130    /// Create a new ONNX backend (uses CPU device by default)
131    pub fn new<P: AsRef<Path>>(model_path: P) -> EdgeResult<Self> {
132        let device = crate::device::cpu();
133        Self::new_with_device(model_path, device)
134    }
135
136    /// Common backend creation logic
137    fn create_backend(session: Session) -> EdgeResult<Self> {
138        // Extract input information
139        let input_info: Vec<InputInfo> = session
140            .inputs
141            .iter()
142            .map(|input| {
143                let shape = vec![-1, -1]; // Dynamic shape for now
144
145                InputInfo {
146                    name: input.name.clone(),
147                    shape,
148                    data_type: format!("{:?}", input.input_type),
149                }
150            })
151            .collect();
152
153        // Extract output information
154        let output_info: Vec<OutputInfo> = session
155            .outputs
156            .iter()
157            .map(|output| {
158                let shape = vec![-1, -1, -1]; // Dynamic shape for now
159
160                OutputInfo {
161                    name: output.name.clone(),
162                    shape,
163                    data_type: format!("{:?}", output.output_type),
164                }
165            })
166            .collect();
167
168        log::info!(
169            "ONNX Backend initialized with {} inputs and {} outputs",
170            input_info.len(),
171            output_info.len()
172        );
173
174        for (i, input) in input_info.iter().enumerate() {
175            log::info!(
176                "  Input {}: name='{}', type={}, shape={:?}",
177                i,
178                input.name,
179                input.data_type,
180                input.shape
181            );
182        }
183
184        for (i, output) in output_info.iter().enumerate() {
185            log::info!(
186                "  Output {}: name='{}', type={}, shape={:?}",
187                i,
188                output.name,
189                output.data_type,
190                output.shape
191            );
192        }
193
194        Ok(Self {
195            session,
196            input_info,
197            output_info,
198        })
199    }
200
201    /// Convert JSON value to ONNX tensor
202    fn json_to_tensor(
203        &self,
204        name: &str,
205        data: &Value,
206    ) -> EdgeResult<ort::value::Value<ort::value::DynValueTypeMarker>> {
207        match data {
208            Value::Array(arr) => {
209                if let Ok(i64_values) = arr
210                    .iter()
211                    .map(|v| v.as_i64().ok_or("Invalid i64"))
212                    .collect::<Result<Vec<_>, _>>()
213                {
214                    let len = i64_values.len();
215                    let array = ArrayD::<i64>::from_shape_vec(IxDyn(&[1, len]), i64_values)
216                        .map_err(|e| {
217                            EdgeError::inference(format!(
218                                "Failed to create i64 tensor for {name}: {e}"
219                            ))
220                        })?;
221
222                    Ok(OrtValue::from_array(array)
223                        .map_err(|e| {
224                            EdgeError::inference(format!(
225                                "Failed to create ONNX value for {name}: {e}"
226                            ))
227                        })?
228                        .into_dyn())
229                }
230                // Try f32 array
231                else if let Ok(f32_values) = arr
232                    .iter()
233                    .map(|v| v.as_f64().map(|f| f as f32).ok_or("Invalid f32"))
234                    .collect::<Result<Vec<_>, _>>()
235                {
236                    let len = f32_values.len();
237                    let array = ArrayD::<f32>::from_shape_vec(IxDyn(&[1, len]), f32_values)
238                        .map_err(|e| {
239                            EdgeError::inference(format!(
240                                "Failed to create f32 tensor for {name}: {e}"
241                            ))
242                        })?;
243
244                    Ok(OrtValue::from_array(array)
245                        .map_err(|e| {
246                            EdgeError::inference(format!(
247                                "Failed to create ONNX value for {name}: {e}"
248                            ))
249                        })?
250                        .into_dyn())
251                } else {
252                    Err(EdgeError::inference(format!(
253                        "Unsupported data type in array for input: {name}"
254                    )))
255                }
256            }
257            _ => Err(EdgeError::inference(format!(
258                "Unsupported JSON type for input: {name}"
259            ))),
260        }
261    }
262
263    /// Convert ONNX tensor to JSON value
264    fn tensor_to_json_static(
265        tensor: &ort::value::Value<ort::value::DynValueTypeMarker>,
266    ) -> EdgeResult<Value> {
267        // Try to extract as f32 first
268        if let Ok((_, data)) = tensor.try_extract_tensor::<f32>() {
269            let values: Vec<Value> = data
270                .iter()
271                .map(|&x| {
272                    Value::Number(
273                        serde_json::Number::from_f64(x as f64)
274                            .unwrap_or(serde_json::Number::from(0)),
275                    )
276                })
277                .collect();
278            return Ok(Value::Array(values));
279        }
280
281        // Try to extract as i64
282        if let Ok((_, data)) = tensor.try_extract_tensor::<i64>() {
283            let values: Vec<Value> = data.iter().map(|&x| Value::Number(x.into())).collect();
284            return Ok(Value::Array(values));
285        }
286
287        Err(EdgeError::inference(
288            "Unsupported tensor type for output conversion",
289        ))
290    }
291}
292
293impl RuntimeBackend for OnnxBackend {
294    fn infer(&mut self, input: InferenceInput) -> EdgeResult<InferenceOutput> {
295        // Convert inputs to ONNX format
296        let mut onnx_inputs = HashMap::new();
297
298        for input_info in &self.input_info {
299            if let Some(data) = input.inputs.get(&input_info.name) {
300                let tensor = self.json_to_tensor(&input_info.name, data)?;
301                onnx_inputs.insert(input_info.name.clone(), tensor);
302            } else {
303                return Err(EdgeError::inference(format!(
304                    "Missing required input: {}",
305                    input_info.name
306                )));
307            }
308        }
309
310        // Run inference
311        let outputs = self
312            .session
313            .run(onnx_inputs)
314            .map_err(|e| EdgeError::inference(format!("ONNX inference failed: {e}")))?;
315
316        // Convert outputs back to JSON
317        let mut result_outputs = HashMap::new();
318        for output_info in &self.output_info {
319            if let Some(tensor) = outputs.get(&output_info.name) {
320                let json_data = Self::tensor_to_json_static(tensor)?;
321                result_outputs.insert(output_info.name.clone(), json_data);
322            }
323        }
324
325        let mut metadata = HashMap::new();
326        metadata.insert("backend".to_string(), Value::String("onnx".to_string()));
327        metadata.insert("inference_time_ms".to_string(), Value::Number(0.into())); // TODO: Add timing
328
329        Ok(InferenceOutput {
330            outputs: result_outputs,
331            metadata,
332        })
333    }
334
335    fn model_info(&self) -> HashMap<String, Value> {
336        let mut info = HashMap::new();
337        info.insert(
338            "backend_type".to_string(),
339            Value::String("onnx".to_string()),
340        );
341        info.insert(
342            "num_inputs".to_string(),
343            Value::Number(self.input_info.len().into()),
344        );
345        info.insert(
346            "num_outputs".to_string(),
347            Value::Number(self.output_info.len().into()),
348        );
349
350        let inputs: Vec<Value> = self
351            .input_info
352            .iter()
353            .map(|input| {
354                serde_json::json!({
355                    "name": input.name,
356                    "data_type": input.data_type,
357                    "shape": input.shape
358                })
359            })
360            .collect();
361        info.insert("inputs".to_string(), Value::Array(inputs));
362
363        let outputs: Vec<Value> = self
364            .output_info
365            .iter()
366            .map(|output| {
367                serde_json::json!({
368                    "name": output.name,
369                    "data_type": output.data_type,
370                    "shape": output.shape
371                })
372            })
373            .collect();
374        info.insert("outputs".to_string(), Value::Array(outputs));
375
376        info
377    }
378
379    fn is_ready(&self) -> bool {
380        true // ONNX session is ready once created
381    }
382
383    fn backend_info(&self) -> HashMap<String, Value> {
384        let mut info = HashMap::new();
385        info.insert(
386            "name".to_string(),
387            Value::String("ONNX Runtime".to_string()),
388        );
389        info.insert("version".to_string(), Value::String("2.0".to_string()));
390        info.insert("supports_gpu".to_string(), Value::Bool(false)); // TODO: Detect GPU support
391        info
392    }
393}
394
395/// ONNX model implementation
396#[derive(Debug, Clone)]
397pub struct OnnxModel {
398    path: PathBuf,
399    metadata: HashMap<String, Value>,
400}
401
402impl OnnxModel {
403    /// Create a new ONNX model from a directory path
404    pub fn from_directory<P: AsRef<Path>>(path: P) -> EdgeResult<Self> {
405        let path = path.as_ref().to_path_buf();
406        let mut metadata = HashMap::new();
407
408        // Check if path exists
409        if !path.exists() {
410            return Err(EdgeError::model(format!(
411                "Model directory does not exist: {}",
412                path.display()
413            )));
414        }
415
416        // Try to load config.json if it exists
417        let config_path = path.join("config.json");
418        if config_path.exists() {
419            let config_content = std::fs::read_to_string(&config_path)?;
420            let config: Value = serde_json::from_str(&config_content)?;
421
422            // Extract metadata from config
423            if let Some(model_type) = config.get("model_type").and_then(|v| v.as_str()) {
424                metadata.insert(
425                    "model_type".to_string(),
426                    Value::String(model_type.to_string()),
427                );
428            }
429
430            // Add other relevant config fields to metadata
431            if let Some(vocab_size) = config.get("vocab_size") {
432                metadata.insert("vocab_size".to_string(), vocab_size.clone());
433            }
434            if let Some(hidden_size) = config.get("hidden_size") {
435                metadata.insert("hidden_size".to_string(), hidden_size.clone());
436            }
437            if let Some(max_position_embeddings) = config.get("max_position_embeddings") {
438                metadata.insert(
439                    "max_position_embeddings".to_string(),
440                    max_position_embeddings.clone(),
441                );
442            }
443            if let Some(bos_token_id) = config.get("bos_token_id") {
444                metadata.insert("bos_token_id".to_string(), bos_token_id.clone());
445            }
446            if let Some(eos_token_id) = config.get("eos_token_id") {
447                metadata.insert("eos_token_id".to_string(), eos_token_id.clone());
448            }
449            if let Some(pad_token_id) = config.get("pad_token_id") {
450                metadata.insert("pad_token_id".to_string(), pad_token_id.clone());
451            }
452        }
453
454        // Add model format info
455        metadata.insert("format".to_string(), Value::String("onnx".to_string()));
456        metadata.insert(
457            "path".to_string(),
458            Value::String(path.display().to_string()),
459        );
460
461        Ok(Self { path, metadata })
462    }
463
464    /// Create a new ONNX model from a single .onnx file
465    pub fn from_file<P: AsRef<Path>>(path: P) -> EdgeResult<Self> {
466        let path = path.as_ref().to_path_buf();
467
468        if !path.exists() {
469            return Err(EdgeError::model(format!(
470                "Model file does not exist: {}",
471                path.display()
472            )));
473        }
474
475        if path.extension().and_then(|e| e.to_str()) != Some("onnx") {
476            return Err(EdgeError::model("File must have .onnx extension"));
477        }
478
479        let mut metadata = HashMap::new();
480        metadata.insert("format".to_string(), Value::String("onnx".to_string()));
481        metadata.insert(
482            "path".to_string(),
483            Value::String(path.display().to_string()),
484        );
485
486        Ok(Self { path, metadata })
487    }
488
489    /// Add metadata to the model
490    pub fn with_metadata(mut self, key: String, value: Value) -> Self {
491        self.metadata.insert(key, value);
492        self
493    }
494}
495
496impl Model for OnnxModel {
497    fn model_type(&self) -> &str {
498        "onnx"
499    }
500
501    fn model_path(&self) -> &Path {
502        &self.path
503    }
504
505    fn metadata(&self) -> &HashMap<String, Value> {
506        &self.metadata
507    }
508
509    fn config(&self) -> EdgeResult<Value> {
510        let config_path = self.path.join("config.json");
511        if config_path.exists() {
512            let config_content = std::fs::read_to_string(&config_path)?;
513            let config: Value = serde_json::from_str(&config_content)?;
514            Ok(config)
515        } else {
516            // Return basic config if no config.json exists
517            Ok(serde_json::json!({
518                "model_type": "onnx",
519                "path": self.path.display().to_string()
520            }))
521        }
522    }
523
524    fn validate(&self) -> EdgeResult<()> {
525        if !self.path.exists() {
526            return Err(EdgeError::model(format!(
527                "Model path does not exist: {}",
528                self.path.display()
529            )));
530        }
531
532        // Check for ONNX model file
533        let onnx_file = if self.path.is_file() {
534            // Direct .onnx file
535            self.path.clone()
536        } else {
537            // Directory containing model.onnx
538            self.path.join("model.onnx")
539        };
540
541        if !onnx_file.exists() {
542            return Err(EdgeError::model(format!(
543                "ONNX model file not found: {}",
544                onnx_file.display()
545            )));
546        }
547
548        Ok(())
549    }
550}
551
552/// Builder for creating models
553pub struct ModelBuilder;
554
555impl ModelBuilder {
556    /// Create an ONNX model from a directory
557    pub fn onnx_from_directory<P: AsRef<Path>>(path: P) -> EdgeResult<OnnxModel> {
558        OnnxModel::from_directory(path)
559    }
560
561    /// Create an ONNX model from a file
562    pub fn onnx_from_file<P: AsRef<Path>>(path: P) -> EdgeResult<OnnxModel> {
563        OnnxModel::from_file(path)
564    }
565}
566
567/// Convenience function to create an ONNX model
568pub fn onnx_model<P: AsRef<Path>>(path: P) -> EdgeResult<OnnxModel> {
569    OnnxModel::from_directory(path)
570}