liquid_edge/runtime/
mod.rs1use crate::error::{EdgeError, EdgeResult};
7use serde_json::Value;
8use std::collections::HashMap;
9
10#[cfg(all(feature = "onnx", not(target_arch = "wasm32")))]
11pub mod onnx;
12
13#[derive(Debug, Clone, Default)]
15pub struct InferenceInput {
16 pub inputs: HashMap<String, Value>,
19 pub metadata: HashMap<String, Value>,
21}
22
23#[derive(Debug, Clone, Default)]
25pub struct InferenceOutput {
26 pub outputs: HashMap<String, Value>,
29 pub metadata: HashMap<String, Value>,
31}
32
33pub trait RuntimeBackend: Send + Sync {
35 fn infer(&mut self, input: InferenceInput) -> EdgeResult<InferenceOutput>;
37
38 fn model_info(&self) -> HashMap<String, Value>;
40
41 fn is_ready(&self) -> bool;
43
44 fn backend_info(&self) -> HashMap<String, Value>;
46}
47
48pub struct InferenceRuntime {
50 backend: Box<dyn RuntimeBackend>,
51 runtime_metadata: HashMap<String, Value>,
52}
53
54impl InferenceRuntime {
55 pub async fn from_model_with_device(
57 model: Box<dyn crate::Model>,
58 device: crate::Device,
59 ) -> EdgeResult<Self> {
60 let backend_type = model.model_type().to_string();
61
62 let backend: Box<dyn RuntimeBackend> = match backend_type.as_str() {
63 #[cfg(all(feature = "onnx", not(target_arch = "wasm32")))]
64 "onnx" => {
65 let backend = onnx::OnnxBackend::from_model_with_device(model, device)?;
66 Box::new(backend)
67 }
68 #[cfg(target_arch = "wasm32")]
69 "onnx" => {
70 let backend = wasm::WasmBackend::from_model_with_device(model, device)?;
71 Box::new(backend)
72 }
73 _ => {
74 return Err(EdgeError::runtime(format!(
75 "Unsupported model type: {backend_type}"
76 )));
77 }
78 };
79
80 let mut runtime_metadata = HashMap::new();
81 runtime_metadata.insert("backend_type".to_string(), Value::String(backend_type));
82 runtime_metadata.insert("device_type".to_string(), Value::String(device.to_string()));
83 runtime_metadata.insert(
84 "created_at".to_string(),
85 Value::String(chrono::Utc::now().to_rfc3339()),
86 );
87
88 Ok(Self {
89 backend,
90 runtime_metadata,
91 })
92 }
93
94 pub async fn from_model(model: Box<dyn crate::Model>) -> EdgeResult<Self> {
96 #[cfg(target_arch = "wasm32")]
97 let device = crate::device::webgpu();
98 #[cfg(not(target_arch = "wasm32"))]
99 let device = crate::device::cpu();
100
101 Self::from_model_with_device(model, device).await
102 }
103
104 pub fn infer(&mut self, input: InferenceInput) -> EdgeResult<InferenceOutput> {
106 if !self.backend.is_ready() {
107 return Err(EdgeError::runtime("Backend is not ready for inference"));
108 }
109
110 self.backend.infer(input)
111 }
112
113 pub fn model_info(&self) -> HashMap<String, Value> {
115 let mut info = self.backend.model_info();
116 info.extend(self.runtime_metadata.clone());
117 info
118 }
119
120 pub fn is_ready(&self) -> bool {
122 self.backend.is_ready()
123 }
124
125 pub fn backend_info(&self) -> HashMap<String, Value> {
127 self.backend.backend_info()
128 }
129}
130
131impl InferenceInput {
132 pub fn new() -> Self {
134 Self::default()
135 }
136
137 pub fn add_input(mut self, name: String, data: Value) -> Self {
139 self.inputs.insert(name, data);
140 self
141 }
142
143 pub fn add_metadata(mut self, key: String, value: Value) -> Self {
145 self.metadata.insert(key, value);
146 self
147 }
148}
149
150impl InferenceOutput {
151 pub fn new() -> Self {
153 Self::default()
154 }
155
156 pub fn get_output(&self, name: &str) -> Option<&Value> {
158 self.outputs.get(name)
159 }
160
161 pub fn get_metadata(&self, key: &str) -> Option<&Value> {
163 self.metadata.get(key)
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170
171 #[test]
172 fn test_inference_input_creation() {
173 let input = InferenceInput::new()
174 .add_input(
175 "input_ids".to_string(),
176 Value::Array(vec![Value::Number(1.into())]),
177 )
178 .add_metadata("batch_size".to_string(), Value::Number(1.into()));
179
180 assert!(input.inputs.contains_key("input_ids"));
181 assert!(input.metadata.contains_key("batch_size"));
182 }
183}