akv_cli/
credentials.rs

1// Copyright 2024 Heath Stewart.
2// Licensed under the MIT License. See LICENSE.txt in the project root for license information.
3
4use async_lock::RwLock;
5use azure_core::{
6    credentials::{AccessToken, TokenCredential, TokenRequestOptions},
7    error::{Error, ErrorKind},
8};
9use azure_identity::{
10    AzureCliCredential, AzureCliCredentialOptions, AzureDeveloperCliCredential,
11    AzureDeveloperCliCredentialOptions,
12};
13use std::sync::{Arc, LazyLock};
14use tracing::Instrument;
15
16#[derive(Debug)]
17pub struct DeveloperCredential {
18    options: Option<DeveloperCredentialOptions>,
19    credential: RwLock<Option<Arc<dyn TokenCredential>>>,
20}
21
22impl DeveloperCredential {
23    pub fn new(options: Option<DeveloperCredentialOptions>) -> Arc<Self> {
24        Arc::new(Self {
25            options,
26            credential: RwLock::new(None),
27        })
28    }
29}
30
31#[async_trait::async_trait]
32impl TokenCredential for DeveloperCredential {
33    async fn get_token(
34        &self,
35        scopes: &[&str],
36        options: Option<TokenRequestOptions<'_>>,
37    ) -> azure_core::Result<AccessToken> {
38        if let Some(credential) = self.credential.read().await.as_ref() {
39            return credential.get_token(scopes, options).await;
40        }
41
42        let mut lock = self.credential.write().await;
43        if let Some(credential) = lock.as_ref() {
44            return credential.get_token(scopes, options).await;
45        }
46
47        let mut errors = Vec::new();
48        for (name, f) in CREDENTIALS.iter() {
49            let options = options.clone();
50            match async {
51                match f(self.options.as_ref()) {
52                    Ok(c) => match c.get_token(scopes, options).await {
53                        Ok(token) => {
54                            tracing::debug!(target: "akv::credentials", "acquired token");
55                            *lock = Some(c);
56                            Ok(token)
57                        }
58                        Err(err) => {
59                            tracing::debug!(target: "akv::credentials", "failed acquiring token: {err}");
60                            Err(err)
61                        }
62                    },
63                    Err(err) => {
64                        tracing::debug!(target: "akv::credentials", "failed creating credential: {err}");
65                        Err(err)
66                    }
67                }
68            }
69            .instrument(tracing::debug_span!(target: "akv::credentials", "trying credential", name))
70            .await
71            {
72                Ok(token) => return Ok(token),
73                Err(err) => errors.push(err),
74            }
75        }
76
77        Err(Error::with_message_fn(ErrorKind::Credential, || {
78            format!(
79                "Multiple errors when attempting to authenticate:\n{}",
80                aggregate(&errors)
81            )
82        }))
83    }
84}
85
86#[derive(Debug, Default)]
87pub struct DeveloperCredentialOptions {
88    pub subscription: Option<String>,
89    pub tenant_id: Option<String>,
90    pub additionally_allowed_tenants: Vec<String>,
91}
92
93impl From<&DeveloperCredentialOptions> for AzureCliCredentialOptions {
94    fn from(options: &DeveloperCredentialOptions) -> Self {
95        AzureCliCredentialOptions {
96            subscription: options.subscription.clone(),
97            tenant_id: options.tenant_id.clone(),
98            additionally_allowed_tenants: options.additionally_allowed_tenants.clone(),
99            ..Default::default()
100        }
101    }
102}
103
104impl From<&DeveloperCredentialOptions> for AzureDeveloperCliCredentialOptions {
105    fn from(options: &DeveloperCredentialOptions) -> Self {
106        AzureDeveloperCliCredentialOptions {
107            tenant_id: options.tenant_id.clone(),
108            ..Default::default()
109        }
110    }
111}
112
113type CredentialFn = (
114    &'static str,
115    Box<
116        dyn Fn(Option<&DeveloperCredentialOptions>) -> azure_core::Result<Arc<dyn TokenCredential>>
117            + Send
118            + Sync
119            + 'static,
120    >,
121);
122
123static CREDENTIALS: LazyLock<Vec<CredentialFn>> = LazyLock::new(|| {
124    // Though az is likely more common, try azd first because it fails faster if even in $PATH.
125    // This is reverse of DefaultAzureCredential because azd was added long after az and compat was a concern.
126    vec![
127        (
128            "AzureDeveloperCliCredential",
129            Box::new(|options| Ok(AzureDeveloperCliCredential::new(options.map(Into::into))?)),
130        ),
131        (
132            "AzureCliCredential",
133            Box::new(|options| Ok(AzureCliCredential::new(options.map(Into::into))?)),
134        ),
135    ]
136});
137
138fn aggregate(errors: &[Error]) -> String {
139    use std::error::Error;
140    errors
141        .iter()
142        .map(|err| {
143            let mut current: Option<&dyn Error> = Some(err);
144            let mut stack = vec![];
145            while let Some(err) = current.take() {
146                stack.push(err.to_string());
147                current = err.source();
148            }
149            stack.join(" - ")
150        })
151        .collect::<Vec<String>>()
152        .join("\n")
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158
159    #[test]
160    fn aggregate_multiple_errors() {
161        let errors = vec![
162            Error::with_error(
163                ErrorKind::Other,
164                Error::with_message(ErrorKind::Other, "first inner error"),
165                "first outer error",
166            ),
167            Error::with_message(ErrorKind::Other, "second error"),
168        ];
169        assert_eq!(
170            aggregate(&errors),
171            "first outer error - first inner error\nsecond error"
172        );
173    }
174}