1#![forbid(unsafe_code)]
2#![doc = include_str!("../README.md")]
3
4use core::{fmt, str::FromStr};
5use std::error::Error;
6
7pub mod prelude {
8 pub use crate::{
9 MlPipelineArtifactKind, MlPipelineDependencyKind, MlPipelineError, MlPipelineId,
10 MlPipelineName, MlPipelineRunId, MlPipelineScheduleKind, MlPipelineStatus,
11 MlPipelineStepKind, MlPipelineStepName, MlPipelineTriggerKind,
12 };
13}
14
15macro_rules! pipeline_text_newtype {
16 ($name:ident) => {
17 #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
18 pub struct $name(String);
19
20 impl $name {
21 pub fn new(value: impl AsRef<str>) -> Result<Self, MlPipelineError> {
22 non_empty_text(value).map(Self)
23 }
24
25 pub fn as_str(&self) -> &str {
26 &self.0
27 }
28 }
29
30 impl AsRef<str> for $name {
31 fn as_ref(&self) -> &str {
32 self.as_str()
33 }
34 }
35
36 impl fmt::Display for $name {
37 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
38 formatter.write_str(self.as_str())
39 }
40 }
41
42 impl FromStr for $name {
43 type Err = MlPipelineError;
44
45 fn from_str(value: &str) -> Result<Self, Self::Err> {
46 Self::new(value)
47 }
48 }
49
50 impl TryFrom<&str> for $name {
51 type Error = MlPipelineError;
52
53 fn try_from(value: &str) -> Result<Self, Self::Error> {
54 Self::new(value)
55 }
56 }
57 };
58}
59
60macro_rules! pipeline_enum {
61 ($name:ident { $($variant:ident => $label:literal),+ $(,)? }) => {
62 #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
63 pub enum $name {
64 $($variant),+
65 }
66
67 impl $name {
68 pub const fn as_str(self) -> &'static str {
69 match self {
70 $(Self::$variant => $label),+
71 }
72 }
73 }
74
75 impl fmt::Display for $name {
76 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
77 formatter.write_str(self.as_str())
78 }
79 }
80
81 impl FromStr for $name {
82 type Err = MlPipelineError;
83
84 fn from_str(value: &str) -> Result<Self, Self::Err> {
85 match normalized_label(value)?.as_str() {
86 $($label => Ok(Self::$variant),)+
87 _ => Err(MlPipelineError::UnknownLabel),
88 }
89 }
90 }
91 };
92}
93
94pipeline_text_newtype!(MlPipelineName);
95pipeline_text_newtype!(MlPipelineId);
96pipeline_text_newtype!(MlPipelineStepName);
97pipeline_text_newtype!(MlPipelineRunId);
98
99pipeline_enum!(MlPipelineStepKind {
100 Ingest => "ingest",
101 Validate => "validate",
102 Clean => "clean",
103 Transform => "transform",
104 Featurize => "featurize",
105 Split => "split",
106 Train => "train",
107 Tune => "tune",
108 Evaluate => "evaluate",
109 Register => "register",
110 Deploy => "deploy",
111 Monitor => "monitor",
112 Rollback => "rollback",
113 Other => "other",
114});
115
116pipeline_enum!(MlPipelineStatus {
117 Draft => "draft",
118 Ready => "ready",
119 Running => "running",
120 Succeeded => "succeeded",
121 Failed => "failed",
122 Cancelled => "cancelled",
123 Paused => "paused",
124 Deprecated => "deprecated",
125});
126
127pipeline_enum!(MlPipelineArtifactKind {
128 Dataset => "dataset",
129 FeatureSet => "feature-set",
130 Model => "model",
131 Metrics => "metrics",
132 Report => "report",
133 Config => "config",
134 Checkpoint => "checkpoint",
135 Prediction => "prediction",
136 Log => "log",
137 Other => "other",
138});
139
140pipeline_enum!(MlPipelineDependencyKind {
141 Data => "data",
142 Model => "model",
143 Config => "config",
144 Secret => "secret",
145 Service => "service",
146 Compute => "compute",
147 HumanApproval => "human-approval",
148 Other => "other",
149});
150
151pipeline_enum!(MlPipelineTriggerKind {
152 Manual => "manual",
153 Schedule => "schedule",
154 Commit => "commit",
155 DataArrival => "data-arrival",
156 ModelChange => "model-change",
157 DriftDetected => "drift-detected",
158 Api => "api",
159 Other => "other",
160});
161
162pipeline_enum!(MlPipelineScheduleKind {
163 None => "none",
164 Once => "once",
165 Hourly => "hourly",
166 Daily => "daily",
167 Weekly => "weekly",
168 Monthly => "monthly",
169 Cron => "cron",
170 EventDriven => "event-driven",
171});
172
173#[derive(Clone, Copy, Debug, Eq, PartialEq)]
174pub enum MlPipelineError {
175 Empty,
176 UnknownLabel,
177}
178
179impl fmt::Display for MlPipelineError {
180 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
181 match self {
182 Self::Empty => formatter.write_str("ML pipeline metadata text cannot be empty"),
183 Self::UnknownLabel => formatter.write_str("unknown ML pipeline metadata label"),
184 }
185 }
186}
187
188impl Error for MlPipelineError {}
189
190fn non_empty_text(value: impl AsRef<str>) -> Result<String, MlPipelineError> {
191 let trimmed = value.as_ref().trim();
192 if trimmed.is_empty() {
193 Err(MlPipelineError::Empty)
194 } else {
195 Ok(trimmed.to_string())
196 }
197}
198
199fn normalized_label(value: &str) -> Result<String, MlPipelineError> {
200 let trimmed = value.trim();
201 if trimmed.is_empty() {
202 Err(MlPipelineError::Empty)
203 } else {
204 Ok(trimmed.to_ascii_lowercase().replace(['_', ' '], "-"))
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use super::{
211 MlPipelineError, MlPipelineName, MlPipelineScheduleKind, MlPipelineStatus,
212 MlPipelineStepKind,
213 };
214
215 #[test]
216 fn validates_pipeline_names() -> Result<(), MlPipelineError> {
217 let name = MlPipelineName::new(" training ")?;
218
219 assert_eq!(name.as_str(), "training");
220 assert_eq!("training".parse::<MlPipelineName>()?, name);
221 assert_eq!(MlPipelineName::new(" "), Err(MlPipelineError::Empty));
222 Ok(())
223 }
224
225 #[test]
226 fn displays_and_parses_pipeline_enums() -> Result<(), MlPipelineError> {
227 assert_eq!(
228 "featurize".parse::<MlPipelineStepKind>()?,
229 MlPipelineStepKind::Featurize
230 );
231 assert_eq!(
232 "ready".parse::<MlPipelineStatus>()?,
233 MlPipelineStatus::Ready
234 );
235 assert_eq!(
236 "event driven".parse::<MlPipelineScheduleKind>()?,
237 MlPipelineScheduleKind::EventDriven
238 );
239 Ok(())
240 }
241}