use regex::Regex;
use std::fmt;
use std::str::FromStr;
const MAX_DID_LENGTH: usize = 2048;
static DID_REGEX: std::sync::LazyLock<Regex> = std::sync::LazyLock::new(|| {
Regex::new(r"^did:[a-z]+:[a-zA-Z0-9._:%-]*[a-zA-Z0-9._-]$").unwrap()
});
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct Did(String);
#[derive(Debug, Clone, thiserror::Error)]
#[error("Invalid DID: {reason}")]
pub struct InvalidDidError {
pub reason: String,
}
impl Did {
pub fn new(s: &str) -> Result<Self, InvalidDidError> {
ensure_valid_did(s)?;
Ok(Self(s.to_string()))
}
#[must_use]
pub fn is_valid(s: &str) -> bool {
ensure_valid_did(s).is_ok()
}
#[must_use]
pub fn method(&self) -> &str {
self.0.split(':').nth(1).unwrap()
}
#[must_use]
pub fn as_str(&self) -> &str {
&self.0
}
#[must_use]
pub fn into_inner(self) -> String {
self.0
}
}
fn ensure_valid_did(s: &str) -> Result<(), InvalidDidError> {
let err = |reason: &str| InvalidDidError {
reason: reason.to_string(),
};
if s.len() > MAX_DID_LENGTH {
return Err(err(&format!(
"DID is too long ({} chars, max {})",
s.len(),
MAX_DID_LENGTH
)));
}
if !DID_REGEX.is_match(s) {
if !s.starts_with("did:") {
return Err(err("DID requires \"did:\" prefix"));
}
if s.ends_with(':') || s.ends_with('%') {
return Err(err("DID cannot end with ':' or '%'"));
}
let parts: Vec<&str> = s.splitn(4, ':').collect();
if parts.len() < 3 {
return Err(err(
"DID requires prefix, method, and method-specific content",
));
}
if parts[1].is_empty() || !parts[1].chars().all(|c| c.is_ascii_lowercase()) {
return Err(err("DID method must be lowercase letters only"));
}
return Err(err("DID contains invalid characters"));
}
Ok(())
}
impl fmt::Display for Did {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.0)
}
}
impl FromStr for Did {
type Err = InvalidDidError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Self::new(s)
}
}
impl AsRef<str> for Did {
fn as_ref(&self) -> &str {
&self.0
}
}
impl serde::Serialize for Did {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
self.0.serialize(serializer)
}
}
impl<'de> serde::Deserialize<'de> for Did {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let s = String::deserialize(deserializer)?;
Self::new(&s).map_err(serde::de::Error::custom)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn valid_dids() {
let cases = [
"did:plc:asdf123",
"did:web:example.com",
"did:method:val:two",
"did:m:v",
"did:method:%3A",
"did:method:val-two",
"did:method:val_two",
"did:method:val.two",
];
for did in &cases {
assert!(Did::new(did).is_ok(), "should be valid: {did}");
}
}
#[test]
fn invalid_dids() {
let cases = [
("", "empty"),
("did:", "no method"),
("did:m:", "ends with colon"),
("did:m:%", "ends with percent"),
("DID:method:val", "uppercase prefix"),
("did:UPPER:val", "uppercase method"),
("did:m:v!v", "invalid character"),
("randomstring", "no prefix"),
("did:method:", "ends with colon"),
];
for (input, desc) in &cases {
assert!(
Did::new(input).is_err(),
"should be invalid ({desc}): {input}"
);
}
}
#[test]
fn method_extraction() {
let did = Did::new("did:plc:asdf123").unwrap();
assert_eq!(did.method(), "plc");
let did = Did::new("did:web:example.com").unwrap();
assert_eq!(did.method(), "web");
}
#[test]
fn serde_roundtrip() {
let did = Did::new("did:plc:asdf123").unwrap();
let json = serde_json::to_string(&did).unwrap();
assert_eq!(json, "\"did:plc:asdf123\"");
let parsed: Did = serde_json::from_str(&json).unwrap();
assert_eq!(parsed, did);
}
#[test]
fn max_length() {
let long_did = format!("did:m:{}", "a".repeat(MAX_DID_LENGTH));
assert!(Did::new(&long_did).is_err());
}
}