1#![forbid(unsafe_code)]
2#![doc = include_str!("../README.md")]
3
4use core::{fmt, str::FromStr};
5use std::error::Error;
6
7pub mod prelude {
8 pub use crate::{
9 MlModelArchitectureKind, MlModelArtifactKind, MlModelError, MlModelFormat, MlModelId,
10 MlModelKind, MlModelLicense, MlModelName, MlModelProvider, MlModelStage, MlModelTask,
11 MlModelVersion,
12 };
13}
14
15macro_rules! model_text_newtype {
16 ($name:ident) => {
17 #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
18 pub struct $name(String);
19
20 impl $name {
21 pub fn new(value: impl AsRef<str>) -> Result<Self, MlModelError> {
22 non_empty_text(value).map(Self)
23 }
24
25 pub fn as_str(&self) -> &str {
26 &self.0
27 }
28 }
29
30 impl AsRef<str> for $name {
31 fn as_ref(&self) -> &str {
32 self.as_str()
33 }
34 }
35
36 impl fmt::Display for $name {
37 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
38 formatter.write_str(self.as_str())
39 }
40 }
41
42 impl FromStr for $name {
43 type Err = MlModelError;
44
45 fn from_str(value: &str) -> Result<Self, Self::Err> {
46 Self::new(value)
47 }
48 }
49
50 impl TryFrom<&str> for $name {
51 type Error = MlModelError;
52
53 fn try_from(value: &str) -> Result<Self, Self::Error> {
54 Self::new(value)
55 }
56 }
57 };
58}
59
60macro_rules! model_enum {
61 ($name:ident { $($variant:ident => $label:literal),+ $(,)? }) => {
62 #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
63 pub enum $name {
64 $($variant),+
65 }
66
67 impl $name {
68 pub const fn as_str(self) -> &'static str {
69 match self {
70 $(Self::$variant => $label),+
71 }
72 }
73 }
74
75 impl fmt::Display for $name {
76 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
77 formatter.write_str(self.as_str())
78 }
79 }
80
81 impl FromStr for $name {
82 type Err = MlModelError;
83
84 fn from_str(value: &str) -> Result<Self, Self::Err> {
85 match normalized_label(value)?.as_str() {
86 $($label => Ok(Self::$variant),)+
87 _ => Err(MlModelError::UnknownLabel),
88 }
89 }
90 }
91 };
92}
93
94model_text_newtype!(MlModelName);
95model_text_newtype!(MlModelId);
96model_text_newtype!(MlModelVersion);
97model_text_newtype!(MlModelProvider);
98model_text_newtype!(MlModelLicense);
99
100model_enum!(MlModelKind {
101 Classical => "classical",
102 Linear => "linear",
103 TreeBased => "tree-based",
104 Kernel => "kernel",
105 NeuralNetwork => "neural-network",
106 Transformer => "transformer",
107 Diffusion => "diffusion",
108 Ensemble => "ensemble",
109 RuleBased => "rule-based",
110 Hybrid => "hybrid",
111 Other => "other",
112});
113
114model_enum!(MlModelTask {
115 Classification => "classification",
116 Regression => "regression",
117 Clustering => "clustering",
118 Ranking => "ranking",
119 Recommendation => "recommendation",
120 Forecasting => "forecasting",
121 Detection => "detection",
122 Segmentation => "segmentation",
123 Generation => "generation",
124 Embedding => "embedding",
125 AnomalyDetection => "anomaly-detection",
126 Other => "other",
127});
128
129model_enum!(MlModelArchitectureKind {
130 LinearModel => "linear-model",
131 DecisionTree => "decision-tree",
132 RandomForest => "random-forest",
133 GradientBoostedTree => "gradient-boosted-tree",
134 Svm => "svm",
135 Knn => "knn",
136 Mlp => "mlp",
137 Cnn => "cnn",
138 Rnn => "rnn",
139 Lstm => "lstm",
140 Transformer => "transformer",
141 Autoencoder => "autoencoder",
142 Diffusion => "diffusion",
143 Other => "other",
144});
145
146model_enum!(MlModelArtifactKind {
147 Weights => "weights",
148 Config => "config",
149 Tokenizer => "tokenizer",
150 Vocabulary => "vocabulary",
151 Preprocessor => "preprocessor",
152 Pipeline => "pipeline",
153 Checkpoint => "checkpoint",
154 Bundle => "bundle",
155 Card => "card",
156 Metrics => "metrics",
157 Other => "other",
158});
159
160model_enum!(MlModelFormat {
161 Onnx => "onnx",
162 Safetensors => "safetensors",
163 Pickle => "pickle",
164 Joblib => "joblib",
165 TorchScript => "torch-script",
166 TensorFlowSavedModel => "tensorflow-saved-model",
167 CoreMl => "core-ml",
168 Tflite => "tflite",
169 OpenVino => "open-vino",
170 Pmml => "pmml",
171 Custom => "custom",
172});
173
174model_enum!(MlModelStage {
175 Experimental => "experimental",
176 Development => "development",
177 Staging => "staging",
178 Production => "production",
179 Archived => "archived",
180 Deprecated => "deprecated",
181});
182
183#[derive(Clone, Copy, Debug, Eq, PartialEq)]
184pub enum MlModelError {
185 Empty,
186 UnknownLabel,
187}
188
189impl fmt::Display for MlModelError {
190 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
191 match self {
192 Self::Empty => formatter.write_str("ML model metadata text cannot be empty"),
193 Self::UnknownLabel => formatter.write_str("unknown ML model metadata label"),
194 }
195 }
196}
197
198impl Error for MlModelError {}
199
200fn non_empty_text(value: impl AsRef<str>) -> Result<String, MlModelError> {
201 let trimmed = value.as_ref().trim();
202 if trimmed.is_empty() {
203 Err(MlModelError::Empty)
204 } else {
205 Ok(trimmed.to_string())
206 }
207}
208
209fn normalized_label(value: &str) -> Result<String, MlModelError> {
210 let trimmed = value.trim();
211 if trimmed.is_empty() {
212 Err(MlModelError::Empty)
213 } else {
214 Ok(trimmed.to_ascii_lowercase().replace(['_', ' '], "-"))
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use super::{
221 MlModelArchitectureKind, MlModelArtifactKind, MlModelError, MlModelFormat, MlModelKind,
222 MlModelName, MlModelStage, MlModelTask,
223 };
224
225 #[test]
226 fn validates_model_names() -> Result<(), MlModelError> {
227 let name = MlModelName::new(" baseline ")?;
228
229 assert_eq!(name.as_str(), "baseline");
230 assert_eq!(name.to_string(), "baseline");
231 assert_eq!("baseline".parse::<MlModelName>()?, name);
232 Ok(())
233 }
234
235 #[test]
236 fn rejects_empty_model_names() {
237 assert_eq!(MlModelName::new(" "), Err(MlModelError::Empty));
238 }
239
240 #[test]
241 fn displays_and_parses_model_enums() -> Result<(), MlModelError> {
242 assert_eq!("tree based".parse::<MlModelKind>()?, MlModelKind::TreeBased);
243 assert_eq!(
244 "anomaly_detection".parse::<MlModelTask>()?,
245 MlModelTask::AnomalyDetection
246 );
247 assert_eq!(
248 "random forest".parse::<MlModelArchitectureKind>()?,
249 MlModelArchitectureKind::RandomForest
250 );
251 assert_eq!(
252 "checkpoint".parse::<MlModelArtifactKind>()?,
253 MlModelArtifactKind::Checkpoint
254 );
255 assert_eq!(
256 "torch_script".parse::<MlModelFormat>()?,
257 MlModelFormat::TorchScript
258 );
259 assert_eq!(MlModelStage::Production.to_string(), "production");
260 Ok(())
261 }
262}