Skip to main content

use_ml_model/
lib.rs

1#![forbid(unsafe_code)]
2#![doc = include_str!("../README.md")]
3
4use core::{fmt, str::FromStr};
5use std::error::Error;
6
7pub mod prelude {
8    pub use crate::{
9        MlModelArchitectureKind, MlModelArtifactKind, MlModelError, MlModelFormat, MlModelId,
10        MlModelKind, MlModelLicense, MlModelName, MlModelProvider, MlModelStage, MlModelTask,
11        MlModelVersion,
12    };
13}
14
15macro_rules! model_text_newtype {
16    ($name:ident) => {
17        #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
18        pub struct $name(String);
19
20        impl $name {
21            pub fn new(value: impl AsRef<str>) -> Result<Self, MlModelError> {
22                non_empty_text(value).map(Self)
23            }
24
25            pub fn as_str(&self) -> &str {
26                &self.0
27            }
28        }
29
30        impl AsRef<str> for $name {
31            fn as_ref(&self) -> &str {
32                self.as_str()
33            }
34        }
35
36        impl fmt::Display for $name {
37            fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
38                formatter.write_str(self.as_str())
39            }
40        }
41
42        impl FromStr for $name {
43            type Err = MlModelError;
44
45            fn from_str(value: &str) -> Result<Self, Self::Err> {
46                Self::new(value)
47            }
48        }
49
50        impl TryFrom<&str> for $name {
51            type Error = MlModelError;
52
53            fn try_from(value: &str) -> Result<Self, Self::Error> {
54                Self::new(value)
55            }
56        }
57    };
58}
59
60macro_rules! model_enum {
61    ($name:ident { $($variant:ident => $label:literal),+ $(,)? }) => {
62        #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
63        pub enum $name {
64            $($variant),+
65        }
66
67        impl $name {
68            pub const fn as_str(self) -> &'static str {
69                match self {
70                    $(Self::$variant => $label),+
71                }
72            }
73        }
74
75        impl fmt::Display for $name {
76            fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
77                formatter.write_str(self.as_str())
78            }
79        }
80
81        impl FromStr for $name {
82            type Err = MlModelError;
83
84            fn from_str(value: &str) -> Result<Self, Self::Err> {
85                match normalized_label(value)?.as_str() {
86                    $($label => Ok(Self::$variant),)+
87                    _ => Err(MlModelError::UnknownLabel),
88                }
89            }
90        }
91    };
92}
93
94model_text_newtype!(MlModelName);
95model_text_newtype!(MlModelId);
96model_text_newtype!(MlModelVersion);
97model_text_newtype!(MlModelProvider);
98model_text_newtype!(MlModelLicense);
99
100model_enum!(MlModelKind {
101    Classical => "classical",
102    Linear => "linear",
103    TreeBased => "tree-based",
104    Kernel => "kernel",
105    NeuralNetwork => "neural-network",
106    Transformer => "transformer",
107    Diffusion => "diffusion",
108    Ensemble => "ensemble",
109    RuleBased => "rule-based",
110    Hybrid => "hybrid",
111    Other => "other",
112});
113
114model_enum!(MlModelTask {
115    Classification => "classification",
116    Regression => "regression",
117    Clustering => "clustering",
118    Ranking => "ranking",
119    Recommendation => "recommendation",
120    Forecasting => "forecasting",
121    Detection => "detection",
122    Segmentation => "segmentation",
123    Generation => "generation",
124    Embedding => "embedding",
125    AnomalyDetection => "anomaly-detection",
126    Other => "other",
127});
128
129model_enum!(MlModelArchitectureKind {
130    LinearModel => "linear-model",
131    DecisionTree => "decision-tree",
132    RandomForest => "random-forest",
133    GradientBoostedTree => "gradient-boosted-tree",
134    Svm => "svm",
135    Knn => "knn",
136    Mlp => "mlp",
137    Cnn => "cnn",
138    Rnn => "rnn",
139    Lstm => "lstm",
140    Transformer => "transformer",
141    Autoencoder => "autoencoder",
142    Diffusion => "diffusion",
143    Other => "other",
144});
145
146model_enum!(MlModelArtifactKind {
147    Weights => "weights",
148    Config => "config",
149    Tokenizer => "tokenizer",
150    Vocabulary => "vocabulary",
151    Preprocessor => "preprocessor",
152    Pipeline => "pipeline",
153    Checkpoint => "checkpoint",
154    Bundle => "bundle",
155    Card => "card",
156    Metrics => "metrics",
157    Other => "other",
158});
159
160model_enum!(MlModelFormat {
161    Onnx => "onnx",
162    Safetensors => "safetensors",
163    Pickle => "pickle",
164    Joblib => "joblib",
165    TorchScript => "torch-script",
166    TensorFlowSavedModel => "tensorflow-saved-model",
167    CoreMl => "core-ml",
168    Tflite => "tflite",
169    OpenVino => "open-vino",
170    Pmml => "pmml",
171    Custom => "custom",
172});
173
174model_enum!(MlModelStage {
175    Experimental => "experimental",
176    Development => "development",
177    Staging => "staging",
178    Production => "production",
179    Archived => "archived",
180    Deprecated => "deprecated",
181});
182
183#[derive(Clone, Copy, Debug, Eq, PartialEq)]
184pub enum MlModelError {
185    Empty,
186    UnknownLabel,
187}
188
189impl fmt::Display for MlModelError {
190    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
191        match self {
192            Self::Empty => formatter.write_str("ML model metadata text cannot be empty"),
193            Self::UnknownLabel => formatter.write_str("unknown ML model metadata label"),
194        }
195    }
196}
197
198impl Error for MlModelError {}
199
200fn non_empty_text(value: impl AsRef<str>) -> Result<String, MlModelError> {
201    let trimmed = value.as_ref().trim();
202    if trimmed.is_empty() {
203        Err(MlModelError::Empty)
204    } else {
205        Ok(trimmed.to_string())
206    }
207}
208
209fn normalized_label(value: &str) -> Result<String, MlModelError> {
210    let trimmed = value.trim();
211    if trimmed.is_empty() {
212        Err(MlModelError::Empty)
213    } else {
214        Ok(trimmed.to_ascii_lowercase().replace(['_', ' '], "-"))
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use super::{
221        MlModelArchitectureKind, MlModelArtifactKind, MlModelError, MlModelFormat, MlModelKind,
222        MlModelName, MlModelStage, MlModelTask,
223    };
224
225    #[test]
226    fn validates_model_names() -> Result<(), MlModelError> {
227        let name = MlModelName::new(" baseline ")?;
228
229        assert_eq!(name.as_str(), "baseline");
230        assert_eq!(name.to_string(), "baseline");
231        assert_eq!("baseline".parse::<MlModelName>()?, name);
232        Ok(())
233    }
234
235    #[test]
236    fn rejects_empty_model_names() {
237        assert_eq!(MlModelName::new("  "), Err(MlModelError::Empty));
238    }
239
240    #[test]
241    fn displays_and_parses_model_enums() -> Result<(), MlModelError> {
242        assert_eq!("tree based".parse::<MlModelKind>()?, MlModelKind::TreeBased);
243        assert_eq!(
244            "anomaly_detection".parse::<MlModelTask>()?,
245            MlModelTask::AnomalyDetection
246        );
247        assert_eq!(
248            "random forest".parse::<MlModelArchitectureKind>()?,
249            MlModelArchitectureKind::RandomForest
250        );
251        assert_eq!(
252            "checkpoint".parse::<MlModelArtifactKind>()?,
253            MlModelArtifactKind::Checkpoint
254        );
255        assert_eq!(
256            "torch_script".parse::<MlModelFormat>()?,
257            MlModelFormat::TorchScript
258        );
259        assert_eq!(MlModelStage::Production.to_string(), "production");
260        Ok(())
261    }
262}