cts-common 0.34.1-alpha.3

Common types and traits used across the CipherStash ecosystem
Documentation
use miette::Diagnostic;
use serde::{Deserialize, Serialize};
use std::{fmt::Display, str::FromStr};
use thiserror::Error;
use utoipa::{
    openapi::{schema::SchemaType, Type},
    PartialSchema, ToSchema,
};

#[cfg(feature = "test_utils")]
use fake::{Dummy, Faker};

/// Defines the region of a CipherStash service.
/// A region in CipherStash is defined by the region identifier and the provider separated by a dot.
/// For example, `us-west-2.aws` is a valid region identifier and refers to the AWS region `us-west-2`.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(untagged)]
pub enum Region {
    #[serde(
        serialize_with = "AwsRegion::serialize_with_suffix",
        deserialize_with = "AwsRegion::deserialize_with_suffix"
    )]
    Aws(AwsRegion),
}

impl ToSchema for Region {
    fn name() -> std::borrow::Cow<'static, str> {
        "Region".into()
    }
}
impl PartialSchema for Region {
    fn schema() -> utoipa::openapi::RefOr<utoipa::openapi::schema::Schema> {
        utoipa::openapi::ObjectBuilder::new()
            .schema_type(SchemaType::Type(Type::String))
            .enum_values(Some(Region::all().iter().map(|r| r.identifier())))
            .into()
    }
}

impl FromStr for Region {
    type Err = RegionError;

    #[inline]
    fn from_str(s: &str) -> Result<Self, Self::Err> {
        Region::new(s)
    }
}

impl Display for Region {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Region::Aws(region) => write!(f, "{region}.aws"),
        }
    }
}

impl PartialEq<&str> for Region {
    fn eq(&self, other: &&str) -> bool {
        self.identifier() == *other
    }
}

#[derive(Debug, Error, Diagnostic, PartialEq, Eq)]
pub enum RegionError {
    // TODO: Use miette to specify parts of the region that are invalid
    // Consider making this a separate error type
    #[error("Invalid region: {0}")]
    #[diagnostic(help(
        "Region identifiers are in the format `<region>.<provider>` (e.g. 'us-west-2.aws')"
    ))]
    InvalidRegion(String),

    #[error("Host or endpoint does not contain a valid region: `{0}`")]
    InvalidHostFqdn(String),
}

impl Region {
    pub fn all() -> Vec<Self> {
        AwsRegion::all().into_iter().map(Self::Aws).collect()
    }

    /// Creates a new region from an identifier.
    /// Region identifiers are in the format `<region>.<provider>`.
    ///
    /// For example, `us-west-2.aws` is a valid region identifier.
    ///
    /// # Example
    ///
    /// ```
    /// use cts_common::{AwsRegion, Region};
    /// let region = Region::new("us-west-2.aws").unwrap();
    /// assert_eq!(region, Region::Aws(AwsRegion::UsWest2));
    /// ```
    ///
    pub fn new(identifier: &str) -> Result<Self, RegionError> {
        if identifier.ends_with(".aws") {
            let region = identifier.trim_end_matches(".aws");
            Self::aws(region)
        } else {
            Err(RegionError::InvalidRegion(format!(
                "Missing or unknown provider (e.g. '.aws' suffix on '{identifier}')"
            )))
        }
    }

    /// Creates a new AWS region from an identifier.
    /// Note that this is not the complete list of AWS regions, only the ones that are currently supported by CipherStash.
    ///
    /// # Example
    ///
    /// ```
    /// use cts_common::{Region, AwsRegion};
    /// let region = Region::aws("us-west-2").unwrap();
    /// assert_eq!(region, Region::Aws(AwsRegion::UsWest2));
    /// ```
    ///
    pub fn aws(identifier: &str) -> Result<Self, RegionError> {
        AwsRegion::try_from(identifier).map(Self::Aws)
    }

    pub fn identifier(&self) -> String {
        match self {
            Region::Aws(region) => format!("{}.aws", region.identifier()),
        }
    }

    pub fn name(&self) -> &'static str {
        match self {
            Region::Aws(region) => region.name(),
        }
    }
}

