Skip to main content

cts_common/
crn.rs

1use crate::{AwsRegion, Region, WorkspaceId};
2use miette::Diagnostic;
3use nom::{
4    bytes::complete::{tag, take_while1, take_while_m_n},
5    combinator::{all_consuming, opt},
6    error::ErrorKind,
7    sequence::{preceded, separated_pair},
8    IResult, Parser,
9};
10use serde::{Deserialize, Serialize};
11use std::{fmt::Display, str::FromStr};
12use thiserror::Error;
13
14#[derive(Error, Debug, Diagnostic)]
15pub enum InvalidCrn {
16    #[error("Invalid CRN: {0}")]
17    #[diagnostic(help = "CRN format: `crn:<region>:<workspace_id>[:<service_name>]`")]
18    InvalidFormat(String),
19
20    #[error(transparent)]
21    #[diagnostic(transparent)]
22    InvalidRegion(#[from] crate::region::RegionError),
23
24    #[error(transparent)]
25    #[diagnostic(transparent)]
26    InvalidWorkspaceId(#[from] crate::workspace::InvalidWorkspaceId),
27}
28
29impl InvalidCrn {
30    pub fn invalid_format(input: &str) -> Self {
31        Self::InvalidFormat(input.to_string())
32    }
33}
34
35pub trait AsCrn {
36    /// Converts the implementing type to a CRN
37    fn as_crn(&self) -> Crn;
38}
39
40// TODO: Make some inner type variants of this to handle when a service name is present or not (and for other extensions)
41
42#[derive(Debug, Clone, PartialEq, Eq)]
43/// Represents CRNs (CipherStash Resource Names)
44pub struct Crn {
45    /// The workspace ID
46    pub workspace_id: WorkspaceId,
47
48    /// The region
49    pub region: Region,
50
51    /// An optional service name
52    pub service_name: Option<String>,
53}
54
55impl Crn {
56    /// Creates a new CRN
57    pub fn new(region: Region, workspace_id: WorkspaceId) -> Self {
58        Self {
59            workspace_id,
60            region,
61            service_name: None,
62        }
63    }
64
65    pub fn with_service_name(mut self, service_name: &str) -> Self {
66        self.service_name = Some(service_name.into());
67        self
68    }
69}
70
71impl Serialize for Crn {
72    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
73    where
74        S: serde::Serializer,
75    {
76        let s = self.to_string();
77        serializer.serialize_str(&s)
78    }
79}
80
81impl<'de> Deserialize<'de> for Crn {
82    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
83    where
84        D: serde::Deserializer<'de>,
85    {
86        let s = String::deserialize(deserializer)?;
87        Crn::try_from(s.as_str()).map_err(serde::de::Error::custom)
88    }
89}
90
91impl Display for Crn {
92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93        write!(f, "crn:{}:{}", self.region, self.workspace_id)?;
94        if let Some(service_name) = &self.service_name {
95            write!(f, ":{service_name}")?;
96        }
97        Ok(())
98    }
99}
100
101impl TryFrom<&str> for Crn {
102    type Error = InvalidCrn;
103
104    fn try_from(value: &str) -> Result<Self, Self::Error> {
105        parse_crn(value)
106    }
107}
108
109impl FromStr for Crn {
110    type Err = InvalidCrn;
111
112    fn from_str(value: &str) -> Result<Self, Self::Err> {
113        Self::try_from(value)
114    }
115}
116
117// TODO: Move all of this into a submodule
118
119/// Parse the "geo" part of the region (e.g. "us-east-1")
120/// Uses `AwsRegion::ALL` as the single source of truth for valid regions.
121///
122/// # Invariant
123/// This loop uses prefix matching (`nom::tag`). No region identifier may be a
124/// prefix of another — if one were, the shorter match would win and leave the
125/// remainder of the longer identifier unparsed, breaking that region's CRN parsing.
126/// The `no_region_identifier_is_prefix_of_another` test in this module enforces this.
127fn region_geo(input: &str) -> IResult<&str, AwsRegion> {
128    for region in AwsRegion::ALL.iter() {
129        if let Ok((rest, _)) =
130            tag::<&str, &str, nom::error::Error<&str>>(region.identifier())(input)
131        {
132            return Ok((rest, *region));
133        }
134    }
135    // Use Alt error kind to match the semantics of nom's alt() combinator,
136    // as originally implemented.
137    Err(nom::Err::Error(nom::error::Error::new(
138        input,
139        ErrorKind::Alt,
140    )))
141}
142
143/// Parse the "vendor" part of the region (e.g. "aws")
144/// Only AWS is supported for now.
145#[inline]
146fn region_vendor(input: &str) -> IResult<&str, &str> {
147    tag("aws")(input)
148}
149
150/// Parse the region (e.g. "us-east-1.aws")
151#[inline]
152fn region(input: &str) -> IResult<&str, Region, nom::error::Error<&str>> {
153    separated_pair(region_geo, tag("."), region_vendor)
154        .parse(input)
155        .map(|(rest, (aws_region, _))| (rest, Region::Aws(aws_region)))
156}
157
158/// Parse the workspace ID (e.g. "ZVATKW3VHMFG27DY").
159///
160/// Takes 16 alphanumeric characters as a coarse delimiter, then fully validates
161/// them via `WorkspaceId::try_from`. `is_alphanumeric` is only a delimiter — it
162/// is Unicode-broad (accepts lowercase, `0`/`1`/`8`/`9` and non-ASCII) and so
163/// wider than the base32 alphabet `WorkspaceId` actually requires; `try_from`
164/// is the real validation. A failure becomes a recoverable nom error (which
165/// `parse_crn` maps to `InvalidCrn::InvalidFormat`) rather than a panic. See
166/// CIP-3239.
167#[inline]
168fn workspace_id(input: &str) -> IResult<&str, WorkspaceId, nom::error::Error<&str>> {
169    let (rest, id) = take_while_m_n(16, 16, |c: char| c.is_alphanumeric()).parse(input)?;
170    match WorkspaceId::try_from(id) {
171        Ok(workspace_id) => Ok((rest, workspace_id)),
172        Err(_) => Err(nom::Err::Error(nom::error::Error::new(
173            input,
174            ErrorKind::Verify,
175        ))),
176    }
177}
178
179fn service_name_chars(input: &str) -> IResult<&str, &str> {
180    // parse the service name
181    let (rest, service_name) =
182        take_while1(|c: char| c.is_alphanumeric() || c == '-' || c == '_').parse(input)?;
183    Ok((rest, service_name))
184}
185
186fn parse_crn(input: &str) -> Result<Crn, InvalidCrn> {
187    let (_, (region, workspace_id, service_name)) = all_consuming((
188        preceded(tag("crn:"), region),
189        preceded(tag(":"), workspace_id),
190        opt(preceded(tag(":"), service_name_chars)),
191    ))
192    .parse(input)
193    .map_err(|_| InvalidCrn::invalid_format(input))?;
194
195    Ok(Crn {
196        region,
197        workspace_id,
198        service_name: service_name.map(String::from),
199    })
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205    use crate::AwsRegion;
206
207    mod try_from_str {
208        use super::*;
209
210        #[test]
211        fn success_valid_with_service() {
212            let region = Region::new("us-east-1.aws").unwrap();
213            let workspace_id = WorkspaceId::try_from("ZVATKW3VHMFG27DY").unwrap();
214
215            assert_eq!(
216                Crn::try_from("crn:us-east-1.aws:ZVATKW3VHMFG27DY:service_name").unwrap(),
217                Crn::new(region, workspace_id).with_service_name("service_name")
218            );
219
220            assert_eq!(
221                Crn::try_from("crn:us-east-1.aws:ZVATKW3VHMFG27DY:service-name").unwrap(),
222                Crn::new(region, workspace_id).with_service_name("service-name")
223            );
224        }
225
226        #[test]
227        fn success_valid_without_service() {
228            let crn_str = "crn:us-east-1.aws:ZVATKW3VHMFG27DY";
229            let crn = Crn::try_from(crn_str).unwrap();
230            assert_eq!(crn.region, Region::Aws(AwsRegion::UsEast1));
231            assert_eq!(crn.workspace_id.to_string(), "ZVATKW3VHMFG27DY");
232            assert!(crn.service_name.is_none());
233        }
234
235        #[test]
236        fn success_ca_central_1() {
237            let crn_str = "crn:ca-central-1.aws:ZVATKW3VHMFG27DY";
238            let crn = Crn::try_from(crn_str).unwrap();
239            assert_eq!(crn.region, Region::Aws(AwsRegion::CaCentral1));
240            assert_eq!(crn.workspace_id.to_string(), "ZVATKW3VHMFG27DY");
241            assert!(crn.service_name.is_none());
242        }
243
244        #[test]
245        fn all_regions_roundtrip_in_crn() {
246            let workspace_id = "ZVATKW3VHMFG27DY";
247            for region in AwsRegion::all() {
248                let crn_str = format!("crn:{}.aws:{}", region.identifier(), workspace_id);
249                let crn = Crn::try_from(crn_str.as_str()).unwrap_or_else(|err| {
250                    panic!(
251                        "Failed to parse CRN for region {}: {}",
252                        region.identifier(),
253                        err
254                    )
255                });
256                assert_eq!(crn.region, Region::Aws(region));
257                // Also verify round-trip through Display
258                assert_eq!(crn.to_string(), crn_str);
259            }
260        }
261
262        #[test]
263        fn test_invalid_crn() {
264            assert!(Crn::try_from("invalid_crn").is_err());
265            assert!(Crn::try_from("crn:invalid_crn").is_err());
266            // Trailing colon
267            assert!(Crn::try_from("crn:us-east-1.aws:ZVATKW3VHMFG27DY:").is_err());
268            // Extra parts
269            assert!(
270                Crn::try_from("crn:us-east-1.aws:ZVATKW3VHMFG27DY:service_name:extra").is_err()
271            );
272            // Extra extra parts
273            assert!(
274                Crn::try_from("crn:us-east-1.aws:ZVATKW3VHMFG27DY:service_name:extra:extra")
275                    .is_err()
276            );
277            // Invalid workspace ID
278            assert!(Crn::try_from("crn:us-east-1.aws:ZVATKW3VH").is_err());
279            // Invalid region
280            assert!(Crn::try_from("crn:us-east-1:ZVATKW3VHMFG27DY").is_err());
281            // Missing CRN prefix
282            assert!(Crn::try_from("us-east-1.aws:ZVATKW3VHMFG27DY:service_name").is_err());
283        }
284
285        /// Regression for CIP-3239: the `workspace_id` sub-parser used
286        /// `.expect()` on `WorkspaceId::try_from`. The `is_alphanumeric`
287        /// delimiter is broader than the base32 alphabet, so a 16-char-but-
288        /// invalid workspace segment panicked instead of returning `Err`.
289        #[test]
290        fn invalid_workspace_id_segment_is_err_not_panic() {
291            // 16 alphanumeric chars, but a lowercase `s` is not valid base32:
292            assert!(Crn::try_from("crn:ca-central-1.aws:ZVATKWsVHMFG27DY").is_err());
293            assert!("crn:ca-central-1.aws:ZVATKWsVHMFG27DY"
294                .parse::<Crn>()
295                .is_err());
296            // ... and `0` is outside the RFC4648 base32 alphabet:
297            assert!(Crn::try_from("crn:ca-central-1.aws:0VATKW3VHMFG27DY").is_err());
298        }
299    }
300
301    mod display {
302        use super::*;
303
304        #[test]
305        fn test_with_workspace_id() {
306            let workspace_id = WorkspaceId::generate().unwrap();
307            let crn = Crn::new(Region::new("us-east-1.aws").unwrap(), workspace_id);
308            assert_eq!(crn.to_string(), format!("crn:us-east-1.aws:{workspace_id}"));
309        }
310
311        #[test]
312        fn test_ca_central_1_round_trip() {
313            let workspace_id = WorkspaceId::generate().unwrap();
314            let crn = Crn::new(Region::new("ca-central-1.aws").unwrap(), workspace_id);
315            assert_eq!(
316                crn.to_string(),
317                format!("crn:ca-central-1.aws:{workspace_id}")
318            );
319        }
320
321        #[test]
322        fn test_with_workspace_id_and_service() {
323            let workspace_id = WorkspaceId::generate().unwrap();
324            let crn = Crn::new(Region::new("us-east-1.aws").unwrap(), workspace_id)
325                .with_service_name("zerokms");
326            assert_eq!(
327                crn.to_string(),
328                format!("crn:us-east-1.aws:{workspace_id}:zerokms")
329            );
330        }
331    }
332
333    #[test]
334    fn no_region_identifier_is_prefix_of_another() {
335        let identifiers: Vec<&str> = AwsRegion::ALL.iter().map(|r| r.identifier()).collect();
336        for (i, a) in identifiers.iter().enumerate() {
337            for (j, b) in identifiers.iter().enumerate() {
338                if i != j {
339                    assert!(
340                        !b.starts_with(a),
341                        "region identifier {:?} is a prefix of {:?} — \
342                         region_geo() would match {:?} first, making {:?} unparseable",
343                        a,
344                        b,
345                        a,
346                        b
347                    );
348                }
349            }
350        }
351    }
352}