1use async_lock::RwLock;
5use azure_core::{
6 credentials::{AccessToken, TokenCredential, TokenRequestOptions},
7 error::{Error, ErrorKind},
8};
9use azure_identity::{
10 AzureCliCredential, AzureCliCredentialOptions, AzureDeveloperCliCredential,
11 AzureDeveloperCliCredentialOptions,
12};
13use std::sync::{Arc, LazyLock};
14use tracing::Instrument;
15
16#[derive(Debug)]
17pub struct DeveloperCredential {
18 options: Option<DeveloperCredentialOptions>,
19 credential: RwLock<Option<Arc<dyn TokenCredential>>>,
20}
21
22impl DeveloperCredential {
23 pub fn new(options: Option<DeveloperCredentialOptions>) -> Arc<Self> {
24 Arc::new(Self {
25 options,
26 credential: RwLock::new(None),
27 })
28 }
29}
30
31#[async_trait::async_trait]
32impl TokenCredential for DeveloperCredential {
33 async fn get_token(
34 &self,
35 scopes: &[&str],
36 options: Option<TokenRequestOptions<'_>>,
37 ) -> azure_core::Result<AccessToken> {
38 if let Some(credential) = self.credential.read().await.as_ref() {
39 return credential.get_token(scopes, options).await;
40 }
41
42 let mut lock = self.credential.write().await;
43 if let Some(credential) = lock.as_ref() {
44 return credential.get_token(scopes, options).await;
45 }
46
47 let mut errors = Vec::new();
48 for (name, f) in CREDENTIALS.iter() {
49 let options = options.clone();
50 match async {
51 match f(self.options.as_ref()) {
52 Ok(c) => match c.get_token(scopes, options).await {
53 Ok(token) => {
54 tracing::debug!(target: "akv::credentials", "acquired token");
55 *lock = Some(c);
56 Ok(token)
57 }
58 Err(err) => {
59 tracing::debug!(target: "akv::credentials", "failed acquiring token: {err}");
60 Err(err)
61 }
62 },
63 Err(err) => {
64 tracing::debug!(target: "akv::credentials", "failed creating credential: {err}");
65 Err(err)
66 }
67 }
68 }
69 .instrument(tracing::debug_span!(target: "akv::credentials", "trying credential", name))
70 .await
71 {
72 Ok(token) => return Ok(token),
73 Err(err) => errors.push(err),
74 }
75 }
76
77 Err(Error::with_message_fn(ErrorKind::Credential, || {
78 format!(
79 "Multiple errors when attempting to authenticate:\n{}",
80 aggregate(&errors)
81 )
82 }))
83 }
84}
85
86#[derive(Debug, Default)]
87pub struct DeveloperCredentialOptions {
88 pub subscription: Option<String>,
89 pub tenant_id: Option<String>,
90 pub additionally_allowed_tenants: Vec<String>,
91}
92
93impl From<&DeveloperCredentialOptions> for AzureCliCredentialOptions {
94 fn from(options: &DeveloperCredentialOptions) -> Self {
95 AzureCliCredentialOptions {
96 subscription: options.subscription.clone(),
97 tenant_id: options.tenant_id.clone(),
98 additionally_allowed_tenants: options.additionally_allowed_tenants.clone(),
99 ..Default::default()
100 }
101 }
102}
103
104impl From<&DeveloperCredentialOptions> for AzureDeveloperCliCredentialOptions {
105 fn from(options: &DeveloperCredentialOptions) -> Self {
106 AzureDeveloperCliCredentialOptions {
107 tenant_id: options.tenant_id.clone(),
108 ..Default::default()
109 }
110 }
111}
112
113type CredentialFn = (
114 &'static str,
115 Box<
116 dyn Fn(Option<&DeveloperCredentialOptions>) -> azure_core::Result<Arc<dyn TokenCredential>>
117 + Send
118 + Sync
119 + 'static,
120 >,
121);
122
123static CREDENTIALS: LazyLock<Vec<CredentialFn>> = LazyLock::new(|| {
124 vec![
127 (
128 "AzureDeveloperCliCredential",
129 Box::new(|options| Ok(AzureDeveloperCliCredential::new(options.map(Into::into))?)),
130 ),
131 (
132 "AzureCliCredential",
133 Box::new(|options| Ok(AzureCliCredential::new(options.map(Into::into))?)),
134 ),
135 ]
136});
137
138fn aggregate(errors: &[Error]) -> String {
139 use std::error::Error;
140 errors
141 .iter()
142 .map(|err| {
143 let mut current: Option<&dyn Error> = Some(err);
144 let mut stack = vec![];
145 while let Some(err) = current.take() {
146 stack.push(err.to_string());
147 current = err.source();
148 }
149 stack.join(" - ")
150 })
151 .collect::<Vec<String>>()
152 .join("\n")
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158
159 #[test]
160 fn aggregate_multiple_errors() {
161 let errors = vec![
162 Error::with_error(
163 ErrorKind::Other,
164 Error::with_message(ErrorKind::Other, "first inner error"),
165 "first outer error",
166 ),
167 Error::with_message(ErrorKind::Other, "second error"),
168 ];
169 assert_eq!(
170 aggregate(&errors),
171 "first outer error - first inner error\nsecond error"
172 );
173 }
174}