azure_identity/
developer_tools_credential.rs

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4use crate::{
5    AzureCliCredential, AzureCliCredentialOptions, AzureDeveloperCliCredential,
6    AzureDeveloperCliCredentialOptions, Executor,
7};
8use azure_core::{
9    credentials::{AccessToken, TokenCredential, TokenRequestOptions},
10    error::{Error, ErrorKind},
11};
12use std::{
13    fmt,
14    sync::{
15        atomic::{AtomicUsize, Ordering},
16        Arc,
17    },
18};
19
20/// Options for constructing a new [`DeveloperToolsCredential`]
21#[derive(Clone, Debug, Default)]
22pub struct DeveloperToolsCredentialOptions {
23    /// An implementation of [`Executor`] to run commands asynchronously.
24    pub executor: Option<Arc<dyn Executor>>,
25}
26
27/// Authenticates through developer tools such as the Azure CLI.
28///
29/// It tries the following credential types, in this order, stopping when one provides a token:
30///
31/// * [`AzureCliCredential`]
32/// * [`AzureDeveloperCliCredential`]
33///
34/// `DeveloperToolsCredential` uses the first credential that provides a token for all subsequent token requests. It never tries the others again.
35pub struct DeveloperToolsCredential {
36    sources: Vec<Arc<dyn TokenCredential>>,
37    // index of the source that first provided a token. usize::MAX indicates no source has provided a token.
38    cached_source_index: AtomicUsize,
39}
40
41impl DeveloperToolsCredential {
42    /// Creates a new instance of `DeveloperToolsCredential`.
43    ///
44    /// # Arguments
45    /// * `options`: Options for configuring the credential. If `None` is provided, default options will be used.
46    pub fn new(
47        options: Option<DeveloperToolsCredentialOptions>,
48    ) -> azure_core::Result<Arc<DeveloperToolsCredential>> {
49        let options = options.unwrap_or_default();
50        let sources: Vec<Arc<dyn TokenCredential>> = vec![
51            AzureCliCredential::new(Some(AzureCliCredentialOptions {
52                executor: options.executor.clone(),
53                ..Default::default()
54            }))?,
55            AzureDeveloperCliCredential::new(Some(AzureDeveloperCliCredentialOptions {
56                executor: options.executor,
57                ..Default::default()
58            }))?,
59        ];
60        Ok(Arc::new(Self {
61            sources,
62            cached_source_index: AtomicUsize::new(usize::MAX),
63        }))
64    }
65
66    #[cfg(test)]
67    pub(crate) fn new_with_sources(
68        sources: Vec<Arc<dyn TokenCredential>>,
69    ) -> azure_core::Result<Arc<DeveloperToolsCredential>> {
70        Ok(Arc::new(Self {
71            sources,
72            cached_source_index: AtomicUsize::new(usize::MAX),
73        }))
74    }
75}
76
77impl fmt::Debug for DeveloperToolsCredential {
78    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79        f.write_str("DeveloperToolsCredential")
80    }
81}
82
83#[async_trait::async_trait]
84impl TokenCredential for DeveloperToolsCredential {
85    async fn get_token(
86        &self,
87        scopes: &[&str],
88        options: Option<TokenRequestOptions<'_>>,
89    ) -> azure_core::Result<AccessToken> {
90        let cached_index = self.cached_source_index.load(Ordering::Relaxed);
91        if cached_index != usize::MAX {
92            if let Some(source) = self.sources.get(cached_index) {
93                return source.get_token(scopes, options).await;
94            }
95            // impossible because the vector's size never changes
96            panic!("DeveloperToolsCredential source index {cached_index} is out of bounds")
97        }
98
99        let mut errors = Vec::new();
100        for (index, source) in self.sources.iter().enumerate() {
101            match source.get_token(scopes, options.clone()).await {
102                Ok(token) => {
103                    self.cached_source_index.store(index, Ordering::Relaxed);
104                    return Ok(token);
105                }
106                Err(error) => errors.push(error),
107            }
108        }
109        Err(Error::with_message_fn(ErrorKind::Credential, || {
110            format!(
111                "Multiple errors were encountered while attempting to authenticate:\n{}",
112                format_aggregate_error(&errors)
113            )
114        }))
115    }
116}
117
118fn format_aggregate_error(errors: &[Error]) -> String {
119    use std::error::Error;
120    errors
121        .iter()
122        .map(|e| {
123            let mut current: Option<&dyn Error> = Some(e);
124            let mut stack = vec![];
125            while let Some(err) = current.take() {
126                stack.push(err.to_string());
127                current = err.source();
128            }
129            stack.join(" - ")
130        })
131        .collect::<Vec<String>>()
132        .join("\n")
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138    use crate::tests::MockExecutor;
139    use azure_core::credentials::AccessToken;
140    use std::sync::atomic::{AtomicUsize, Ordering};
141    use std::time::{Duration, SystemTime};
142
143    #[derive(Debug)]
144    struct MockCredential {
145        call_count: AtomicUsize,
146        id: String,
147        succeed: bool,
148    }
149
150    impl MockCredential {
151        fn new(id: &str, succeed: bool) -> Arc<Self> {
152            Arc::new(Self {
153                call_count: AtomicUsize::new(0),
154                id: id.to_string(),
155                succeed,
156            })
157        }
158
159        fn call_count(&self) -> usize {
160            self.call_count.load(Ordering::SeqCst)
161        }
162    }
163
164    #[async_trait::async_trait]
165    impl TokenCredential for MockCredential {
166        async fn get_token(
167            &self,
168            _scopes: &[&str],
169            _options: Option<TokenRequestOptions<'_>>,
170        ) -> azure_core::Result<AccessToken> {
171            self.call_count.fetch_add(1, Ordering::SeqCst);
172            if self.succeed {
173                Ok(AccessToken {
174                    token: self.id.clone().into(),
175                    expires_on: (SystemTime::now() + Duration::from_secs(3600)).into(),
176                })
177            } else {
178                Err(Error::with_message_fn(ErrorKind::Credential, || {
179                    format!("{} failed", self.id)
180                }))
181            }
182        }
183    }
184
185    #[tokio::test]
186    async fn cached_credential() {
187        let mock1 = MockCredential::new("mock1", false);
188        let mock2 = MockCredential::new("mock2", false);
189        let mock3 = MockCredential::new("mock3", true);
190        let mock4 = MockCredential::new("mock4", true);
191        let sources: Vec<Arc<dyn TokenCredential>> =
192            vec![mock1.clone(), mock2.clone(), mock3.clone(), mock4.clone()];
193
194        let credential = DeveloperToolsCredential::new_with_sources(sources).unwrap();
195
196        for i in 1..=5 {
197            let token = credential
198                .get_token(&["scope"], None)
199                .await
200                .expect("authentication success");
201            assert_eq!(token.token.secret(), "mock3");
202            assert_eq!(mock1.call_count(), 1);
203            assert_eq!(mock2.call_count(), 1);
204            assert_eq!(mock3.call_count(), i);
205            assert_eq!(mock4.call_count(), 0);
206        }
207    }
208
209    #[tokio::test]
210    async fn error_message() {
211        let mock1 = MockCredential::new("mock1", false);
212        let mock2 = MockCredential::new("mock2", false);
213        let mock3 = MockCredential::new("mock3", false);
214        let sources: Vec<Arc<dyn TokenCredential>> =
215            vec![mock1.clone(), mock2.clone(), mock3.clone()];
216
217        let credential = DeveloperToolsCredential::new_with_sources(sources).unwrap();
218
219        let error_msg = credential
220            .get_token(&["scope"], None)
221            .await
222            .expect_err("authentication error")
223            .to_string();
224
225        assert_eq!(mock1.call_count(), 1);
226        assert_eq!(mock2.call_count(), 1);
227        assert_eq!(mock3.call_count(), 1);
228        assert!(error_msg.contains("mock1 failed"));
229        assert!(error_msg.contains("mock2 failed"));
230        assert!(error_msg.contains("mock3 failed"));
231    }
232
233    #[tokio::test]
234    async fn executor() {
235        let err = std::io::Error::other("something went wrong");
236        let executor = MockExecutor::with_error(err);
237        let options = DeveloperToolsCredentialOptions {
238            executor: Some(executor.clone()),
239        };
240        let err = DeveloperToolsCredential::new(Some(options))
241            .expect("valid credential")
242            .get_token(&["scope"], None)
243            .await
244            .expect_err("expected error");
245        assert!(err.to_string().contains("something went wrong"));
246        assert_eq!(
247            2,
248            executor.call_count(),
249            "Executor should have been called once for each inner credential"
250        );
251    }
252}