#[cfg(feature = "test_utils")]
impl Dummy<Faker> for Region {
    fn dummy_with_rng<R>(_: &Faker, rng: &mut R) -> Self
    where
        R: rand::Rng + ?Sized,
    {
        let aws_regions = AwsRegion::all();
        let choice = rng.gen_range(0..aws_regions.len());
        Region::Aws(aws_regions[choice])
    }
}

#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, utoipa::ToSchema)]
#[serde(rename_all = "kebab-case")]
pub enum AwsRegion {
    ApSoutheast2,
    CaCentral1,
    EuCentral1,
    EuWest1,
    UsEast1,
    UsEast2,
    UsWest1,
    UsWest2,
}

impl AwsRegion {
    pub const ALL: [Self; 8] = [
        Self::ApSoutheast2,
        Self::CaCentral1,
        Self::EuCentral1,
        Self::EuWest1,
        Self::UsEast1,
        Self::UsEast2,
        Self::UsWest1,
        Self::UsWest2,
    ];

    pub fn all() -> Vec<Self> {
        Self::ALL.to_vec()
    }

    pub fn identifier(&self) -> &'static str {
        match self {
            Self::ApSoutheast2 => "ap-southeast-2",
            Self::CaCentral1 => "ca-central-1",
            Self::EuCentral1 => "eu-central-1",
            Self::EuWest1 => "eu-west-1",
            Self::UsEast1 => "us-east-1",
            Self::UsEast2 => "us-east-2",
            Self::UsWest1 => "us-west-1",
            Self::UsWest2 => "us-west-2",
        }
    }

    pub fn name(&self) -> &'static str {
        match self {
            Self::ApSoutheast2 => "Asia Pacific (Sydney)",
            Self::CaCentral1 => "Canada (Central)",
            Self::EuCentral1 => "Europe (Frankfurt)",
            Self::EuWest1 => "Europe (Ireland)",
            Self::UsEast1 => "US East (N. Virginia)",
            Self::UsEast2 => "US East (Ohio)",
            Self::UsWest1 => "US West (N. California)",
            Self::UsWest2 => "US West (Oregon)",
        }
    }

    pub fn serialize_with_suffix<S>(region: &AwsRegion, serializer: S) -> Result<S::Ok, S::Error>
    where
        S: serde::Serializer,
    {
        serializer.serialize_str(&format!("{}.aws", region.identifier()))
    }

    pub fn deserialize_with_suffix<'de, D>(deserializer: D) -> Result<AwsRegion, D::Error>
    where
        D: serde::Deserializer<'de>,
    {
        let region = String::deserialize(deserializer)?;
        region
            .trim_end_matches(".aws")
            .try_into()
            .map_err(serde::de::Error::custom)
    }
}

impl TryFrom<&str> for AwsRegion {
    type Error = RegionError;

    fn try_from(value: &str) -> Result<Self, Self::Error> {
        AwsRegion::ALL
            .iter()
            .find(|r| r.identifier() == value)
            .copied()
            .ok_or_else(|| RegionError::InvalidRegion(value.to_string()))
    }
}

/// Implement TryFrom for (&str, &str) to support the format "<region>.<provider>"
///
/// # Example
///
/// ```
/// use cts_common::{AwsRegion, Region, RegionError};
/// use std::convert::TryFrom;
///
/// let region = Region::try_from(("us-west-2", "aws"));
/// assert_eq!(region, Ok(Region::Aws(AwsRegion::UsWest2)));
/// ```
///
impl TryFrom<(&str, &str)> for Region {
    type Error = RegionError;

    fn try_from(value: (&str, &str)) -> Result<Self, Self::Error> {
        if value.1 == "aws" {
            AwsRegion::try_from(value.0).map(Region::Aws)
        } else {
            Err(RegionError::InvalidRegion(format!(
                "Invalid region: {}",
                value.0
            )))
        }
    }
}

impl Display for AwsRegion {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "{}", self.identifier())
    }
}

#[cfg(test)]
mod test {
    use super::*;

