1#![forbid(unsafe_code)]
2#![doc = include_str!("../README.md")]
3
4use core::{fmt, str::FromStr};
5use std::{error::Error, num::NonZeroUsize};
6
7pub mod prelude {
8 pub use crate::{
9 MlBenchmarkName, MlConfusionMatrixShape, MlEvalSliceKind, MlEvalSliceName,
10 MlEvaluationError, MlEvaluationKind, MlEvaluationRunId, MlEvaluationStatus,
11 MlEvaluationTarget, MlThreshold, MlValidationStrategy,
12 };
13}
14
15macro_rules! evaluation_text_newtype {
16 ($name:ident) => {
17 #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
18 pub struct $name(String);
19
20 impl $name {
21 pub fn new(value: impl AsRef<str>) -> Result<Self, MlEvaluationError> {
22 non_empty_text(value).map(Self)
23 }
24
25 pub fn as_str(&self) -> &str {
26 &self.0
27 }
28 }
29
30 impl AsRef<str> for $name {
31 fn as_ref(&self) -> &str {
32 self.as_str()
33 }
34 }
35
36 impl fmt::Display for $name {
37 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
38 formatter.write_str(self.as_str())
39 }
40 }
41
42 impl FromStr for $name {
43 type Err = MlEvaluationError;
44
45 fn from_str(value: &str) -> Result<Self, Self::Err> {
46 Self::new(value)
47 }
48 }
49
50 impl TryFrom<&str> for $name {
51 type Error = MlEvaluationError;
52
53 fn try_from(value: &str) -> Result<Self, Self::Error> {
54 Self::new(value)
55 }
56 }
57 };
58}
59
60macro_rules! evaluation_enum {
61 ($name:ident { $($variant:ident => $label:literal),+ $(,)? }) => {
62 #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
63 pub enum $name {
64 $($variant),+
65 }
66
67 impl $name {
68 pub const fn as_str(self) -> &'static str {
69 match self {
70 $(Self::$variant => $label),+
71 }
72 }
73 }
74
75 impl fmt::Display for $name {
76 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
77 formatter.write_str(self.as_str())
78 }
79 }
80
81 impl FromStr for $name {
82 type Err = MlEvaluationError;
83
84 fn from_str(value: &str) -> Result<Self, Self::Err> {
85 match normalized_label(value)?.as_str() {
86 $($label => Ok(Self::$variant),)+
87 _ => Err(MlEvaluationError::UnknownLabel),
88 }
89 }
90 }
91 };
92}
93
94evaluation_text_newtype!(MlEvaluationRunId);
95evaluation_text_newtype!(MlEvalSliceName);
96evaluation_text_newtype!(MlBenchmarkName);
97
98#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
99pub struct MlThreshold(f64);
100
101impl MlThreshold {
102 pub fn new(value: f64) -> Result<Self, MlEvaluationError> {
103 if !value.is_finite() {
104 return Err(MlEvaluationError::NonFinite);
105 }
106 if !(0.0..=1.0).contains(&value) {
107 return Err(MlEvaluationError::OutOfRange);
108 }
109 Ok(Self(value))
110 }
111
112 pub const fn value(self) -> f64 {
113 self.0
114 }
115}
116
117#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
118pub struct MlConfusionMatrixShape {
119 rows: NonZeroUsize,
120 columns: NonZeroUsize,
121}
122
123impl MlConfusionMatrixShape {
124 pub fn new(rows: usize, columns: usize) -> Result<Self, MlEvaluationError> {
125 Ok(Self {
126 rows: NonZeroUsize::new(rows).ok_or(MlEvaluationError::Zero)?,
127 columns: NonZeroUsize::new(columns).ok_or(MlEvaluationError::Zero)?,
128 })
129 }
130
131 pub const fn rows(self) -> usize {
132 self.rows.get()
133 }
134
135 pub const fn columns(self) -> usize {
136 self.columns.get()
137 }
138
139 pub const fn is_square(self) -> bool {
140 self.rows.get() == self.columns.get()
141 }
142}
143
144evaluation_enum!(MlEvaluationKind {
145 Offline => "offline",
146 Online => "online",
147 Shadow => "shadow",
148 ABTest => "a-b-test",
149 Backtest => "backtest",
150 CrossValidation => "cross-validation",
151 Holdout => "holdout",
152 Benchmark => "benchmark",
153 HumanEval => "human-eval",
154 Other => "other",
155});
156
157evaluation_enum!(MlValidationStrategy {
158 Holdout => "holdout",
159 KFold => "k-fold",
160 StratifiedKFold => "stratified-k-fold",
161 TimeSeriesSplit => "time-series-split",
162 LeaveOneOut => "leave-one-out",
163 Bootstrap => "bootstrap",
164 Custom => "custom",
165});
166
167evaluation_enum!(MlEvaluationStatus {
168 Pending => "pending",
169 Running => "running",
170 Succeeded => "succeeded",
171 Failed => "failed",
172 Cancelled => "cancelled",
173 Inconclusive => "inconclusive",
174});
175
176evaluation_enum!(MlEvaluationTarget {
177 Model => "model",
178 Pipeline => "pipeline",
179 Dataset => "dataset",
180 Feature => "feature",
181 Label => "label",
182 Artifact => "artifact",
183 TrainingRun => "training-run",
184 Other => "other",
185});
186
187evaluation_enum!(MlEvalSliceKind {
188 Global => "global",
189 Class => "class",
190 Segment => "segment",
191 Cohort => "cohort",
192 Geography => "geography",
193 TimeWindow => "time-window",
194 Device => "device",
195 Language => "language",
196 Custom => "custom",
197});
198
199#[derive(Clone, Copy, Debug, Eq, PartialEq)]
200pub enum MlEvaluationError {
201 Empty,
202 NonFinite,
203 OutOfRange,
204 Zero,
205 UnknownLabel,
206}
207
208impl fmt::Display for MlEvaluationError {
209 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
210 match self {
211 Self::Empty => formatter.write_str("ML evaluation metadata text cannot be empty"),
212 Self::NonFinite => formatter.write_str("ML evaluation value must be finite"),
213 Self::OutOfRange => formatter.write_str("ML evaluation threshold must be in 0.0..=1.0"),
214 Self::Zero => formatter.write_str("ML evaluation count must be positive"),
215 Self::UnknownLabel => formatter.write_str("unknown ML evaluation metadata label"),
216 }
217 }
218}
219
220impl Error for MlEvaluationError {}
221
222fn non_empty_text(value: impl AsRef<str>) -> Result<String, MlEvaluationError> {
223 let trimmed = value.as_ref().trim();
224 if trimmed.is_empty() {
225 Err(MlEvaluationError::Empty)
226 } else {
227 Ok(trimmed.to_string())
228 }
229}
230
231fn normalized_label(value: &str) -> Result<String, MlEvaluationError> {
232 let trimmed = value.trim();
233 if trimmed.is_empty() {
234 Err(MlEvaluationError::Empty)
235 } else {
236 Ok(trimmed.to_ascii_lowercase().replace(['_', ' '], "-"))
237 }
238}
239
240#[cfg(test)]
241mod tests {
242 use super::{
243 MlConfusionMatrixShape, MlEvaluationError, MlEvaluationKind, MlEvaluationRunId,
244 MlEvaluationStatus, MlThreshold, MlValidationStrategy,
245 };
246
247 #[test]
248 fn validates_evaluation_run_ids() -> Result<(), MlEvaluationError> {
249 let run_id = MlEvaluationRunId::new(" eval-001 ")?;
250
251 assert_eq!(run_id.as_str(), "eval-001");
252 assert_eq!("eval-001".parse::<MlEvaluationRunId>()?, run_id);
253 Ok(())
254 }
255
256 #[test]
257 fn validates_thresholds_and_confusion_matrix_shapes() -> Result<(), MlEvaluationError> {
258 assert_eq!(MlThreshold::new(0.0)?.value(), 0.0);
259 assert_eq!(MlThreshold::new(1.0)?.value(), 1.0);
260 assert_eq!(MlThreshold::new(-0.1), Err(MlEvaluationError::OutOfRange));
261 assert_eq!(MlThreshold::new(1.1), Err(MlEvaluationError::OutOfRange));
262 assert_eq!(
263 MlThreshold::new(f64::NAN),
264 Err(MlEvaluationError::NonFinite)
265 );
266
267 let shape = MlConfusionMatrixShape::new(3, 3)?;
268 assert_eq!(shape.rows(), 3);
269 assert!(shape.is_square());
270 assert_eq!(
271 MlConfusionMatrixShape::new(0, 3),
272 Err(MlEvaluationError::Zero)
273 );
274 Ok(())
275 }
276
277 #[test]
278 fn displays_and_parses_evaluation_enums() -> Result<(), MlEvaluationError> {
279 assert_eq!(
280 "a b test".parse::<MlEvaluationKind>()?,
281 MlEvaluationKind::ABTest
282 );
283 assert_eq!(
284 "stratified k fold".parse::<MlValidationStrategy>()?,
285 MlValidationStrategy::StratifiedKFold
286 );
287 assert_eq!(
288 "cancelled".parse::<MlEvaluationStatus>()?,
289 MlEvaluationStatus::Cancelled
290 );
291 Ok(())
292 }
293}