inferox_core/
model.rs

1use crate::Backend;
2use std::collections::HashMap;
3use std::path::Path;
4
5pub trait Model: Send + Sync {
6    type Backend: Backend;
7    type Input;
8    type Output;
9
10    fn name(&self) -> &str;
11
12    fn forward(
13        &self,
14        input: Self::Input,
15    ) -> Result<Self::Output, <Self::Backend as Backend>::Error>;
16
17    fn metadata(&self) -> ModelMetadata {
18        ModelMetadata::default()
19    }
20
21    fn memory_requirements(&self) -> MemoryRequirements {
22        MemoryRequirements::default()
23    }
24}
25
26#[derive(Debug, Clone, Default)]
27pub struct ModelMetadata {
28    pub name: String,
29    pub version: String,
30    pub description: String,
31    pub author: String,
32    pub license: String,
33    pub tags: Vec<String>,
34    pub custom: HashMap<String, String>,
35}
36
37#[derive(Debug, Clone, Default)]
38pub struct MemoryRequirements {
39    pub parameters: usize,
40    pub activations: usize,
41    pub peak: usize,
42}
43
44pub trait BatchedModel: Model {
45    fn forward_batch(
46        &self,
47        batch: Vec<Self::Input>,
48    ) -> Result<Vec<Self::Output>, <Self::Backend as Backend>::Error> {
49        batch.into_iter().map(|input| self.forward(input)).collect()
50    }
51}
52
53pub trait SaveLoadModel: Model {
54    fn save(&self, path: &Path) -> Result<(), <Self::Backend as Backend>::Error>;
55
56    fn load(&mut self, path: &Path) -> Result<(), <Self::Backend as Backend>::Error>;
57}