liquid_edge/runtime/
mod.rs

1//! Generic inference runtime for edge computing
2//!
3//! This module provides a generic interface for running inference on various
4//! deep learning models using different backends.
5
6use crate::error::{EdgeError, EdgeResult};
7use serde_json::Value;
8use std::collections::HashMap;
9
10#[cfg(all(feature = "onnx", not(target_arch = "wasm32")))]
11pub mod onnx;
12
13/// Generic input for inference operations
14#[derive(Debug, Clone, Default)]
15pub struct InferenceInput {
16    /// Named tensor inputs as key-value pairs
17    /// Keys are input names, values are the tensor data as JSON
18    pub inputs: HashMap<String, Value>,
19    /// Input metadata
20    pub metadata: HashMap<String, Value>,
21}
22
23/// Generic output from inference operations
24#[derive(Debug, Clone, Default)]
25pub struct InferenceOutput {
26    /// Named tensor outputs as key-value pairs
27    /// Keys are output names, values are the tensor data as JSON
28    pub outputs: HashMap<String, Value>,
29    /// Output metadata
30    pub metadata: HashMap<String, Value>,
31}
32
33/// Generic inference runtime trait
34pub trait RuntimeBackend: Send + Sync {
35    /// Run inference with the given inputs
36    fn infer(&mut self, input: InferenceInput) -> EdgeResult<InferenceOutput>;
37
38    /// Get model information
39    fn model_info(&self) -> HashMap<String, Value>;
40
41    /// Check if the runtime is ready for inference
42    fn is_ready(&self) -> bool;
43
44    /// Get backend-specific metadata
45    fn backend_info(&self) -> HashMap<String, Value>;
46}
47
48/// Main inference runtime that manages different backends
49pub struct InferenceRuntime {
50    backend: Box<dyn RuntimeBackend>,
51    runtime_metadata: HashMap<String, Value>,
52}
53
54impl InferenceRuntime {
55    /// Create a new inference runtime from a model with a specific device
56    pub async fn from_model_with_device(
57        model: Box<dyn crate::Model>,
58        device: crate::Device,
59    ) -> EdgeResult<Self> {
60        let backend_type = model.model_type().to_string();
61
62        let backend: Box<dyn RuntimeBackend> = match backend_type.as_str() {
63            #[cfg(all(feature = "onnx", not(target_arch = "wasm32")))]
64            "onnx" => {
65                let backend = onnx::OnnxBackend::from_model_with_device(model, device)?;
66                Box::new(backend)
67            }
68            #[cfg(target_arch = "wasm32")]
69            "onnx" => {
70                let backend = wasm::WasmBackend::from_model_with_device(model, device)?;
71                Box::new(backend)
72            }
73            _ => {
74                return Err(EdgeError::runtime(format!(
75                    "Unsupported model type: {backend_type}"
76                )));
77            }
78        };
79
80        let mut runtime_metadata = HashMap::new();
81        runtime_metadata.insert("backend_type".to_string(), Value::String(backend_type));
82        runtime_metadata.insert("device_type".to_string(), Value::String(device.to_string()));
83        runtime_metadata.insert(
84            "created_at".to_string(),
85            Value::String(chrono::Utc::now().to_rfc3339()),
86        );
87
88        Ok(Self {
89            backend,
90            runtime_metadata,
91        })
92    }
93
94    /// Create a new inference runtime from a model (uses CPU device by default, WebGPU on WASM)
95    pub async fn from_model(model: Box<dyn crate::Model>) -> EdgeResult<Self> {
96        #[cfg(target_arch = "wasm32")]
97        let device = crate::device::webgpu();
98        #[cfg(not(target_arch = "wasm32"))]
99        let device = crate::device::cpu();
100
101        Self::from_model_with_device(model, device).await
102    }
103
104    /// Run inference on the loaded model
105    pub fn infer(&mut self, input: InferenceInput) -> EdgeResult<InferenceOutput> {
106        if !self.backend.is_ready() {
107            return Err(EdgeError::runtime("Backend is not ready for inference"));
108        }
109
110        self.backend.infer(input)
111    }
112
113    /// Get comprehensive model information
114    pub fn model_info(&self) -> HashMap<String, Value> {
115        let mut info = self.backend.model_info();
116        info.extend(self.runtime_metadata.clone());
117        info
118    }
119
120    /// Check if the runtime is ready for inference
121    pub fn is_ready(&self) -> bool {
122        self.backend.is_ready()
123    }
124
125    /// Get backend information
126    pub fn backend_info(&self) -> HashMap<String, Value> {
127        self.backend.backend_info()
128    }
129}
130
131impl InferenceInput {
132    /// Create a new empty inference input
133    pub fn new() -> Self {
134        Self::default()
135    }
136
137    /// Add a tensor input
138    pub fn add_input(mut self, name: String, data: Value) -> Self {
139        self.inputs.insert(name, data);
140        self
141    }
142
143    /// Add metadata
144    pub fn add_metadata(mut self, key: String, value: Value) -> Self {
145        self.metadata.insert(key, value);
146        self
147    }
148}
149
150impl InferenceOutput {
151    /// Create a new empty inference output
152    pub fn new() -> Self {
153        Self::default()
154    }
155
156    /// Get output by name
157    pub fn get_output(&self, name: &str) -> Option<&Value> {
158        self.outputs.get(name)
159    }
160
161    /// Get metadata by key
162    pub fn get_metadata(&self, key: &str) -> Option<&Value> {
163        self.metadata.get(key)
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170
171    #[test]
172    fn test_inference_input_creation() {
173        let input = InferenceInput::new()
174            .add_input(
175                "input_ids".to_string(),
176                Value::Array(vec![Value::Number(1.into())]),
177            )
178            .add_metadata("batch_size".to_string(), Value::Number(1.into()));
179
180        assert!(input.inputs.contains_key("input_ids"));
181        assert!(input.metadata.contains_key("batch_size"));
182    }
183}