    #[test]
    fn test_region_new() {
        assert_eq!(
            Region::new("us-west-1.aws").unwrap(),
            Region::Aws(AwsRegion::UsWest1)
        );
        assert_eq!(
            Region::new("us-west-2.aws").unwrap(),
            Region::Aws(AwsRegion::UsWest2)
        );
        assert_eq!(
            Region::new("us-east-1.aws").unwrap(),
            Region::Aws(AwsRegion::UsEast1)
        );
        assert_eq!(
            Region::new("us-east-2.aws").unwrap(),
            Region::Aws(AwsRegion::UsEast2)
        );
        assert_eq!(
            Region::new("eu-west-1.aws").unwrap(),
            Region::Aws(AwsRegion::EuWest1)
        );
        assert_eq!(
            Region::new("eu-central-1.aws").unwrap(),
            Region::Aws(AwsRegion::EuCentral1)
        );
        assert_eq!(
            Region::new("ap-southeast-2.aws").unwrap(),
            Region::Aws(AwsRegion::ApSoutheast2)
        );
        assert_eq!(
            Region::new("ca-central-1.aws").unwrap(),
            Region::Aws(AwsRegion::CaCentral1)
        );
    }

    #[test]
    fn test_region_new_invalid() {
        let region = Region::new("us-west-2");
        assert!(region.is_err());
    }

    #[test]
    fn test_region_aws() {
        let region = Region::aws("us-west-2").unwrap();
        assert_eq!(region, Region::Aws(AwsRegion::UsWest2));
    }

    #[test]
    fn test_region_aws_invalid() {
        let region = Region::aws("us-west-3");
        assert!(region.is_err());
    }

    #[test]
    fn test_region_identifier() {
        let region = Region::aws("us-west-2").unwrap();
        assert_eq!(region.identifier(), "us-west-2.aws");
    }

    #[test]
    fn test_region_from_string() {
        let region = Region::from_str("us-west-2.aws").unwrap();
        assert_eq!(region, Region::Aws(AwsRegion::UsWest2));

        let region = Region::from_str("ap-southeast-2.aws").unwrap();
        assert_eq!(region, Region::Aws(AwsRegion::ApSoutheast2));
    }

    #[test]
    fn test_region_from_string_invalid_provider() {
        let region = Region::from_str("us-west-2.gcp");
        assert_eq!(
            region,
            Err(RegionError::InvalidRegion(
                "Missing or unknown provider (e.g. '.aws' suffix on 'us-west-2.gcp')".to_string()
            ))
        );
    }

    #[test]
    fn test_region_from_string_invalid_region() {
        let region = Region::from_str("us-invalid-2.aws");
        assert_eq!(
            region,
            Err(RegionError::InvalidRegion("us-invalid-2".to_string()))
        );
    }

    mod aws {
        use super::*;

        #[test]
        fn test_aws_region_identifier() {
            assert_eq!(AwsRegion::UsWest1.identifier(), "us-west-1");
            assert_eq!(AwsRegion::UsWest2.identifier(), "us-west-2");
            assert_eq!(AwsRegion::UsEast1.identifier(), "us-east-1");
            assert_eq!(AwsRegion::UsEast2.identifier(), "us-east-2");
            assert_eq!(AwsRegion::EuWest1.identifier(), "eu-west-1");
            assert_eq!(AwsRegion::EuCentral1.identifier(), "eu-central-1");
            assert_eq!(AwsRegion::ApSoutheast2.identifier(), "ap-southeast-2");
            assert_eq!(AwsRegion::CaCentral1.identifier(), "ca-central-1");
        }

        #[test]
        fn test_display() {
            assert_eq!(Region::Aws(AwsRegion::UsWest1).to_string(), "us-west-1.aws");
            assert_eq!(Region::Aws(AwsRegion::UsWest2).to_string(), "us-west-2.aws");
            assert_eq!(Region::Aws(AwsRegion::UsEast1).to_string(), "us-east-1.aws");
            assert_eq!(Region::Aws(AwsRegion::UsEast2).to_string(), "us-east-2.aws");
            assert_eq!(Region::Aws(AwsRegion::EuWest1).to_string(), "eu-west-1.aws");
            assert_eq!(
                Region::Aws(AwsRegion::EuCentral1).to_string(),
                "eu-central-1.aws"
            );
            assert_eq!(
                Region::Aws(AwsRegion::ApSoutheast2).to_string(),
                "ap-southeast-2.aws"
            );
            assert_eq!(
                Region::Aws(AwsRegion::CaCentral1).to_string(),
                "ca-central-1.aws"
            );
        }
    }
}