ratproto-did 0.0.3

A highly-optimized library for atproto DIDs.
Documentation
use std::{
    cmp::Ordering,
    fmt::{Debug, Display, Formatter},
    hash::{Hash, Hasher},
    mem::ManuallyDrop,
    ptr, slice,
    str::FromStr,
};

pub struct DidWebDomain {
    ptr: *const u8,
    len: u8, // A web domain is at most 255 chars long
}

impl DidWebDomain {
    /// Domain bytes, valid ASCII.
    pub fn as_bytes(&self) -> &[u8] {
        // SAFETY: ptr and len come from a boxed slice
        unsafe { slice::from_raw_parts(self.ptr, self.len as usize) }
    }

    /// Domain as a string slice, valid ASCII.
    pub fn as_str(&self) -> &str {
        // SAFETY: domain validation ensures the contents are valid ASCII
        unsafe { str::from_utf8_unchecked(self.as_bytes()) }
    }

    /// Domain length
    #[allow(dead_code)] // May be useful internally later
    #[allow(clippy::len_without_is_empty)] // Never empty
    pub const fn len(&self) -> usize {
        self.len as usize
    }

    /// Convenience method for [`Self::len()`], since domain length fits in a byte.
    #[allow(dead_code)] // May be useful internally later
    pub const fn len_u8(&self) -> u8 {
        self.len
    }

    /// Clones a byte string into a new `WebDomain`.
    ///
    /// Expects the input to be a valid domain, and at most 255 chars.
    unsafe fn clone_from_byte_string_unchecked(bytes: &[u8]) -> Self {
        // Clone to a boxed array, and extract the pointer and length
        // The length fits within a byte
        let boxed = clone_to_heap(bytes);
        let heap_slice = Box::into_raw(boxed); // This type takes ownership

        let ptr = heap_slice as *const u8;
        debug_assert!(heap_slice.len() <= 255);
        let len = heap_slice.len() as u8;
        DidWebDomain { ptr, len }
    }

    /// Turn a pointer and length back into a DidWebDomain. This transfers ownership!
    ///
    /// DidWebDomain is an owning type, so use [`ManuallyDrop`] or [`std::mem::forget()`]
    /// to avoid dropping the contents.
    pub(crate) unsafe fn from_raw_parts(ptr: *const u8, len: u8) -> Self {
        DidWebDomain { ptr, len }
    }

    /// Returns the contained pointer and length without dropping the data.
    pub(crate) fn into_raw_parts(self) -> (*const u8, u8) {
        let this = ManuallyDrop::new(self);
        (this.ptr, this.len)
    }
}

impl Drop for DidWebDomain {
    fn drop(&mut self) {
        // Re-box the same way the pointer & length were obtained,
        // and let the Box drop the contents
        // SAFETY: ptr and len come from a boxed slice
        let heap_slice = ptr::slice_from_raw_parts(self.ptr, self.len as usize);
        let boxed = unsafe { Box::from_raw(heap_slice as *mut [u8]) };
        drop(boxed);
    }
}

impl FromStr for DidWebDomain {
    type Err = ();

    fn from_str(input: &str) -> Result<Self, Self::Err> {
        if !validate_domain(input) {
            return Err(());
        }
        // SAFETY: input has been validated (valid domain, len <= 255)
        Ok(unsafe { DidWebDomain::clone_from_byte_string_unchecked(input.as_bytes()) })
    }
}

fn clone_to_heap(slice: &[u8]) -> Box<[u8]> {
    let mut boxed = Box::new_uninit_slice(slice.len());
    for (to, from) in boxed.iter_mut().zip(slice.iter()) {
        to.write(*from);
    }
    unsafe { boxed.assume_init() }
}

