Skip to main content

dag_ml_core/
ids.rs

1use std::fmt;
2
3use serde::{Deserialize, Serialize};
4
5use crate::error::{DagMlError, Result};
6
7fn validate_identifier(value: &str) -> Result<()> {
8    if value.is_empty() {
9        return Err(DagMlError::InvalidIdentifier {
10            value: value.to_string(),
11            reason: "identifier is empty",
12        });
13    }
14    if value.len() > 128 {
15        return Err(DagMlError::InvalidIdentifier {
16            value: value.to_string(),
17            reason: "identifier is longer than 128 bytes",
18        });
19    }
20    if !value
21        .bytes()
22        .all(|b| b.is_ascii_alphanumeric() || matches!(b, b'_' | b'-' | b'.' | b':'))
23    {
24        return Err(DagMlError::InvalidIdentifier {
25            value: value.to_string(),
26            reason: "identifier contains unsupported characters",
27        });
28    }
29    Ok(())
30}
31
32macro_rules! define_id {
33    ($name:ident) => {
34        #[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)]
35        #[serde(try_from = "String", into = "String")]
36        pub struct $name(String);
37
38        impl $name {
39            pub fn new(value: impl Into<String>) -> Result<Self> {
40                let value = value.into();
41                validate_identifier(&value)?;
42                Ok(Self(value))
43            }
44
45            pub fn as_str(&self) -> &str {
46                &self.0
47            }
48        }
49
50        impl TryFrom<String> for $name {
51            type Error = DagMlError;
52
53            fn try_from(value: String) -> Result<Self> {
54                Self::new(value)
55            }
56        }
57
58        impl From<$name> for String {
59            fn from(value: $name) -> Self {
60                value.0
61            }
62        }
63
64        impl fmt::Display for $name {
65            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66                self.0.fmt(f)
67            }
68        }
69    };
70}
71
72define_id!(NodeId);
73define_id!(ObservationId);
74define_id!(SampleId);
75define_id!(FoldId);
76define_id!(TargetId);
77define_id!(GroupId);
78define_id!(ControllerId);
79define_id!(VariantId);
80define_id!(RunId);
81define_id!(BundleId);
82define_id!(ArtifactId);
83define_id!(LineageId);
84define_id!(BranchId);
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89
90    #[test]
91    fn accepts_pipeline_style_node_ids() {
92        assert!(NodeId::new("model:rf.v1").is_ok());
93    }
94
95    #[test]
96    fn rejects_ambiguous_ids() {
97        assert!(NodeId::new("model/rf").is_err());
98        assert!(SampleId::new("").is_err());
99    }
100}