autoagents_onnx/runtime/
mod.rs1use crate::error::{EdgeError, EdgeResult};
7use serde_json::Value;
8use std::{collections::HashMap, path::Path};
9
10pub mod inference;
11use inference::OnnxBackend;
12use inference::OnnxModel;
13
14pub fn onnx_model<P: AsRef<Path>>(path: P) -> EdgeResult<OnnxModel> {
16 OnnxModel::from_directory(path)
17}
18
19#[derive(Debug, Clone, Default)]
21pub struct InferenceInput {
22 pub inputs: HashMap<String, Value>,
25 pub metadata: HashMap<String, Value>,
27}
28
29#[derive(Debug, Clone, Default)]
31pub struct InferenceOutput {
32 pub outputs: HashMap<String, Value>,
35 pub metadata: HashMap<String, Value>,
37}
38
39pub struct InferenceRuntime {
41 backend: OnnxBackend,
42 runtime_metadata: HashMap<String, Value>,
43}
44
45impl InferenceRuntime {
46 pub async fn from_model_with_device(
48 model: Box<dyn crate::Model>,
49 device: crate::Device,
50 ) -> EdgeResult<Self> {
51 let backend_type = model.model_type().to_string();
52 let backend = OnnxBackend::from_model_with_device(model, device)?;
53
54 let mut runtime_metadata = HashMap::new();
55 runtime_metadata.insert("backend_type".to_string(), Value::String(backend_type));
56 runtime_metadata.insert("device_type".to_string(), Value::String(device.to_string()));
57 runtime_metadata.insert(
58 "created_at".to_string(),
59 Value::String(chrono::Utc::now().to_rfc3339()),
60 );
61
62 Ok(Self {
63 backend,
64 runtime_metadata,
65 })
66 }
67
68 pub async fn from_model(model: Box<dyn crate::Model>) -> EdgeResult<Self> {
70 let device = crate::device::cpu();
71 Self::from_model_with_device(model, device).await
72 }
73
74 pub fn infer(&mut self, input: InferenceInput) -> EdgeResult<InferenceOutput> {
76 if !self.backend.is_ready() {
77 return Err(EdgeError::runtime("Backend is not ready for inference"));
78 }
79
80 self.backend.infer(input)
81 }
82
83 pub fn model_info(&self) -> HashMap<String, Value> {
85 let mut info = self.backend.model_info();
86 info.extend(self.runtime_metadata.clone());
87 info
88 }
89
90 pub fn is_ready(&self) -> bool {
92 self.backend.is_ready()
93 }
94
95 pub fn backend_info(&self) -> HashMap<String, Value> {
97 self.backend.backend_info()
98 }
99}
100
101impl InferenceInput {
102 pub fn new() -> Self {
104 Self::default()
105 }
106
107 pub fn add_input(mut self, name: String, data: Value) -> Self {
109 self.inputs.insert(name, data);
110 self
111 }
112
113 pub fn add_metadata(mut self, key: String, value: Value) -> Self {
115 self.metadata.insert(key, value);
116 self
117 }
118}
119
120impl InferenceOutput {
121 pub fn new() -> Self {
123 Self::default()
124 }
125
126 pub fn get_output(&self, name: &str) -> Option<&Value> {
128 self.outputs.get(name)
129 }
130
131 pub fn get_metadata(&self, key: &str) -> Option<&Value> {
133 self.metadata.get(key)
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140
141 #[test]
142 fn test_inference_input_creation() {
143 let input = InferenceInput::new()
144 .add_input(
145 "input_ids".to_string(),
146 Value::Array(vec![Value::Number(1.into())]),
147 )
148 .add_metadata("batch_size".to_string(), Value::Number(1.into()));
149
150 assert!(input.inputs.contains_key("input_ids"));
151 assert!(input.metadata.contains_key("batch_size"));
152 }
153}