cts-common 0.34.1-alpha.3

Common types and traits used across the CipherStash ecosystem
Documentation
use crate::{AwsRegion, Region, WorkspaceId};
use miette::Diagnostic;
use nom::{
    bytes::complete::{tag, take_while1, take_while_m_n},
    combinator::{all_consuming, opt},
    error::ErrorKind,
    sequence::{preceded, separated_pair},
    IResult, Parser,
};
use serde::{Deserialize, Serialize};
use std::{fmt::Display, str::FromStr};
use thiserror::Error;

#[derive(Error, Debug, Diagnostic)]
pub enum InvalidCrn {
    #[error("Invalid CRN: {0}")]
    #[diagnostic(help = "CRN format: `crn:<region>:<workspace_id>[:<service_name>]`")]
    InvalidFormat(String),

    #[error(transparent)]
    #[diagnostic(transparent)]
    InvalidRegion(#[from] crate::region::RegionError),

    #[error(transparent)]
    #[diagnostic(transparent)]
    InvalidWorkspaceId(#[from] crate::workspace::InvalidWorkspaceId),
}

impl InvalidCrn {
    pub fn invalid_format(input: &str) -> Self {
        Self::InvalidFormat(input.to_string())
    }
}

pub trait AsCrn {
    /// Converts the implementing type to a CRN
    fn as_crn(&self) -> Crn;
}

// TODO: Make some inner type variants of this to handle when a service name is present or not (and for other extensions)

#[derive(Debug, Clone, PartialEq, Eq)]
/// Represents CRNs (CipherStash Resource Names)
pub struct Crn {
    /// The workspace ID
    pub workspace_id: WorkspaceId,

    /// The region
    pub region: Region,

    /// An optional service name
    pub service_name: Option<String>,
}

impl Crn {
    /// Creates a new CRN
    pub fn new(region: Region, workspace_id: WorkspaceId) -> Self {
        Self {
            workspace_id,
            region,
            service_name: None,
        }
    }

    pub fn with_service_name(mut self, service_name: &str) -> Self {
        self.service_name = Some(service_name.into());
        self
    }
}

impl Serialize for Crn {
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
    where
        S: serde::Serializer,
    {
        let s = self.to_string();
        serializer.serialize_str(&s)
    }
}

impl<'de> Deserialize<'de> for Crn {
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
    where
        D: serde::Deserializer<'de>,
    {
        let s = String::deserialize(deserializer)?;
        Crn::try_from(s.as_str()).map_err(serde::de::Error::custom)
    }
}

impl Display for Crn {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "crn:{}:{}", self.region, self.workspace_id)?;
        if let Some(service_name) = &self.service_name {
            write!(f, ":{service_name}")?;
        }
        Ok(())
    }
}

impl TryFrom<&str> for Crn {
    type Error = InvalidCrn;

    fn try_from(value: &str) -> Result<Self, Self::Error> {
        parse_crn(value)
    }
}

impl FromStr for Crn {
    type Err = InvalidCrn;

    fn from_str(value: &str) -> Result<Self, Self::Err> {
        Self::try_from(value)
    }
}

// TODO: Move all of this into a submodule

/// Parse the "geo" part of the region (e.g. "us-east-1")
/// Uses `AwsRegion::ALL` as the single source of truth for valid regions.
///
/// # Invariant
/// This loop uses prefix matching (`nom::tag`). No region identifier may be a
/// prefix of another — if one were, the shorter match would win and leave the
/// remainder of the longer identifier unparsed, breaking that region's CRN parsing.
/// The `no_region_identifier_is_prefix_of_another` test in this module enforces this.
fn region_geo(input: &str) -> IResult<&str, AwsRegion> {
    for region in AwsRegion::ALL.iter() {
        if let Ok((rest, _)) =
            tag::<&str, &str, nom::error::Error<&str>>(region.identifier())(input)
        {
            return Ok((rest, *region));
        }
    }
    // Use Alt error kind to match the semantics of nom's alt() combinator,
    // as originally implemented.
    Err(nom::Err::Error(nom::error::Error::new(
        input,
        ErrorKind::Alt,
    )))
}

/// Parse the "vendor" part of the region (e.g. "aws")
/// Only AWS is supported for now.
#[inline]
fn region_vendor(input: &str) -> IResult<&str, &str> {
    tag("aws")(input)
}

/// Parse the region (e.g. "us-east-1.aws")
#[inline]
fn region(input: &str) -> IResult<&str, Region, nom::error::Error<&str>> {
    separated_pair(region_geo, tag("."), region_vendor)
        .parse(input)
        .map(|(rest, (aws_region, _))| (rest, Region::Aws(aws_region)))
}

/// Parse the workspace ID (e.g. "ZVATKW3VHMFG27DY")
/// The workspace ID must be 20 alphanumeric characters
#[inline]
fn workspace_id(input: &str) -> IResult<&str, WorkspaceId, nom::error::Error<&str>> {
    // parse the workspace ID
    take_while_m_n(16, 16, |c: char| c.is_alphanumeric())(input).map(|(rest, id)| {
        // Convert the ID to a WorkspaceId
        // SAFETY: The ID is already validated to be 16 alphanumeric characters
        // TODO: use the parse method on the inner ArrayString
        let id = WorkspaceId::try_from(id).expect("Invalid workspace ID");
        (rest, id)
    })
}

fn service_name_chars(input: &str) -> IResult<&str, &str> {
    // parse the service name
    let (rest, service_name) =
        take_while1(|c: char| c.is_alphanumeric() || c == '-' || c == '_').parse(input)?;
    Ok((rest, service_name))
}

