Skip to main content

use_ml_pipeline/
lib.rs

1#![forbid(unsafe_code)]
2#![doc = include_str!("../README.md")]
3
4use core::{fmt, str::FromStr};
5use std::error::Error;
6
7pub mod prelude {
8    pub use crate::{
9        MlPipelineArtifactKind, MlPipelineDependencyKind, MlPipelineError, MlPipelineId,
10        MlPipelineName, MlPipelineRunId, MlPipelineScheduleKind, MlPipelineStatus,
11        MlPipelineStepKind, MlPipelineStepName, MlPipelineTriggerKind,
12    };
13}
14
15macro_rules! pipeline_text_newtype {
16    ($name:ident) => {
17        #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
18        pub struct $name(String);
19
20        impl $name {
21            pub fn new(value: impl AsRef<str>) -> Result<Self, MlPipelineError> {
22                non_empty_text(value).map(Self)
23            }
24
25            pub fn as_str(&self) -> &str {
26                &self.0
27            }
28        }
29
30        impl AsRef<str> for $name {
31            fn as_ref(&self) -> &str {
32                self.as_str()
33            }
34        }
35
36        impl fmt::Display for $name {
37            fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
38                formatter.write_str(self.as_str())
39            }
40        }
41
42        impl FromStr for $name {
43            type Err = MlPipelineError;
44
45            fn from_str(value: &str) -> Result<Self, Self::Err> {
46                Self::new(value)
47            }
48        }
49
50        impl TryFrom<&str> for $name {
51            type Error = MlPipelineError;
52
53            fn try_from(value: &str) -> Result<Self, Self::Error> {
54                Self::new(value)
55            }
56        }
57    };
58}
59
60macro_rules! pipeline_enum {
61    ($name:ident { $($variant:ident => $label:literal),+ $(,)? }) => {
62        #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
63        pub enum $name {
64            $($variant),+
65        }
66
67        impl $name {
68            pub const fn as_str(self) -> &'static str {
69                match self {
70                    $(Self::$variant => $label),+
71                }
72            }
73        }
74
75        impl fmt::Display for $name {
76            fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
77                formatter.write_str(self.as_str())
78            }
79        }
80
81        impl FromStr for $name {
82            type Err = MlPipelineError;
83
84            fn from_str(value: &str) -> Result<Self, Self::Err> {
85                match normalized_label(value)?.as_str() {
86                    $($label => Ok(Self::$variant),)+
87                    _ => Err(MlPipelineError::UnknownLabel),
88                }
89            }
90        }
91    };
92}
93
94pipeline_text_newtype!(MlPipelineName);
95pipeline_text_newtype!(MlPipelineId);
96pipeline_text_newtype!(MlPipelineStepName);
97pipeline_text_newtype!(MlPipelineRunId);
98
99pipeline_enum!(MlPipelineStepKind {
100    Ingest => "ingest",
101    Validate => "validate",
102    Clean => "clean",
103    Transform => "transform",
104    Featurize => "featurize",
105    Split => "split",
106    Train => "train",
107    Tune => "tune",
108    Evaluate => "evaluate",
109    Register => "register",
110    Deploy => "deploy",
111    Monitor => "monitor",
112    Rollback => "rollback",
113    Other => "other",
114});
115
116pipeline_enum!(MlPipelineStatus {
117    Draft => "draft",
118    Ready => "ready",
119    Running => "running",
120    Succeeded => "succeeded",
121    Failed => "failed",
122    Cancelled => "cancelled",
123    Paused => "paused",
124    Deprecated => "deprecated",
125});
126
127pipeline_enum!(MlPipelineArtifactKind {
128    Dataset => "dataset",
129    FeatureSet => "feature-set",
130    Model => "model",
131    Metrics => "metrics",
132    Report => "report",
133    Config => "config",
134    Checkpoint => "checkpoint",
135    Prediction => "prediction",
136    Log => "log",
137    Other => "other",
138});
139
140pipeline_enum!(MlPipelineDependencyKind {
141    Data => "data",
142    Model => "model",
143    Config => "config",
144    Secret => "secret",
145    Service => "service",
146    Compute => "compute",
147    HumanApproval => "human-approval",
148    Other => "other",
149});
150
151pipeline_enum!(MlPipelineTriggerKind {
152    Manual => "manual",
153    Schedule => "schedule",
154    Commit => "commit",
155    DataArrival => "data-arrival",
156    ModelChange => "model-change",
157    DriftDetected => "drift-detected",
158    Api => "api",
159    Other => "other",
160});
161
162pipeline_enum!(MlPipelineScheduleKind {
163    None => "none",
164    Once => "once",
165    Hourly => "hourly",
166    Daily => "daily",
167    Weekly => "weekly",
168    Monthly => "monthly",
169    Cron => "cron",
170    EventDriven => "event-driven",
171});
172
173#[derive(Clone, Copy, Debug, Eq, PartialEq)]
174pub enum MlPipelineError {
175    Empty,
176    UnknownLabel,
177}
178
179impl fmt::Display for MlPipelineError {
180    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
181        match self {
182            Self::Empty => formatter.write_str("ML pipeline metadata text cannot be empty"),
183            Self::UnknownLabel => formatter.write_str("unknown ML pipeline metadata label"),
184        }
185    }
186}
187
188impl Error for MlPipelineError {}
189
190fn non_empty_text(value: impl AsRef<str>) -> Result<String, MlPipelineError> {
191    let trimmed = value.as_ref().trim();
192    if trimmed.is_empty() {
193        Err(MlPipelineError::Empty)
194    } else {
195        Ok(trimmed.to_string())
196    }
197}
198
199fn normalized_label(value: &str) -> Result<String, MlPipelineError> {
200    let trimmed = value.trim();
201    if trimmed.is_empty() {
202        Err(MlPipelineError::Empty)
203    } else {
204        Ok(trimmed.to_ascii_lowercase().replace(['_', ' '], "-"))
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use super::{
211        MlPipelineError, MlPipelineName, MlPipelineScheduleKind, MlPipelineStatus,
212        MlPipelineStepKind,
213    };
214
215    #[test]
216    fn validates_pipeline_names() -> Result<(), MlPipelineError> {
217        let name = MlPipelineName::new(" training ")?;
218
219        assert_eq!(name.as_str(), "training");
220        assert_eq!("training".parse::<MlPipelineName>()?, name);
221        assert_eq!(MlPipelineName::new("  "), Err(MlPipelineError::Empty));
222        Ok(())
223    }
224
225    #[test]
226    fn displays_and_parses_pipeline_enums() -> Result<(), MlPipelineError> {
227        assert_eq!(
228            "featurize".parse::<MlPipelineStepKind>()?,
229            MlPipelineStepKind::Featurize
230        );
231        assert_eq!(
232            "ready".parse::<MlPipelineStatus>()?,
233            MlPipelineStatus::Ready
234        );
235        assert_eq!(
236            "event driven".parse::<MlPipelineScheduleKind>()?,
237            MlPipelineScheduleKind::EventDriven
238        );
239        Ok(())
240    }
241}