/// Validates a domain. Case-insensitive.
fn validate_domain(s: &str) -> bool {
    // Expected to be non-empty
    if s.is_empty() {
        return false;
    }

    // Note: this expects an ASCII or punycoded domain
    if !s.is_ascii() {
        return false;
    }
    let s = s.as_bytes();

    // Max domain length
    if s.len() > 255 {
        return false;
    }

    // port is allowed specifically for localhost (`:` is percent-encoded)
    // Should this care about case-insensitivity? dunno
    if let Some(port_bytes) = s.strip_prefix(b"localhost%3A") {
        // SAFETY: comes from valid ASCII
        let port_str = unsafe { str::from_utf8_unchecked(port_bytes) };
        return u16::from_str(port_str).is_ok();
    }

    for label in s.split(|x| *x == b'.') {
        if !validate_label(label) {
            return false;
        }
    }

    true
}

/// Validates a domain segment/label. The input should be an ASCII byte string.
fn validate_label(label: &[u8]) -> bool {
    if label.is_empty() {
        return false;
    }
    let len = label.len();
    let first = *label.first().unwrap();
    let last = *label.last().unwrap();
    let middle = label.get(1..(len - 1)).unwrap_or(&[]);

    // Note: atproto handles do not disallow non-TLD segments starting with a digit.
    // RFC1035 seems to forbid this, requiring that each segment begins exactly with a letter.

    len <= 63
        && first.is_ascii_alphabetic()
        && last.is_ascii_alphanumeric()
        && middle.iter().all(|&x| x == b'-' || x.is_ascii_alphanumeric())
}

impl PartialEq for DidWebDomain {
    fn eq(&self, other: &Self) -> bool {
        self.as_bytes() == other.as_bytes()
    }
}

impl Eq for DidWebDomain {}

impl Clone for DidWebDomain {
    fn clone(&self) -> Self {
        // SAFETY: own bytes are valid
        unsafe { DidWebDomain::clone_from_byte_string_unchecked(self.as_bytes()) }
    }
}

impl Hash for DidWebDomain {
    fn hash<H: Hasher>(&self, state: &mut H) {
        self.as_bytes().hash(state);
    }
}

impl Ord for DidWebDomain {
    fn cmp(&self, other: &Self) -> Ordering {
        self.as_bytes().cmp(other.as_bytes())
    }
}

impl PartialOrd for DidWebDomain {
    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
        Some(self.cmp(other))
    }
}

impl Display for DidWebDomain {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        Display::fmt(&self.as_str(), f)
    }
}

impl Debug for DidWebDomain {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        f.debug_tuple("DidWebDomain").field(&self.as_str()).finish()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test_case::test_case("metaflame.dev")]
    #[test_case::test_case("example.com")]
    #[test_case::test_case("Example.Com"; "capitalization")]
    #[test_case::test_case("a.example.com")]
    #[test_case::test_case("abc.example.com")]
    #[test_case::test_case("a1-2.example2.com")]
    #[test_case::test_case(
        "sixtythreecharacterslongsegment-sixtythreecharacterslongsegment.example.com"
    )]
    #[test_case::test_case("localhost%3A1234")]
    #[test_case::test_case("localhost")]
    #[test_case::test_case("abc")]
    fn valid_domain(domain: &str) {
        let wd = DidWebDomain::from_str(domain).unwrap();
        assert_eq!(wd.as_str(), domain);
    }

    #[test_case::test_case("example-.com"; "trailing hyphen")]
    #[test_case::test_case("-example.com"; "leading hyphen")]
    #[test_case::test_case("exam_ple.com"; "underscore")]
    #[test_case::test_case("exam%ple.com"; "percent")]
    #[test_case::test_case("1example.com"; "leading digit")]
    #[test_case::test_case("abc..com"; "empty segment")]
    #[test_case::test_case(
        "sixtyfourcharacterslongsegment----sixtyfourcharacterslongsegment.example.com"; "64-char segment"
    )]
    #[test_case::test_case("localhost%3A65536"; "localhost port 65536")]
    #[test_case::test_case("localhost%3A"; "localhost trailing %3A")]
    fn invalid_domain(domain: &str) {
        DidWebDomain::from_str(domain).expect_err("parsing should fail");
    }
}