use crate::error::{EdgeError, EdgeResult};
use serde_json::Value;
use std::{collections::HashMap, path::Path};
pub mod inference;
use inference::OnnxBackend;
use inference::OnnxModel;
pub fn onnx_model<P: AsRef<Path>>(path: P) -> EdgeResult<OnnxModel> {
OnnxModel::from_directory(path)
}
#[derive(Debug, Clone, Default)]
pub struct InferenceInput {
pub inputs: HashMap<String, Value>,
pub metadata: HashMap<String, Value>,
}
#[derive(Debug, Clone, Default)]
pub struct InferenceOutput {
pub outputs: HashMap<String, Value>,
pub metadata: HashMap<String, Value>,
}
pub struct InferenceRuntime {
backend: OnnxBackend,
runtime_metadata: HashMap<String, Value>,
}
impl InferenceRuntime {
pub async fn from_model_with_device(
model: Box<dyn crate::Model>,
device: crate::Device,
) -> EdgeResult<Self> {
let backend_type = model.model_type().to_string();
let backend = OnnxBackend::from_model_with_device(model, device)?;
let mut runtime_metadata = HashMap::new();
runtime_metadata.insert("backend_type".to_string(), Value::String(backend_type));
runtime_metadata.insert("device_type".to_string(), Value::String(device.to_string()));
runtime_metadata.insert(
"created_at".to_string(),
Value::String(chrono::Utc::now().to_rfc3339()),
);
Ok(Self {
backend,
runtime_metadata,
})
}
pub async fn from_model(model: Box<dyn crate::Model>) -> EdgeResult<Self> {
let device = crate::device::cpu();
Self::from_model_with_device(model, device).await
}
pub fn infer(&mut self, input: InferenceInput) -> EdgeResult<InferenceOutput> {
if !self.backend.is_ready() {
return Err(EdgeError::runtime("Backend is not ready for inference"));
}
self.backend.infer(input)
}
pub fn model_info(&self) -> HashMap<String, Value> {
let mut info = self.backend.model_info();
info.extend(self.runtime_metadata.clone());
info
}
pub fn is_ready(&self) -> bool {
self.backend.is_ready()
}
pub fn backend_info(&self) -> HashMap<String, Value> {
self.backend.backend_info()
}
}
impl InferenceInput {
pub fn new() -> Self {
Self::default()
}
pub fn add_input(mut self, name: String, data: Value) -> Self {
self.inputs.insert(name, data);
self
}
pub fn add_metadata(mut self, key: String, value: Value) -> Self {
self.metadata.insert(key, value);
self
}
}
impl InferenceOutput {
pub fn new() -> Self {
Self::default()
}
pub fn get_output(&self, name: &str) -> Option<&Value> {
self.outputs.get(name)
}
pub fn get_metadata(&self, key: &str) -> Option<&Value> {
self.metadata.get(key)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_inference_input_creation() {
let input = InferenceInput::new()
.add_input(
"input_ids".to_string(),
Value::Array(vec![Value::Number(1.into())]),
)
.add_metadata("batch_size".to_string(), Value::Number(1.into()));
assert!(input.inputs.contains_key("input_ids"));
assert!(input.metadata.contains_key("batch_size"));
}
}