1use crate::Backend;
2use std::collections::HashMap;
3use std::path::Path;
4
5pub trait Model: Send + Sync {
6 type Backend: Backend;
7 type Input;
8 type Output;
9
10 fn name(&self) -> &str;
11
12 fn forward(
13 &self,
14 input: Self::Input,
15 ) -> Result<Self::Output, <Self::Backend as Backend>::Error>;
16
17 fn metadata(&self) -> ModelMetadata {
18 ModelMetadata::default()
19 }
20
21 fn memory_requirements(&self) -> MemoryRequirements {
22 MemoryRequirements::default()
23 }
24}
25
26#[derive(Debug, Clone, Default)]
27pub struct ModelMetadata {
28 pub name: String,
29 pub version: String,
30 pub description: String,
31 pub author: String,
32 pub license: String,
33 pub tags: Vec<String>,
34 pub custom: HashMap<String, String>,
35}
36
37#[derive(Debug, Clone, Default)]
38pub struct MemoryRequirements {
39 pub parameters: usize,
40 pub activations: usize,
41 pub peak: usize,
42}
43
44pub trait BatchedModel: Model {
45 fn forward_batch(
46 &self,
47 batch: Vec<Self::Input>,
48 ) -> Result<Vec<Self::Output>, <Self::Backend as Backend>::Error> {
49 batch.into_iter().map(|input| self.forward(input)).collect()
50 }
51}
52
53pub trait SaveLoadModel: Model {
54 fn save(&self, path: &Path) -> Result<(), <Self::Backend as Backend>::Error>;
55
56 fn load(&mut self, path: &Path) -> Result<(), <Self::Backend as Backend>::Error>;
57}