Skip to main content

use_ml_training/
lib.rs

1#![forbid(unsafe_code)]
2#![doc = include_str!("../README.md")]
3
4use core::{fmt, str::FromStr};
5use std::{error::Error, num::NonZeroUsize};
6
7pub mod prelude {
8    pub use crate::{
9        MlBatchSize, MlCheckpointKind, MlEpochCount, MlHyperparameterName, MlHyperparameterValue,
10        MlLearningRate, MlLossKind, MlOptimizerKind, MlTrainingError, MlTrainingJobName,
11        MlTrainingPhase, MlTrainingRunId, MlTrainingStatus,
12    };
13}
14
15macro_rules! training_text_newtype {
16    ($name:ident) => {
17        #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
18        pub struct $name(String);
19
20        impl $name {
21            pub fn new(value: impl AsRef<str>) -> Result<Self, MlTrainingError> {
22                non_empty_text(value).map(Self)
23            }
24
25            pub fn as_str(&self) -> &str {
26                &self.0
27            }
28        }
29
30        impl AsRef<str> for $name {
31            fn as_ref(&self) -> &str {
32                self.as_str()
33            }
34        }
35
36        impl fmt::Display for $name {
37            fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
38                formatter.write_str(self.as_str())
39            }
40        }
41
42        impl FromStr for $name {
43            type Err = MlTrainingError;
44
45            fn from_str(value: &str) -> Result<Self, Self::Err> {
46                Self::new(value)
47            }
48        }
49
50        impl TryFrom<&str> for $name {
51            type Error = MlTrainingError;
52
53            fn try_from(value: &str) -> Result<Self, Self::Error> {
54                Self::new(value)
55            }
56        }
57    };
58}
59
60macro_rules! training_enum {
61    ($name:ident { $($variant:ident => $label:literal),+ $(,)? }) => {
62        #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
63        pub enum $name {
64            $($variant),+
65        }
66
67        impl $name {
68            pub const fn as_str(self) -> &'static str {
69                match self {
70                    $(Self::$variant => $label),+
71                }
72            }
73        }
74
75        impl fmt::Display for $name {
76            fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
77                formatter.write_str(self.as_str())
78            }
79        }
80
81        impl FromStr for $name {
82            type Err = MlTrainingError;
83
84            fn from_str(value: &str) -> Result<Self, Self::Err> {
85                match normalized_label(value)?.as_str() {
86                    $($label => Ok(Self::$variant),)+
87                    _ => Err(MlTrainingError::UnknownLabel),
88                }
89            }
90        }
91    };
92}
93
94training_text_newtype!(MlTrainingRunId);
95training_text_newtype!(MlTrainingJobName);
96training_text_newtype!(MlHyperparameterName);
97training_text_newtype!(MlHyperparameterValue);
98
99#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
100pub struct MlBatchSize(NonZeroUsize);
101
102impl MlBatchSize {
103    pub fn new(value: usize) -> Result<Self, MlTrainingError> {
104        NonZeroUsize::new(value)
105            .map(Self)
106            .ok_or(MlTrainingError::Zero)
107    }
108
109    pub const fn get(self) -> usize {
110        self.0.get()
111    }
112}
113
114#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
115pub struct MlEpochCount(NonZeroUsize);
116
117impl MlEpochCount {
118    pub fn new(value: usize) -> Result<Self, MlTrainingError> {
119        NonZeroUsize::new(value)
120            .map(Self)
121            .ok_or(MlTrainingError::Zero)
122    }
123
124    pub const fn get(self) -> usize {
125        self.0.get()
126    }
127}
128
129#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
130pub struct MlLearningRate(f64);
131
132impl MlLearningRate {
133    pub fn new(value: f64) -> Result<Self, MlTrainingError> {
134        if !value.is_finite() {
135            return Err(MlTrainingError::NonFinite);
136        }
137        if value <= 0.0 {
138            return Err(MlTrainingError::NonPositive);
139        }
140        Ok(Self(value))
141    }
142
143    pub const fn value(self) -> f64 {
144        self.0
145    }
146}
147
148training_enum!(MlTrainingStatus {
149    Queued => "queued",
150    Running => "running",
151    Succeeded => "succeeded",
152    Failed => "failed",
153    Cancelled => "cancelled",
154    TimedOut => "timed-out",
155    Paused => "paused",
156    Unknown => "unknown",
157});
158
159training_enum!(MlTrainingPhase {
160    PrepareData => "prepare-data",
161    Initialize => "initialize",
162    Train => "train",
163    Validate => "validate",
164    Tune => "tune",
165    Checkpoint => "checkpoint",
166    Evaluate => "evaluate",
167    Export => "export",
168    Complete => "complete",
169});
170
171training_enum!(MlOptimizerKind {
172    Sgd => "sgd",
173    Momentum => "momentum",
174    Adam => "adam",
175    AdamW => "adamw",
176    RmsProp => "rmsprop",
177    Adagrad => "adagrad",
178    Adadelta => "adadelta",
179    Lbfgs => "lbfgs",
180    Custom => "custom",
181});
182
183training_enum!(MlLossKind {
184    CrossEntropy => "cross-entropy",
185    BinaryCrossEntropy => "binary-cross-entropy",
186    MeanSquaredError => "mean-squared-error",
187    MeanAbsoluteError => "mean-absolute-error",
188    Huber => "huber",
189    Hinge => "hinge",
190    Triplet => "triplet",
191    Contrastive => "contrastive",
192    Custom => "custom",
193});
194
195training_enum!(MlCheckpointKind {
196    Best => "best",
197    Latest => "latest",
198    Epoch => "epoch",
199    Step => "step",
200    Manual => "manual",
201    Final => "final",
202});
203
204#[derive(Clone, Copy, Debug, Eq, PartialEq)]
205pub enum MlTrainingError {
206    Empty,
207    Zero,
208    NonFinite,
209    NonPositive,
210    UnknownLabel,
211}
212
213impl fmt::Display for MlTrainingError {
214    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
215        match self {
216            Self::Empty => formatter.write_str("ML training metadata text cannot be empty"),
217            Self::Zero => formatter.write_str("ML training count must be positive"),
218            Self::NonFinite => formatter.write_str("ML training value must be finite"),
219            Self::NonPositive => formatter.write_str("ML training value must be positive"),
220            Self::UnknownLabel => formatter.write_str("unknown ML training metadata label"),
221        }
222    }
223}
224
225impl Error for MlTrainingError {}
226
227fn non_empty_text(value: impl AsRef<str>) -> Result<String, MlTrainingError> {
228    let trimmed = value.as_ref().trim();
229    if trimmed.is_empty() {
230        Err(MlTrainingError::Empty)
231    } else {
232        Ok(trimmed.to_string())
233    }
234}
235
236fn normalized_label(value: &str) -> Result<String, MlTrainingError> {
237    let trimmed = value.trim();
238    if trimmed.is_empty() {
239        Err(MlTrainingError::Empty)
240    } else {
241        Ok(trimmed.to_ascii_lowercase().replace(['_', ' '], "-"))
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use super::{
248        MlBatchSize, MlCheckpointKind, MlEpochCount, MlLearningRate, MlLossKind, MlOptimizerKind,
249        MlTrainingError, MlTrainingRunId, MlTrainingStatus,
250    };
251
252    #[test]
253    fn validates_training_ids() -> Result<(), MlTrainingError> {
254        let run_id = MlTrainingRunId::new(" run-001 ")?;
255
256        assert_eq!(run_id.as_str(), "run-001");
257        assert_eq!("run-001".parse::<MlTrainingRunId>()?, run_id);
258        Ok(())
259    }
260
261    #[test]
262    fn validates_positive_counts_and_learning_rates() -> Result<(), MlTrainingError> {
263        assert_eq!(MlBatchSize::new(32)?.get(), 32);
264        assert_eq!(MlEpochCount::new(10)?.get(), 10);
265        assert_eq!(MlLearningRate::new(0.001)?.value(), 0.001);
266        assert_eq!(MlBatchSize::new(0), Err(MlTrainingError::Zero));
267        assert_eq!(MlEpochCount::new(0), Err(MlTrainingError::Zero));
268        assert_eq!(MlLearningRate::new(0.0), Err(MlTrainingError::NonPositive));
269        assert_eq!(
270            MlLearningRate::new(f64::NAN),
271            Err(MlTrainingError::NonFinite)
272        );
273        Ok(())
274    }
275
276    #[test]
277    fn displays_and_parses_training_enums() -> Result<(), MlTrainingError> {
278        assert_eq!(
279            "timed out".parse::<MlTrainingStatus>()?,
280            MlTrainingStatus::TimedOut
281        );
282        assert_eq!("adamw".parse::<MlOptimizerKind>()?, MlOptimizerKind::AdamW);
283        assert_eq!(
284            "mean squared error".parse::<MlLossKind>()?,
285            MlLossKind::MeanSquaredError
286        );
287        assert_eq!(MlCheckpointKind::Latest.to_string(), "latest");
288        Ok(())
289    }
290}