fn parse_crn(input: &str) -> Result<Crn, InvalidCrn> {
    let (_, (region, workspace_id, service_name)) = all_consuming((
        preceded(tag("crn:"), region),
        preceded(tag(":"), workspace_id),
        opt(preceded(tag(":"), service_name_chars)),
    ))
    .parse(input)
    .map_err(|_| InvalidCrn::invalid_format(input))?;

    Ok(Crn {
        region,
        workspace_id,
        service_name: service_name.map(String::from),
    })
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::AwsRegion;

    mod try_from_str {
        use super::*;

        #[test]
        fn success_valid_with_service() {
            let region = Region::new("us-east-1.aws").unwrap();
            let workspace_id = WorkspaceId::try_from("ZVATKW3VHMFG27DY").unwrap();

            assert_eq!(
                Crn::try_from("crn:us-east-1.aws:ZVATKW3VHMFG27DY:service_name").unwrap(),
                Crn::new(region, workspace_id).with_service_name("service_name")
            );

            assert_eq!(
                Crn::try_from("crn:us-east-1.aws:ZVATKW3VHMFG27DY:service-name").unwrap(),
                Crn::new(region, workspace_id).with_service_name("service-name")
            );
        }

        #[test]
        fn success_valid_without_service() {
            let crn_str = "crn:us-east-1.aws:ZVATKW3VHMFG27DY";
            let crn = Crn::try_from(crn_str).unwrap();
            assert_eq!(crn.region, Region::Aws(AwsRegion::UsEast1));
            assert_eq!(crn.workspace_id.to_string(), "ZVATKW3VHMFG27DY");
            assert!(crn.service_name.is_none());
        }

        #[test]
        fn success_ca_central_1() {
            let crn_str = "crn:ca-central-1.aws:ZVATKW3VHMFG27DY";
            let crn = Crn::try_from(crn_str).unwrap();
            assert_eq!(crn.region, Region::Aws(AwsRegion::CaCentral1));
            assert_eq!(crn.workspace_id.to_string(), "ZVATKW3VHMFG27DY");
            assert!(crn.service_name.is_none());
        }

        #[test]
        fn all_regions_roundtrip_in_crn() {
            let workspace_id = "ZVATKW3VHMFG27DY";
            for region in AwsRegion::all() {
                let crn_str = format!("crn:{}.aws:{}", region.identifier(), workspace_id);
                let crn = Crn::try_from(crn_str.as_str()).unwrap_or_else(|err| {
                    panic!(
                        "Failed to parse CRN for region {}: {}",
                        region.identifier(),
                        err
                    )
                });
                assert_eq!(crn.region, Region::Aws(region));
                // Also verify round-trip through Display
                assert_eq!(crn.to_string(), crn_str);
            }
        }

        #[test]
        fn test_invalid_crn() {
            assert!(Crn::try_from("invalid_crn").is_err());
            assert!(Crn::try_from("crn:invalid_crn").is_err());
            // Trailing colon
            assert!(Crn::try_from("crn:us-east-1.aws:ZVATKW3VHMFG27DY:").is_err());
            // Extra parts
            assert!(
                Crn::try_from("crn:us-east-1.aws:ZVATKW3VHMFG27DY:service_name:extra").is_err()
            );
            // Extra extra parts
            assert!(
                Crn::try_from("crn:us-east-1.aws:ZVATKW3VHMFG27DY:service_name:extra:extra")
                    .is_err()
            );
            // Invalid workspace ID
            assert!(Crn::try_from("crn:us-east-1.aws:ZVATKW3VH").is_err());
            // Invalid region
            assert!(Crn::try_from("crn:us-east-1:ZVATKW3VHMFG27DY").is_err());
            // Missing CRN prefix
            assert!(Crn::try_from("us-east-1.aws:ZVATKW3VHMFG27DY:service_name").is_err());
        }
    }

    mod display {
        use super::*;

        #[test]
        fn test_with_workspace_id() {
            let workspace_id = WorkspaceId::generate().unwrap();
            let crn = Crn::new(Region::new("us-east-1.aws").unwrap(), workspace_id);
            assert_eq!(crn.to_string(), format!("crn:us-east-1.aws:{workspace_id}"));
        }

        #[test]
        fn test_ca_central_1_round_trip() {
            let workspace_id = WorkspaceId::generate().unwrap();
            let crn = Crn::new(Region::new("ca-central-1.aws").unwrap(), workspace_id);
            assert_eq!(
                crn.to_string(),
                format!("crn:ca-central-1.aws:{workspace_id}")
            );
        }

        #[test]
        fn test_with_workspace_id_and_service() {
            let workspace_id = WorkspaceId::generate().unwrap();
            let crn = Crn::new(Region::new("us-east-1.aws").unwrap(), workspace_id)
                .with_service_name("zerokms");
            assert_eq!(
                crn.to_string(),
                format!("crn:us-east-1.aws:{workspace_id}:zerokms")
            );
        }
    }

    #[test]
    fn no_region_identifier_is_prefix_of_another() {
        let identifiers: Vec<&str> = AwsRegion::ALL.iter().map(|r| r.identifier()).collect();
        for (i, a) in identifiers.iter().enumerate() {
            for (j, b) in identifiers.iter().enumerate() {
                if i != j {
                    assert!(
                        !b.starts_with(a),
                        "region identifier {:?} is a prefix of {:?} — \
                         region_geo() would match {:?} first, making {:?} unparseable",
                        a,
                        b,
                        a,
                        b
                    );
                }
            }
        }
    }
}