liquid_edge/runtime/
mod.rs1use crate::error::{EdgeError, EdgeResult};
7use serde_json::Value;
8use std::collections::HashMap;
9
10#[cfg(feature = "onnx")]
11pub mod onnx;
12
13#[derive(Debug, Clone, Default)]
15pub struct InferenceInput {
16 pub inputs: HashMap<String, Value>,
19 pub metadata: HashMap<String, Value>,
21}
22
23#[derive(Debug, Clone, Default)]
25pub struct InferenceOutput {
26 pub outputs: HashMap<String, Value>,
29 pub metadata: HashMap<String, Value>,
31}
32
33pub trait RuntimeBackend: Send + Sync {
35 fn infer(&mut self, input: InferenceInput) -> EdgeResult<InferenceOutput>;
37
38 fn model_info(&self) -> HashMap<String, Value>;
40
41 fn is_ready(&self) -> bool;
43
44 fn backend_info(&self) -> HashMap<String, Value>;
46}
47
48pub struct InferenceRuntime {
50 backend: Box<dyn RuntimeBackend>,
51 runtime_metadata: HashMap<String, Value>,
52}
53
54impl InferenceRuntime {
55 pub async fn from_model_with_device(
57 model: Box<dyn crate::Model>,
58 device: crate::Device,
59 ) -> EdgeResult<Self> {
60 let backend_type = model.model_type().to_string();
61
62 let backend: Box<dyn RuntimeBackend> = match backend_type.as_str() {
63 #[cfg(feature = "onnx")]
64 "onnx" => {
65 let backend = onnx::OnnxBackend::from_model_with_device(model, device)?;
66 Box::new(backend)
67 }
68 _ => {
69 return Err(EdgeError::runtime(format!(
70 "Unsupported model type: {backend_type}"
71 )));
72 }
73 };
74
75 let mut runtime_metadata = HashMap::new();
76 runtime_metadata.insert("backend_type".to_string(), Value::String(backend_type));
77 runtime_metadata.insert("device_type".to_string(), Value::String(device.to_string()));
78 runtime_metadata.insert(
79 "created_at".to_string(),
80 Value::String(chrono::Utc::now().to_rfc3339()),
81 );
82
83 Ok(Self {
84 backend,
85 runtime_metadata,
86 })
87 }
88
89 pub async fn from_model(model: Box<dyn crate::Model>) -> EdgeResult<Self> {
91 let device = crate::device::cpu();
92 Self::from_model_with_device(model, device).await
93 }
94
95 pub fn infer(&mut self, input: InferenceInput) -> EdgeResult<InferenceOutput> {
97 if !self.backend.is_ready() {
98 return Err(EdgeError::runtime("Backend is not ready for inference"));
99 }
100
101 self.backend.infer(input)
102 }
103
104 pub fn model_info(&self) -> HashMap<String, Value> {
106 let mut info = self.backend.model_info();
107 info.extend(self.runtime_metadata.clone());
108 info
109 }
110
111 pub fn is_ready(&self) -> bool {
113 self.backend.is_ready()
114 }
115
116 pub fn backend_info(&self) -> HashMap<String, Value> {
118 self.backend.backend_info()
119 }
120}
121
122impl InferenceInput {
123 pub fn new() -> Self {
125 Self::default()
126 }
127
128 pub fn add_input(mut self, name: String, data: Value) -> Self {
130 self.inputs.insert(name, data);
131 self
132 }
133
134 pub fn add_metadata(mut self, key: String, value: Value) -> Self {
136 self.metadata.insert(key, value);
137 self
138 }
139}
140
141impl InferenceOutput {
142 pub fn new() -> Self {
144 Self::default()
145 }
146
147 pub fn get_output(&self, name: &str) -> Option<&Value> {
149 self.outputs.get(name)
150 }
151
152 pub fn get_metadata(&self, key: &str) -> Option<&Value> {
154 self.metadata.get(key)
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161
162 #[test]
163 fn test_inference_input_creation() {
164 let input = InferenceInput::new()
165 .add_input(
166 "input_ids".to_string(),
167 Value::Array(vec![Value::Number(1.into())]),
168 )
169 .add_metadata("batch_size".to_string(), Value::Number(1.into()));
170
171 assert!(input.inputs.contains_key("input_ids"));
172 assert!(input.metadata.contains_key("batch_size"));
173 }
174}