liquid_edge/runtime/
mod.rs

1//! Generic inference runtime for edge computing
2//!
3//! This module provides a generic interface for running inference on various
4//! deep learning models using different backends.
5
6use crate::error::{EdgeError, EdgeResult};
7use serde_json::Value;
8use std::collections::HashMap;
9
10#[cfg(feature = "onnx")]
11pub mod onnx;
12
13/// Generic input for inference operations
14#[derive(Debug, Clone, Default)]
15pub struct InferenceInput {
16    /// Named tensor inputs as key-value pairs
17    /// Keys are input names, values are the tensor data as JSON
18    pub inputs: HashMap<String, Value>,
19    /// Input metadata
20    pub metadata: HashMap<String, Value>,
21}
22
23/// Generic output from inference operations
24#[derive(Debug, Clone, Default)]
25pub struct InferenceOutput {
26    /// Named tensor outputs as key-value pairs
27    /// Keys are output names, values are the tensor data as JSON
28    pub outputs: HashMap<String, Value>,
29    /// Output metadata
30    pub metadata: HashMap<String, Value>,
31}
32
33/// Generic inference runtime trait
34pub trait RuntimeBackend: Send + Sync {
35    /// Run inference with the given inputs
36    fn infer(&mut self, input: InferenceInput) -> EdgeResult<InferenceOutput>;
37
38    /// Get model information
39    fn model_info(&self) -> HashMap<String, Value>;
40
41    /// Check if the runtime is ready for inference
42    fn is_ready(&self) -> bool;
43
44    /// Get backend-specific metadata
45    fn backend_info(&self) -> HashMap<String, Value>;
46}
47
48/// Main inference runtime that manages different backends
49pub struct InferenceRuntime {
50    backend: Box<dyn RuntimeBackend>,
51    runtime_metadata: HashMap<String, Value>,
52}
53
54impl InferenceRuntime {
55    /// Create a new inference runtime from a model with a specific device
56    pub async fn from_model_with_device(
57        model: Box<dyn crate::Model>,
58        device: crate::Device,
59    ) -> EdgeResult<Self> {
60        let backend_type = model.model_type().to_string();
61
62        let backend: Box<dyn RuntimeBackend> = match backend_type.as_str() {
63            #[cfg(feature = "onnx")]
64            "onnx" => {
65                let backend = onnx::OnnxBackend::from_model_with_device(model, device)?;
66                Box::new(backend)
67            }
68            _ => {
69                return Err(EdgeError::runtime(format!(
70                    "Unsupported model type: {backend_type}"
71                )));
72            }
73        };
74
75        let mut runtime_metadata = HashMap::new();
76        runtime_metadata.insert("backend_type".to_string(), Value::String(backend_type));
77        runtime_metadata.insert("device_type".to_string(), Value::String(device.to_string()));
78        runtime_metadata.insert(
79            "created_at".to_string(),
80            Value::String(chrono::Utc::now().to_rfc3339()),
81        );
82
83        Ok(Self {
84            backend,
85            runtime_metadata,
86        })
87    }
88
89    /// Create a new inference runtime from a model (uses CPU device by default)
90    pub async fn from_model(model: Box<dyn crate::Model>) -> EdgeResult<Self> {
91        let device = crate::device::cpu();
92        Self::from_model_with_device(model, device).await
93    }
94
95    /// Run inference on the loaded model
96    pub fn infer(&mut self, input: InferenceInput) -> EdgeResult<InferenceOutput> {
97        if !self.backend.is_ready() {
98            return Err(EdgeError::runtime("Backend is not ready for inference"));
99        }
100
101        self.backend.infer(input)
102    }
103
104    /// Get comprehensive model information
105    pub fn model_info(&self) -> HashMap<String, Value> {
106        let mut info = self.backend.model_info();
107        info.extend(self.runtime_metadata.clone());
108        info
109    }
110
111    /// Check if the runtime is ready for inference
112    pub fn is_ready(&self) -> bool {
113        self.backend.is_ready()
114    }
115
116    /// Get backend information
117    pub fn backend_info(&self) -> HashMap<String, Value> {
118        self.backend.backend_info()
119    }
120}
121
122impl InferenceInput {
123    /// Create a new empty inference input
124    pub fn new() -> Self {
125        Self::default()
126    }
127
128    /// Add a tensor input
129    pub fn add_input(mut self, name: String, data: Value) -> Self {
130        self.inputs.insert(name, data);
131        self
132    }
133
134    /// Add metadata
135    pub fn add_metadata(mut self, key: String, value: Value) -> Self {
136        self.metadata.insert(key, value);
137        self
138    }
139}
140
141impl InferenceOutput {
142    /// Create a new empty inference output
143    pub fn new() -> Self {
144        Self::default()
145    }
146
147    /// Get output by name
148    pub fn get_output(&self, name: &str) -> Option<&Value> {
149        self.outputs.get(name)
150    }
151
152    /// Get metadata by key
153    pub fn get_metadata(&self, key: &str) -> Option<&Value> {
154        self.metadata.get(key)
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161
162    #[test]
163    fn test_inference_input_creation() {
164        let input = InferenceInput::new()
165            .add_input(
166                "input_ids".to_string(),
167                Value::Array(vec![Value::Number(1.into())]),
168            )
169            .add_metadata("batch_size".to_string(), Value::Number(1.into()));
170
171        assert!(input.inputs.contains_key("input_ids"));
172        assert!(input.metadata.contains_key("batch_size"));
173    }
174}