autoagents_onnx/runtime/
mod.rs

1//! Onnx inference runtime for edge computing
2//!
3//! This module provides a generic interface for running inference on various
4//! deep learning models using different backends.
5
6use crate::error::{EdgeError, EdgeResult};
7use serde_json::Value;
8use std::{collections::HashMap, path::Path};
9
10pub mod inference;
11use inference::OnnxBackend;
12use inference::OnnxModel;
13
14/// Convenience function to create an ONNX model
15pub fn onnx_model<P: AsRef<Path>>(path: P) -> EdgeResult<OnnxModel> {
16    OnnxModel::from_directory(path)
17}
18
19/// Generic input for inference operations
20#[derive(Debug, Clone, Default)]
21pub struct InferenceInput {
22    /// Named tensor inputs as key-value pairs
23    /// Keys are input names, values are the tensor data as JSON
24    pub inputs: HashMap<String, Value>,
25    /// Input metadata
26    pub metadata: HashMap<String, Value>,
27}
28
29/// Generic output from inference operations
30#[derive(Debug, Clone, Default)]
31pub struct InferenceOutput {
32    /// Named tensor outputs as key-value pairs
33    /// Keys are output names, values are the tensor data as JSON
34    pub outputs: HashMap<String, Value>,
35    /// Output metadata
36    pub metadata: HashMap<String, Value>,
37}
38
39/// Main inference runtime that manages different backends
40pub struct InferenceRuntime {
41    backend: OnnxBackend,
42    runtime_metadata: HashMap<String, Value>,
43}
44
45impl InferenceRuntime {
46    /// Create a new inference runtime from a model with a specific device
47    pub async fn from_model_with_device(
48        model: Box<dyn crate::Model>,
49        device: crate::Device,
50    ) -> EdgeResult<Self> {
51        let backend_type = model.model_type().to_string();
52        let backend = OnnxBackend::from_model_with_device(model, device)?;
53
54        let mut runtime_metadata = HashMap::new();
55        runtime_metadata.insert("backend_type".to_string(), Value::String(backend_type));
56        runtime_metadata.insert("device_type".to_string(), Value::String(device.to_string()));
57        runtime_metadata.insert(
58            "created_at".to_string(),
59            Value::String(chrono::Utc::now().to_rfc3339()),
60        );
61
62        Ok(Self {
63            backend,
64            runtime_metadata,
65        })
66    }
67
68    /// Create a new inference runtime from a model (uses CPU device by default)
69    pub async fn from_model(model: Box<dyn crate::Model>) -> EdgeResult<Self> {
70        let device = crate::device::cpu();
71        Self::from_model_with_device(model, device).await
72    }
73
74    /// Run inference on the loaded model
75    pub fn infer(&mut self, input: InferenceInput) -> EdgeResult<InferenceOutput> {
76        if !self.backend.is_ready() {
77            return Err(EdgeError::runtime("Backend is not ready for inference"));
78        }
79
80        self.backend.infer(input)
81    }
82
83    /// Get comprehensive model information
84    pub fn model_info(&self) -> HashMap<String, Value> {
85        let mut info = self.backend.model_info();
86        info.extend(self.runtime_metadata.clone());
87        info
88    }
89
90    /// Check if the runtime is ready for inference
91    pub fn is_ready(&self) -> bool {
92        self.backend.is_ready()
93    }
94
95    /// Get backend information
96    pub fn backend_info(&self) -> HashMap<String, Value> {
97        self.backend.backend_info()
98    }
99}
100
101impl InferenceInput {
102    /// Create a new empty inference input
103    pub fn new() -> Self {
104        Self::default()
105    }
106
107    /// Add a tensor input
108    pub fn add_input(mut self, name: String, data: Value) -> Self {
109        self.inputs.insert(name, data);
110        self
111    }
112
113    /// Add metadata
114    pub fn add_metadata(mut self, key: String, value: Value) -> Self {
115        self.metadata.insert(key, value);
116        self
117    }
118}
119
120impl InferenceOutput {
121    /// Create a new empty inference output
122    pub fn new() -> Self {
123        Self::default()
124    }
125
126    /// Get output by name
127    pub fn get_output(&self, name: &str) -> Option<&Value> {
128        self.outputs.get(name)
129    }
130
131    /// Get metadata by key
132    pub fn get_metadata(&self, key: &str) -> Option<&Value> {
133        self.metadata.get(key)
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140
141    #[test]
142    fn test_inference_input_creation() {
143        let input = InferenceInput::new()
144            .add_input(
145                "input_ids".to_string(),
146                Value::Array(vec![Value::Number(1.into())]),
147            )
148            .add_metadata("batch_size".to_string(), Value::Number(1.into()));
149
150        assert!(input.inputs.contains_key("input_ids"));
151        assert!(input.metadata.contains_key("batch_size"));
152    }
153}