Skip to main content

proto_blue_syntax/
did.rs

1//! DID (Decentralized Identifier) validation and types.
2//!
3//! DIDs follow the format: `did:method:method-specific-id`
4//! See: <https://www.w3.org/TR/did-core/>
5
6use regex::Regex;
7use std::fmt;
8use std::str::FromStr;
9
10/// Maximum length of a DID string.
11const MAX_DID_LENGTH: usize = 2048;
12
13static DID_REGEX: std::sync::LazyLock<Regex> = std::sync::LazyLock::new(|| {
14    Regex::new(r"^did:[a-z]+:[a-zA-Z0-9._:%-]*[a-zA-Z0-9._-]$").unwrap()
15});
16
17/// A validated DID (Decentralized Identifier).
18///
19/// Format: `did:method:method-specific-id`
20#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
21pub struct Did(String);
22
23/// Error returned when a DID string is invalid.
24#[derive(Debug, Clone, thiserror::Error)]
25#[error("Invalid DID: {reason}")]
26pub struct InvalidDidError {
27    pub reason: String,
28}
29
30impl Did {
31    /// Create a new `Did` from a string, validating the format.
32    pub fn new(s: &str) -> Result<Self, InvalidDidError> {
33        ensure_valid_did(s)?;
34        Ok(Self(s.to_string()))
35    }
36
37    /// Check whether a string is a valid DID without allocating.
38    #[must_use]
39    pub fn is_valid(s: &str) -> bool {
40        ensure_valid_did(s).is_ok()
41    }
42
43    /// Return the DID method (e.g., `"plc"` for `did:plc:...`).
44    #[must_use]
45    pub fn method(&self) -> &str {
46        // Safe: we validated the format in the constructor
47        self.0.split(':').nth(1).unwrap()
48    }
49
50    /// Return the inner string.
51    #[must_use]
52    pub fn as_str(&self) -> &str {
53        &self.0
54    }
55
56    /// Consume and return the inner string.
57    #[must_use]
58    pub fn into_inner(self) -> String {
59        self.0
60    }
61}
62
63fn ensure_valid_did(s: &str) -> Result<(), InvalidDidError> {
64    let err = |reason: &str| InvalidDidError {
65        reason: reason.to_string(),
66    };
67
68    if s.len() > MAX_DID_LENGTH {
69        return Err(err(&format!(
70            "DID is too long ({} chars, max {})",
71            s.len(),
72            MAX_DID_LENGTH
73        )));
74    }
75
76    if !DID_REGEX.is_match(s) {
77        // Provide more specific error messages
78        if !s.starts_with("did:") {
79            return Err(err("DID requires \"did:\" prefix"));
80        }
81        if s.ends_with(':') || s.ends_with('%') {
82            return Err(err("DID cannot end with ':' or '%'"));
83        }
84        let parts: Vec<&str> = s.splitn(4, ':').collect();
85        if parts.len() < 3 {
86            return Err(err(
87                "DID requires prefix, method, and method-specific content",
88            ));
89        }
90        if parts[1].is_empty() || !parts[1].chars().all(|c| c.is_ascii_lowercase()) {
91            return Err(err("DID method must be lowercase letters only"));
92        }
93        return Err(err("DID contains invalid characters"));
94    }
95
96    Ok(())
97}
98
99impl fmt::Display for Did {
100    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
101        f.write_str(&self.0)
102    }
103}
104
105impl FromStr for Did {
106    type Err = InvalidDidError;
107    fn from_str(s: &str) -> Result<Self, Self::Err> {
108        Self::new(s)
109    }
110}
111
112impl AsRef<str> for Did {
113    fn as_ref(&self) -> &str {
114        &self.0
115    }
116}
117
118impl serde::Serialize for Did {
119    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
120        self.0.serialize(serializer)
121    }
122}
123
124impl<'de> serde::Deserialize<'de> for Did {
125    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
126        let s = String::deserialize(deserializer)?;
127        Self::new(&s).map_err(serde::de::Error::custom)
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134
135    #[test]
136    fn valid_dids() {
137        let cases = [
138            "did:plc:asdf123",
139            "did:web:example.com",
140            "did:method:val:two",
141            "did:m:v",
142            "did:method:%3A",
143            "did:method:val-two",
144            "did:method:val_two",
145            "did:method:val.two",
146        ];
147        for did in &cases {
148            assert!(Did::new(did).is_ok(), "should be valid: {did}");
149        }
150    }
151
152    #[test]
153    fn invalid_dids() {
154        let cases = [
155            ("", "empty"),
156            ("did:", "no method"),
157            ("did:m:", "ends with colon"),
158            ("did:m:%", "ends with percent"),
159            ("DID:method:val", "uppercase prefix"),
160            ("did:UPPER:val", "uppercase method"),
161            ("did:m:v!v", "invalid character"),
162            ("randomstring", "no prefix"),
163            ("did:method:", "ends with colon"),
164        ];
165        for (input, desc) in &cases {
166            assert!(
167                Did::new(input).is_err(),
168                "should be invalid ({desc}): {input}"
169            );
170        }
171    }
172
173    #[test]
174    fn method_extraction() {
175        let did = Did::new("did:plc:asdf123").unwrap();
176        assert_eq!(did.method(), "plc");
177
178        let did = Did::new("did:web:example.com").unwrap();
179        assert_eq!(did.method(), "web");
180    }
181
182    #[test]
183    fn serde_roundtrip() {
184        let did = Did::new("did:plc:asdf123").unwrap();
185        let json = serde_json::to_string(&did).unwrap();
186        assert_eq!(json, "\"did:plc:asdf123\"");
187        let parsed: Did = serde_json::from_str(&json).unwrap();
188        assert_eq!(parsed, did);
189    }
190
191    #[test]
192    fn max_length() {
193        let long_did = format!("did:m:{}", "a".repeat(MAX_DID_LENGTH));
194        assert!(Did::new(&long_did).is_err());
195    }
196}