Skip to main content

cts_common/
crn.rs

1use crate::{AwsRegion, Region, WorkspaceId};
2use miette::Diagnostic;
3use nom::{
4    bytes::complete::{tag, take_while1, take_while_m_n},
5    combinator::{all_consuming, opt},
6    error::ErrorKind,
7    sequence::{preceded, separated_pair},
8    IResult, Parser,
9};
10use serde::{Deserialize, Serialize};
11use std::{fmt::Display, str::FromStr};
12use thiserror::Error;
13
14#[derive(Error, Debug, Diagnostic)]
15pub enum InvalidCrn {
16    #[error("Invalid CRN: {0}")]
17    #[diagnostic(help = "CRN format: `crn:<region>:<workspace_id>[:<service_name>]`")]
18    InvalidFormat(String),
19
20    #[error(transparent)]
21    #[diagnostic(transparent)]
22    InvalidRegion(#[from] crate::region::RegionError),
23
24    #[error(transparent)]
25    #[diagnostic(transparent)]
26    InvalidWorkspaceId(#[from] crate::workspace::InvalidWorkspaceId),
27}
28
29impl InvalidCrn {
30    pub fn invalid_format(input: &str) -> Self {
31        Self::InvalidFormat(input.to_string())
32    }
33}
34
35pub trait AsCrn {
36    /// Converts the implementing type to a CRN
37    fn as_crn(&self) -> Crn;
38}
39
40// TODO: Make some inner type variants of this to handle when a service name is present or not (and for other extensions)
41
42#[derive(Debug, Clone, PartialEq, Eq)]
43/// Represents CRNs (CipherStash Resource Names)
44pub struct Crn {
45    /// The workspace ID
46    pub workspace_id: WorkspaceId,
47
48    /// The region
49    pub region: Region,
50
51    /// An optional service name
52    pub service_name: Option<String>,
53}
54
55impl Crn {
56    /// Creates a new CRN
57    pub fn new(region: Region, workspace_id: WorkspaceId) -> Self {
58        Self {
59            workspace_id,
60            region,
61            service_name: None,
62        }
63    }
64
65    pub fn with_service_name(mut self, service_name: &str) -> Self {
66        self.service_name = Some(service_name.into());
67        self
68    }
69}
70
71impl Serialize for Crn {
72    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
73    where
74        S: serde::Serializer,
75    {
76        let s = self.to_string();
77        serializer.serialize_str(&s)
78    }
79}
80
81impl<'de> Deserialize<'de> for Crn {
82    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
83    where
84        D: serde::Deserializer<'de>,
85    {
86        let s = String::deserialize(deserializer)?;
87        Crn::try_from(s.as_str()).map_err(serde::de::Error::custom)
88    }
89}
90
91impl Display for Crn {
92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93        write!(f, "crn:{}:{}", self.region, self.workspace_id)?;
94        if let Some(service_name) = &self.service_name {
95            write!(f, ":{service_name}")?;
96        }
97        Ok(())
98    }
99}
100
101impl TryFrom<&str> for Crn {
102    type Error = InvalidCrn;
103
104    fn try_from(value: &str) -> Result<Self, Self::Error> {
105        parse_crn(value)
106    }
107}
108
109impl FromStr for Crn {
110    type Err = InvalidCrn;
111
112    fn from_str(value: &str) -> Result<Self, Self::Err> {
113        Self::try_from(value)
114    }
115}
116
117// TODO: Move all of this into a submodule
118
119/// Parse the "geo" part of the region (e.g. "us-east-1")
120/// Uses `AwsRegion::ALL` as the single source of truth for valid regions.
121///
122/// # Invariant
123/// This loop uses prefix matching (`nom::tag`). No region identifier may be a
124/// prefix of another — if one were, the shorter match would win and leave the
125/// remainder of the longer identifier unparsed, breaking that region's CRN parsing.
126/// The `no_region_identifier_is_prefix_of_another` test in this module enforces this.
127fn region_geo(input: &str) -> IResult<&str, AwsRegion> {
128    for region in AwsRegion::ALL.iter() {
129        if let Ok((rest, _)) =
130            tag::<&str, &str, nom::error::Error<&str>>(region.identifier())(input)
131        {
132            return Ok((rest, *region));
133        }
134    }
135    // Use Alt error kind to match the semantics of nom's alt() combinator,
136    // as originally implemented.
137    Err(nom::Err::Error(nom::error::Error::new(
138        input,
139        ErrorKind::Alt,
140    )))
141}
142
143/// Parse the "vendor" part of the region (e.g. "aws")
144/// Only AWS is supported for now.
145#[inline]
146fn region_vendor(input: &str) -> IResult<&str, &str> {
147    tag("aws")(input)
148}
149
150/// Parse the region (e.g. "us-east-1.aws")
151#[inline]
152fn region(input: &str) -> IResult<&str, Region, nom::error::Error<&str>> {
153    separated_pair(region_geo, tag("."), region_vendor)
154        .parse(input)
155        .map(|(rest, (aws_region, _))| (rest, Region::Aws(aws_region)))
156}
157
158/// Parse the workspace ID (e.g. "ZVATKW3VHMFG27DY")
159/// The workspace ID must be 20 alphanumeric characters
160#[inline]
161fn workspace_id(input: &str) -> IResult<&str, WorkspaceId, nom::error::Error<&str>> {
162    // parse the workspace ID
163    take_while_m_n(16, 16, |c: char| c.is_alphanumeric())(input).map(|(rest, id)| {
164        // Convert the ID to a WorkspaceId
165        // SAFETY: The ID is already validated to be 16 alphanumeric characters
166        // TODO: use the parse method on the inner ArrayString
167        let id = WorkspaceId::try_from(id).expect("Invalid workspace ID");
168        (rest, id)
169    })
170}
171
172fn service_name_chars(input: &str) -> IResult<&str, &str> {
173    // parse the service name
174    let (rest, service_name) =
175        take_while1(|c: char| c.is_alphanumeric() || c == '-' || c == '_').parse(input)?;
176    Ok((rest, service_name))
177}
178
179fn parse_crn(input: &str) -> Result<Crn, InvalidCrn> {
180    let (_, (region, workspace_id, service_name)) = all_consuming((
181        preceded(tag("crn:"), region),
182        preceded(tag(":"), workspace_id),
183        opt(preceded(tag(":"), service_name_chars)),
184    ))
185    .parse(input)
186    .map_err(|_| InvalidCrn::invalid_format(input))?;
187
188    Ok(Crn {
189        region,
190        workspace_id,
191        service_name: service_name.map(String::from),
192    })
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198    use crate::AwsRegion;
199
200    mod try_from_str {
201        use super::*;
202
203        #[test]
204        fn success_valid_with_service() {
205            let region = Region::new("us-east-1.aws").unwrap();
206            let workspace_id = WorkspaceId::try_from("ZVATKW3VHMFG27DY").unwrap();
207
208            assert_eq!(
209                Crn::try_from("crn:us-east-1.aws:ZVATKW3VHMFG27DY:service_name").unwrap(),
210                Crn::new(region, workspace_id).with_service_name("service_name")
211            );
212
213            assert_eq!(
214                Crn::try_from("crn:us-east-1.aws:ZVATKW3VHMFG27DY:service-name").unwrap(),
215                Crn::new(region, workspace_id).with_service_name("service-name")
216            );
217        }
218
219        #[test]
220        fn success_valid_without_service() {
221            let crn_str = "crn:us-east-1.aws:ZVATKW3VHMFG27DY";
222            let crn = Crn::try_from(crn_str).unwrap();
223            assert_eq!(crn.region, Region::Aws(AwsRegion::UsEast1));
224            assert_eq!(crn.workspace_id.to_string(), "ZVATKW3VHMFG27DY");
225            assert!(crn.service_name.is_none());
226        }
227
228        #[test]
229        fn success_ca_central_1() {
230            let crn_str = "crn:ca-central-1.aws:ZVATKW3VHMFG27DY";
231            let crn = Crn::try_from(crn_str).unwrap();
232            assert_eq!(crn.region, Region::Aws(AwsRegion::CaCentral1));
233            assert_eq!(crn.workspace_id.to_string(), "ZVATKW3VHMFG27DY");
234            assert!(crn.service_name.is_none());
235        }
236
237        #[test]
238        fn all_regions_roundtrip_in_crn() {
239            let workspace_id = "ZVATKW3VHMFG27DY";
240            for region in AwsRegion::all() {
241                let crn_str = format!("crn:{}.aws:{}", region.identifier(), workspace_id);
242                let crn = Crn::try_from(crn_str.as_str()).unwrap_or_else(|err| {
243                    panic!(
244                        "Failed to parse CRN for region {}: {}",
245                        region.identifier(),
246                        err
247                    )
248                });
249                assert_eq!(crn.region, Region::Aws(region));
250                // Also verify round-trip through Display
251                assert_eq!(crn.to_string(), crn_str);
252            }
253        }
254
255        #[test]
256        fn test_invalid_crn() {
257            assert!(Crn::try_from("invalid_crn").is_err());
258            assert!(Crn::try_from("crn:invalid_crn").is_err());
259            // Trailing colon
260            assert!(Crn::try_from("crn:us-east-1.aws:ZVATKW3VHMFG27DY:").is_err());
261            // Extra parts
262            assert!(
263                Crn::try_from("crn:us-east-1.aws:ZVATKW3VHMFG27DY:service_name:extra").is_err()
264            );
265            // Extra extra parts
266            assert!(
267                Crn::try_from("crn:us-east-1.aws:ZVATKW3VHMFG27DY:service_name:extra:extra")
268                    .is_err()
269            );
270            // Invalid workspace ID
271            assert!(Crn::try_from("crn:us-east-1.aws:ZVATKW3VH").is_err());
272            // Invalid region
273            assert!(Crn::try_from("crn:us-east-1:ZVATKW3VHMFG27DY").is_err());
274            // Missing CRN prefix
275            assert!(Crn::try_from("us-east-1.aws:ZVATKW3VHMFG27DY:service_name").is_err());
276        }
277    }
278
279    mod display {
280        use super::*;
281
282        #[test]
283        fn test_with_workspace_id() {
284            let workspace_id = WorkspaceId::generate().unwrap();
285            let crn = Crn::new(Region::new("us-east-1.aws").unwrap(), workspace_id);
286            assert_eq!(crn.to_string(), format!("crn:us-east-1.aws:{workspace_id}"));
287        }
288
289        #[test]
290        fn test_ca_central_1_round_trip() {
291            let workspace_id = WorkspaceId::generate().unwrap();
292            let crn = Crn::new(Region::new("ca-central-1.aws").unwrap(), workspace_id);
293            assert_eq!(
294                crn.to_string(),
295                format!("crn:ca-central-1.aws:{workspace_id}")
296            );
297        }
298
299        #[test]
300        fn test_with_workspace_id_and_service() {
301            let workspace_id = WorkspaceId::generate().unwrap();
302            let crn = Crn::new(Region::new("us-east-1.aws").unwrap(), workspace_id)
303                .with_service_name("zerokms");
304            assert_eq!(
305                crn.to_string(),
306                format!("crn:us-east-1.aws:{workspace_id}:zerokms")
307            );
308        }
309    }
310
311    #[test]
312    fn no_region_identifier_is_prefix_of_another() {
313        let identifiers: Vec<&str> = AwsRegion::ALL.iter().map(|r| r.identifier()).collect();
314        for (i, a) in identifiers.iter().enumerate() {
315            for (j, b) in identifiers.iter().enumerate() {
316                if i != j {
317                    assert!(
318                        !b.starts_with(a),
319                        "region identifier {:?} is a prefix of {:?} — \
320                         region_geo() would match {:?} first, making {:?} unparseable",
321                        a,
322                        b,
323                        a,
324                        b
325                    );
326                }
327            }
328        }
329    }
330}