Skip to main content

aws_config/sso/
cache.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use aws_runtime::fs_util::{home_dir, Os};
7use aws_smithy_json::deserialize::token::skip_value;
8use aws_smithy_json::deserialize::Token;
9use aws_smithy_json::deserialize::{json_token_iter, EscapeError};
10use aws_smithy_json::serialize::JsonObjectWriter;
11use aws_smithy_types::date_time::{DateTimeFormatError, Format};
12use aws_smithy_types::DateTime;
13use aws_types::os_shim_internal::{Env, Fs};
14use sha1::{Digest, Sha1};
15use std::borrow::Cow;
16use std::error::Error as StdError;
17use std::fmt;
18use std::path::PathBuf;
19use std::time::SystemTime;
20use zeroize::Zeroizing;
21
22#[cfg_attr(test, derive(Eq, PartialEq))]
23#[derive(Clone)]
24pub(super) struct CachedSsoToken {
25    pub(super) access_token: Zeroizing<String>,
26    pub(super) client_id: Option<String>,
27    pub(super) client_secret: Option<Zeroizing<String>>,
28    pub(super) expires_at: SystemTime,
29    pub(super) refresh_token: Option<Zeroizing<String>>,
30    pub(super) region: Option<String>,
31    pub(super) registration_expires_at: Option<SystemTime>,
32    pub(super) start_url: Option<String>,
33}
34
35impl CachedSsoToken {
36    /// True if the information required to refresh this token is present.
37    ///
38    /// The expiration times are not considered by this function.
39    pub(super) fn refreshable(&self) -> bool {
40        self.client_id.is_some()
41            && self.client_secret.is_some()
42            && self.refresh_token.is_some()
43            && self.registration_expires_at.is_some()
44    }
45}
46
47impl fmt::Debug for CachedSsoToken {
48    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49        f.debug_struct("CachedSsoToken")
50            .field("access_token", &"** redacted **")
51            .field("client_id", &self.client_id)
52            .field("client_secret", &"** redacted **")
53            .field("expires_at", &self.expires_at)
54            .field("refresh_token", &"** redacted **")
55            .field("region", &self.region)
56            .field("registration_expires_at", &self.registration_expires_at)
57            .field("start_url", &self.start_url)
58            .finish()
59    }
60}
61
62#[derive(Debug)]
63pub(super) enum CachedSsoTokenError {
64    FailedToFormatDateTime {
65        source: Box<dyn StdError + Send + Sync>,
66    },
67    InvalidField {
68        field: &'static str,
69        source: Box<dyn StdError + Send + Sync>,
70    },
71    IoError {
72        what: &'static str,
73        path: PathBuf,
74        source: std::io::Error,
75    },
76    JsonError(Box<dyn StdError + Send + Sync>),
77    MissingField(&'static str),
78    NoHomeDirectory,
79    Other(Cow<'static, str>),
80}
81
82impl fmt::Display for CachedSsoTokenError {
83    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84        match self {
85            Self::FailedToFormatDateTime { .. } => write!(f, "failed to format date time"),
86            Self::InvalidField { field, .. } => write!(
87                f,
88                "invalid value for the `{field}` field in the cached SSO token file"
89            ),
90            Self::IoError { what, path, .. } => write!(f, "failed to {what} `{}`", path.display()),
91            Self::JsonError(_) => write!(f, "invalid JSON in cached SSO token file"),
92            Self::MissingField(field) => {
93                write!(f, "missing field `{field}` in cached SSO token file")
94            }
95            Self::NoHomeDirectory => write!(f, "couldn't resolve a home directory"),
96            Self::Other(message) => f.write_str(message),
97        }
98    }
99}
100
101impl StdError for CachedSsoTokenError {
102    fn source(&self) -> Option<&(dyn StdError + 'static)> {
103        match self {
104            Self::FailedToFormatDateTime { source } => Some(source.as_ref()),
105            Self::InvalidField { source, .. } => Some(source.as_ref()),
106            Self::IoError { source, .. } => Some(source),
107            Self::JsonError(source) => Some(source.as_ref()),
108            Self::MissingField(_) => None,
109            Self::NoHomeDirectory => None,
110            Self::Other(_) => None,
111        }
112    }
113}
114
115impl From<EscapeError> for CachedSsoTokenError {
116    fn from(err: EscapeError) -> Self {
117        Self::JsonError(err.into())
118    }
119}
120
121impl From<aws_smithy_json::deserialize::error::DeserializeError> for CachedSsoTokenError {
122    fn from(err: aws_smithy_json::deserialize::error::DeserializeError) -> Self {
123        Self::JsonError(err.into())
124    }
125}
126
127impl From<DateTimeFormatError> for CachedSsoTokenError {
128    fn from(value: DateTimeFormatError) -> Self {
129        Self::FailedToFormatDateTime {
130            source: value.into(),
131        }
132    }
133}
134
135/// Determine the SSO cached token path for a given identifier.
136///
137/// The `identifier` is the `sso_start_url` for credentials providers, and `sso_session_name` for token providers.
138fn cached_token_path(identifier: &str, home: &str) -> PathBuf {
139    // hex::encode returns a lowercase string
140    let mut out = PathBuf::with_capacity(home.len() + "/.aws/sso/cache".len() + ".json".len() + 40);
141    out.push(home);
142    out.push(".aws/sso/cache");
143    out.push(hex::encode(Sha1::digest(identifier.as_bytes())));
144    out.set_extension("json");
145    out
146}
147
148/// Load the token for `identifier` from `~/.aws/sso/cache/<hashofidentifier>.json`
149///
150/// The `identifier` is the `sso_start_url` for credentials providers, and `sso_session_name` for token providers.
151pub(super) async fn load_cached_token(
152    env: &Env,
153    fs: &Fs,
154    identifier: &str,
155) -> Result<CachedSsoToken, CachedSsoTokenError> {
156    let home = home_dir(env, Os::real()).ok_or(CachedSsoTokenError::NoHomeDirectory)?;
157    let path = cached_token_path(identifier, &home);
158    let data = Zeroizing::new(fs.read_to_end(&path).await.map_err(|source| {
159        CachedSsoTokenError::IoError {
160            what: "read",
161            path,
162            source,
163        }
164    })?);
165    parse_cached_token(&data)
166}
167
168/// Parse SSO token JSON from input
169fn parse_cached_token(
170    cached_token_file_contents: &[u8],
171) -> Result<CachedSsoToken, CachedSsoTokenError> {
172    use CachedSsoTokenError as Error;
173
174    let mut access_token = None;
175    let mut expires_at = None;
176    let mut client_id = None;
177    let mut client_secret = None;
178    let mut refresh_token = None;
179    let mut region = None;
180    let mut registration_expires_at = None;
181    let mut start_url = None;
182    json_parse_loop(cached_token_file_contents, |key, value| {
183        match (key, value) {
184            /*
185            // Required fields:
186            "accessToken": "string",
187            "expiresAt": "2019-11-14T04:05:45Z",
188
189            // Optional fields:
190            "refreshToken": "string",
191            "clientId": "ABCDEFG323242423121312312312312312",
192            "clientSecret": "ABCDE123",
193            "registrationExpiresAt": "2022-03-06T19:53:17Z",
194            "region": "us-west-2",
195            "startUrl": "https://d-abc123.awsapps.com/start"
196            */
197            (key, Token::ValueString { value, .. }) if key.eq_ignore_ascii_case("accessToken") => {
198                access_token = Some(Zeroizing::new(value.to_unescaped()?.into_owned()));
199            }
200            (key, Token::ValueString { value, .. }) if key.eq_ignore_ascii_case("expiresAt") => {
201                expires_at = Some(value.to_unescaped()?);
202            }
203            (key, Token::ValueString { value, .. }) if key.eq_ignore_ascii_case("clientId") => {
204                client_id = Some(value.to_unescaped()?);
205            }
206            (key, Token::ValueString { value, .. }) if key.eq_ignore_ascii_case("clientSecret") => {
207                client_secret = Some(Zeroizing::new(value.to_unescaped()?.into_owned()));
208            }
209            (key, Token::ValueString { value, .. }) if key.eq_ignore_ascii_case("refreshToken") => {
210                refresh_token = Some(Zeroizing::new(value.to_unescaped()?.into_owned()));
211            }
212            (key, Token::ValueString { value, .. }) if key.eq_ignore_ascii_case("region") => {
213                region = Some(value.to_unescaped()?.into_owned());
214            }
215            (key, Token::ValueString { value, .. })
216                if key.eq_ignore_ascii_case("registrationExpiresAt") =>
217            {
218                registration_expires_at = Some(value.to_unescaped()?);
219            }
220            (key, Token::ValueString { value, .. }) if key.eq_ignore_ascii_case("startUrl") => {
221                start_url = Some(value.to_unescaped()?.into_owned());
222            }
223            _ => {}
224        };
225        Ok(())
226    })?;
227
228    Ok(CachedSsoToken {
229        access_token: access_token.ok_or(Error::MissingField("accessToken"))?,
230        expires_at: expires_at
231            .ok_or(Error::MissingField("expiresAt"))
232            .and_then(|expires_at| {
233                DateTime::from_str(expires_at.as_ref(), Format::DateTime)
234                    .map_err(|err| Error::InvalidField { field: "expiresAt", source: err.into() })
235                    .and_then(|date_time| {
236                        SystemTime::try_from(date_time).map_err(|_| {
237                            Error::Other(
238                                "SSO token expiration time cannot be represented by a SystemTime"
239                                    .into(),
240                            )
241                        })
242                    })
243            })?,
244        client_id: client_id.map(Cow::into_owned),
245        client_secret,
246        refresh_token,
247        region,
248        registration_expires_at: Ok(registration_expires_at).and_then(|maybe_expires_at| {
249            if let Some(expires_at) = maybe_expires_at {
250                Some(
251                    DateTime::from_str(expires_at.as_ref(), Format::DateTime)
252                        .map_err(|err| Error::InvalidField { field: "registrationExpiresAt", source: err.into()})
253                        .and_then(|date_time| {
254                            SystemTime::try_from(date_time).map_err(|_| {
255                                Error::Other(
256                                    "SSO registration expiration time cannot be represented by a SystemTime"
257                                        .into(),
258                                )
259                            })
260                        }),
261                )
262                .transpose()
263            } else {
264                Ok(None)
265            }
266        })?,
267        start_url,
268    })
269}
270
271fn json_parse_loop<'a>(
272    input: &'a [u8],
273    mut f: impl FnMut(Cow<'a, str>, &Token<'a>) -> Result<(), CachedSsoTokenError>,
274) -> Result<(), CachedSsoTokenError> {
275    use CachedSsoTokenError as Error;
276    let mut tokens = json_token_iter(input).peekable();
277    if !matches!(tokens.next().transpose()?, Some(Token::StartObject { .. })) {
278        return Err(Error::Other(
279            "expected a JSON document starting with `{`".into(),
280        ));
281    }
282    loop {
283        match tokens.next().transpose()? {
284            Some(Token::EndObject { .. }) => break,
285            Some(Token::ObjectKey { key, .. }) => {
286                if let Some(Ok(token)) = tokens.peek() {
287                    let key = key.to_unescaped()?;
288                    f(key, token)?
289                }
290                skip_value(&mut tokens)?;
291            }
292            other => {
293                return Err(Error::Other(
294                    format!("expected object key, found: {other:?}").into(),
295                ));
296            }
297        }
298    }
299    if tokens.next().is_some() {
300        return Err(Error::Other(
301            "found more JSON tokens after completing parsing".into(),
302        ));
303    }
304    Ok(())
305}
306
307pub(super) async fn save_cached_token(
308    env: &Env,
309    fs: &Fs,
310    identifier: &str,
311    token: &CachedSsoToken,
312) -> Result<(), CachedSsoTokenError> {
313    let expires_at = DateTime::from(token.expires_at).fmt(Format::DateTime)?;
314    let registration_expires_at = token
315        .registration_expires_at
316        .map(|time| DateTime::from(time).fmt(Format::DateTime))
317        .transpose()?;
318
319    let mut out = Zeroizing::new(String::new());
320    let mut writer = JsonObjectWriter::new(&mut out);
321    writer.key("accessToken").string(&token.access_token);
322    writer.key("expiresAt").string(&expires_at);
323    if let Some(refresh_token) = &token.refresh_token {
324        writer.key("refreshToken").string(refresh_token);
325    }
326    if let Some(client_id) = &token.client_id {
327        writer.key("clientId").string(client_id);
328    }
329    if let Some(client_secret) = &token.client_secret {
330        writer.key("clientSecret").string(client_secret);
331    }
332    if let Some(registration_expires_at) = registration_expires_at {
333        writer
334            .key("registrationExpiresAt")
335            .string(&registration_expires_at);
336    }
337    if let Some(region) = &token.region {
338        writer.key("region").string(region);
339    }
340    if let Some(start_url) = &token.start_url {
341        writer.key("startUrl").string(start_url);
342    }
343    writer.finish();
344
345    let home = home_dir(env, Os::real()).ok_or(CachedSsoTokenError::NoHomeDirectory)?;
346    let path = cached_token_path(identifier, &home);
347    fs.write(&path, out.as_bytes())
348        .await
349        .map_err(|err| CachedSsoTokenError::IoError {
350            what: "write",
351            path,
352            source: err,
353        })?;
354    Ok(())
355}
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360    use std::collections::HashMap;
361    use std::time::Duration;
362
363    #[test]
364    fn redact_fields_in_token_debug() {
365        let token = CachedSsoToken {
366            access_token: Zeroizing::new("!!SENSITIVE!!".into()),
367            client_id: Some("clientid".into()),
368            client_secret: Some(Zeroizing::new("!!SENSITIVE!!".into())),
369            expires_at: SystemTime::now(),
370            refresh_token: Some(Zeroizing::new("!!SENSITIVE!!".into())),
371            region: Some("region".into()),
372            registration_expires_at: Some(SystemTime::now()),
373            start_url: Some("starturl".into()),
374        };
375        let debug_str = format!("{:?}", token);
376        assert!(!debug_str.contains("!!SENSITIVE!!"), "The `Debug` impl for `CachedSsoToken` isn't properly redacting sensitive fields: {debug_str}");
377    }
378
379    // Valid token with all fields
380    #[test]
381    fn parse_valid_token() {
382        let file_contents = r#"
383        {
384            "startUrl": "https://d-123.awsapps.com/start",
385            "region": "us-west-2",
386            "accessToken": "cachedtoken",
387            "expiresAt": "2021-12-25T21:30:00Z",
388            "clientId": "clientid",
389            "clientSecret": "YSBzZWNyZXQ=",
390            "registrationExpiresAt": "2022-12-25T13:30:00Z",
391            "refreshToken": "cachedrefreshtoken"
392        }
393        "#;
394        let cached = parse_cached_token(file_contents.as_bytes()).expect("success");
395        assert_eq!("cachedtoken", cached.access_token.as_str());
396        assert_eq!(
397            SystemTime::UNIX_EPOCH + Duration::from_secs(1640467800),
398            cached.expires_at
399        );
400        assert_eq!("clientid", cached.client_id.expect("client id is present"));
401        assert_eq!(
402            "YSBzZWNyZXQ=",
403            cached
404                .client_secret
405                .expect("client secret is present")
406                .as_str()
407        );
408        assert_eq!(
409            "cachedrefreshtoken",
410            cached
411                .refresh_token
412                .expect("refresh token is present")
413                .as_str()
414        );
415        assert_eq!(
416            SystemTime::UNIX_EPOCH + Duration::from_secs(1671975000),
417            cached
418                .registration_expires_at
419                .expect("registration expiration is present")
420        );
421        assert_eq!("us-west-2", cached.region.expect("region is present"));
422        assert_eq!(
423            "https://d-123.awsapps.com/start",
424            cached.start_url.expect("startUrl is present")
425        );
426    }
427
428    // Minimal valid cached token
429    #[test]
430    fn parse_valid_token_with_optional_fields_absent() {
431        let file_contents = r#"
432        {
433            "accessToken": "cachedtoken",
434            "expiresAt": "2021-12-25T21:30:00Z"
435        }
436        "#;
437        let cached = parse_cached_token(file_contents.as_bytes()).expect("success");
438        assert_eq!("cachedtoken", cached.access_token.as_str());
439        assert_eq!(
440            SystemTime::UNIX_EPOCH + Duration::from_secs(1640467800),
441            cached.expires_at
442        );
443        assert!(cached.client_id.is_none());
444        assert!(cached.client_secret.is_none());
445        assert!(cached.refresh_token.is_none());
446        assert!(cached.registration_expires_at.is_none());
447    }
448
449    #[test]
450    fn parse_invalid_timestamp() {
451        let token = br#"
452        {
453            "accessToken": "base64string",
454            "expiresAt": "notatimestamp",
455            "region": "us-west-2",
456            "startUrl": "https://d-abc123.awsapps.com/start"
457        }"#;
458        let err = parse_cached_token(token).expect_err("invalid timestamp");
459        let expected = "invalid value for the `expiresAt` field in the cached SSO token file";
460        let actual = format!("{err}");
461        assert!(
462            actual.contains(expected),
463            "expected error to contain `{expected}`, but was `{actual}`",
464        );
465    }
466
467    #[test]
468    fn parse_missing_fields() {
469        // Token missing accessToken field
470        let token = br#"
471        {
472            "expiresAt": "notatimestamp",
473            "region": "us-west-2",
474            "startUrl": "https://d-abc123.awsapps.com/start"
475        }"#;
476        let err = parse_cached_token(token).expect_err("missing akid");
477        assert!(
478            matches!(err, CachedSsoTokenError::MissingField("accessToken")),
479            "incorrect error: {:?}",
480            err
481        );
482
483        // Token missing expiresAt field
484        let token = br#"
485        {
486            "accessToken": "akid",
487            "region": "us-west-2",
488            "startUrl": "https://d-abc123.awsapps.com/start"
489        }"#;
490        let err = parse_cached_token(token).expect_err("missing expiry");
491        assert!(
492            matches!(err, CachedSsoTokenError::MissingField("expiresAt")),
493            "incorrect error: {:?}",
494            err
495        );
496    }
497
498    #[tokio::test]
499    async fn gracefully_handle_missing_files() {
500        let err = load_cached_token(
501            &Env::from_slice(&[("HOME", "/home")]),
502            &Fs::from_slice(&[]),
503            "asdf",
504        )
505        .await
506        .expect_err("should fail, file is missing");
507        assert!(
508            matches!(err, CachedSsoTokenError::IoError { .. }),
509            "should be io error, got {}",
510            err
511        );
512    }
513
514    // TODO(https://github.com/awslabs/aws-sdk-rust/issues/1117) This test is ignored on Windows because it uses Unix-style paths
515    #[cfg_attr(windows, ignore)]
516    #[test]
517    fn determine_correct_cache_filenames() {
518        assert_eq!(
519            "/home/someuser/.aws/sso/cache/d033e22ae348aeb5660fc2140aec35850c4da997.json",
520            cached_token_path("admin", "/home/someuser").as_os_str()
521        );
522        assert_eq!(
523            "/home/someuser/.aws/sso/cache/75e4d41276d8bd17f85986fc6cccef29fd725ce3.json",
524            cached_token_path("dev-scopes", "/home/someuser").as_os_str()
525        );
526        assert_eq!(
527            "/home/me/.aws/sso/cache/13f9d35043871d073ab260e020f0ffde092cb14b.json",
528            cached_token_path("https://d-92671207e4.awsapps.com/start", "/home/me").as_os_str(),
529        );
530        assert_eq!(
531            "/home/me/.aws/sso/cache/13f9d35043871d073ab260e020f0ffde092cb14b.json",
532            cached_token_path("https://d-92671207e4.awsapps.com/start", "/home/me/").as_os_str(),
533        );
534    }
535
536    // TODO(https://github.com/awslabs/aws-sdk-rust/issues/1117) This test is ignored on Windows because it uses Unix-style paths
537    #[cfg_attr(windows, ignore)]
538    #[tokio::test]
539    async fn save_cached_token() {
540        let expires_at = SystemTime::UNIX_EPOCH + Duration::from_secs(50_000_000);
541        let reg_expires_at = SystemTime::UNIX_EPOCH + Duration::from_secs(100_000_000);
542        let token = CachedSsoToken {
543            access_token: Zeroizing::new("access-token".into()),
544            client_id: Some("client-id".into()),
545            client_secret: Some(Zeroizing::new("client-secret".into())),
546            expires_at,
547            refresh_token: Some(Zeroizing::new("refresh-token".into())),
548            region: Some("region".into()),
549            registration_expires_at: Some(reg_expires_at),
550            start_url: Some("start-url".into()),
551        };
552
553        let env = Env::from_slice(&[("HOME", "/home/user")]);
554        let fs = Fs::from_map(HashMap::<_, Vec<u8>>::new());
555        super::save_cached_token(&env, &fs, "test", &token)
556            .await
557            .expect("success");
558
559        let contents = fs
560            .read_to_end("/home/user/.aws/sso/cache/a94a8fe5ccb19ba61c4c0873d391e987982fbbd3.json")
561            .await
562            .expect("correct file written");
563        let contents_str = String::from_utf8(contents).expect("valid utf8");
564        assert_eq!(
565            r#"{"accessToken":"access-token","expiresAt":"1971-08-02T16:53:20Z","refreshToken":"refresh-token","clientId":"client-id","clientSecret":"client-secret","registrationExpiresAt":"1973-03-03T09:46:40Z","region":"region","startUrl":"start-url"}"#,
566            contents_str,
567        );
568    }
569
570    #[tokio::test]
571    async fn round_trip_token() {
572        let expires_at = SystemTime::UNIX_EPOCH + Duration::from_secs(50_000_000);
573        let reg_expires_at = SystemTime::UNIX_EPOCH + Duration::from_secs(100_000_000);
574        let original = CachedSsoToken {
575            access_token: Zeroizing::new("access-token".into()),
576            client_id: Some("client-id".into()),
577            client_secret: Some(Zeroizing::new("client-secret".into())),
578            expires_at,
579            refresh_token: Some(Zeroizing::new("refresh-token".into())),
580            region: Some("region".into()),
581            registration_expires_at: Some(reg_expires_at),
582            start_url: Some("start-url".into()),
583        };
584
585        let env = Env::from_slice(&[("HOME", "/home/user")]);
586        let fs = Fs::from_map(HashMap::<_, Vec<u8>>::new());
587
588        super::save_cached_token(&env, &fs, "test", &original)
589            .await
590            .unwrap();
591
592        let roundtripped = load_cached_token(&env, &fs, "test").await.unwrap();
593        assert_eq!(original, roundtripped)
594    }
595}