1use super::client::UserDelegationKey;
19use crate::RetryConfig;
20use crate::azure::STORE;
21use crate::client::builder::{HttpRequestBuilder, add_query_pairs};
22use crate::client::retry::RetryExt;
23use crate::client::token::{TemporaryToken, TokenCache};
24use crate::client::{CredentialProvider, HttpClient, HttpError, HttpRequest, TokenProvider};
25use crate::util::hmac_sha256;
26use async_trait::async_trait;
27use base64::Engine;
28use base64::prelude::{BASE64_STANDARD, BASE64_URL_SAFE_NO_PAD};
29use chrono::{DateTime, SecondsFormat, Utc};
30use http::Method;
31use http::header::{
32 ACCEPT, AUTHORIZATION, CONTENT_ENCODING, CONTENT_LANGUAGE, CONTENT_LENGTH, CONTENT_TYPE, DATE,
33 HeaderMap, HeaderName, HeaderValue, IF_MATCH, IF_MODIFIED_SINCE, IF_NONE_MATCH,
34 IF_UNMODIFIED_SINCE, RANGE,
35};
36use serde::Deserialize;
37use std::borrow::Cow;
38use std::collections::HashMap;
39use std::fmt::Debug;
40use std::ops::Deref;
41use std::process::Command;
42use std::str;
43use std::sync::Arc;
44use std::time::{Duration, Instant, SystemTime};
45use url::Url;
46
47static AZURE_VERSION: HeaderValue = HeaderValue::from_static("2023-11-03");
48static VERSION: HeaderName = HeaderName::from_static("x-ms-version");
49pub(crate) static BLOB_TYPE: HeaderName = HeaderName::from_static("x-ms-blob-type");
50pub(crate) static COPY_SOURCE: HeaderName = HeaderName::from_static("x-ms-copy-source");
51static CONTENT_MD5: HeaderName = HeaderName::from_static("content-md5");
52static PARTNER_TOKEN: HeaderName = HeaderName::from_static("x-ms-partner-token");
53static CLUSTER_IDENTIFIER: HeaderName = HeaderName::from_static("x-ms-cluster-identifier");
54static WORKLOAD_RESOURCE: HeaderName = HeaderName::from_static("x-ms-workload-resource-moniker");
55static PROXY_HOST: HeaderName = HeaderName::from_static("x-ms-proxy-host");
56pub(crate) const RFC1123_FMT: &str = "%a, %d %h %Y %T GMT";
57const CONTENT_TYPE_JSON: &str = "application/json";
58const MSI_SECRET_ENV_KEY: &str = "IDENTITY_HEADER";
59const MSI_API_VERSION: &str = "2019-08-01";
60const TOKEN_MIN_TTL: u64 = 300;
61
62const AZURE_STORAGE_SCOPE: &str = "https://storage.azure.com/.default";
66
67const AZURE_STORAGE_RESOURCE: &str = "https://storage.azure.com";
71
72#[derive(Debug, thiserror::Error)]
73pub enum Error {
74 #[error("Error performing token request: {}", source)]
75 TokenRequest {
76 source: crate::client::retry::RetryError,
77 },
78
79 #[error("Error getting token response body: {}", source)]
80 TokenResponseBody { source: HttpError },
81
82 #[error("Error reading federated token file ")]
83 FederatedTokenFile,
84
85 #[error("Invalid Access Key: {}", source)]
86 InvalidAccessKey { source: base64::DecodeError },
87
88 #[error("'az account get-access-token' command failed: {message}")]
89 AzureCli { message: String },
90
91 #[error("Failed to parse azure cli response: {source}")]
92 AzureCliResponse { source: serde_json::Error },
93
94 #[error("Generating SAS keys with SAS tokens auth is not supported")]
95 SASforSASNotSupported,
96}
97
98pub(crate) type Result<T, E = Error> = std::result::Result<T, E>;
99
100impl From<Error> for crate::Error {
101 fn from(value: Error) -> Self {
102 Self::Generic {
103 store: STORE,
104 source: Box::new(value),
105 }
106 }
107}
108
109#[derive(Debug, Clone, Eq, PartialEq)]
111pub struct AzureAccessKey(Vec<u8>);
112
113impl AzureAccessKey {
114 pub fn try_new(key: &str) -> Result<Self> {
116 let key = BASE64_STANDARD
117 .decode(key)
118 .map_err(|source| Error::InvalidAccessKey { source })?;
119
120 Ok(Self(key))
121 }
122}
123
124#[derive(Debug, Eq, PartialEq)]
126pub enum AzureCredential {
127 AccessKey(AzureAccessKey),
131 SASToken(Vec<(String, String)>),
135 BearerToken(String),
139}
140
141impl AzureCredential {
142 pub fn sensitive_request(&self) -> bool {
144 match self {
145 Self::AccessKey(_) => false,
146 Self::BearerToken(_) => false,
147 Self::SASToken(_) => true,
149 }
150 }
151}
152
153pub mod authority_hosts {
155 pub const AZURE_CHINA: &str = "https://login.chinacloudapi.cn";
157 pub const AZURE_GERMANY: &str = "https://login.microsoftonline.de";
159 pub const AZURE_GOVERNMENT: &str = "https://login.microsoftonline.us";
161 pub const AZURE_PUBLIC_CLOUD: &str = "https://login.microsoftonline.com";
163}
164
165pub(crate) struct AzureSigner {
166 signing_key: AzureAccessKey,
167 start: DateTime<Utc>,
168 end: DateTime<Utc>,
169 account: String,
170 delegation_key: Option<UserDelegationKey>,
171}
172
173impl AzureSigner {
174 pub(crate) fn new(
175 signing_key: AzureAccessKey,
176 account: String,
177 start: DateTime<Utc>,
178 end: DateTime<Utc>,
179 delegation_key: Option<UserDelegationKey>,
180 ) -> Self {
181 Self {
182 signing_key,
183 account,
184 start,
185 end,
186 delegation_key,
187 }
188 }
189
190 pub(crate) fn sign(&self, method: &Method, url: &mut Url) -> Result<()> {
191 let (str_to_sign, query_pairs) = match &self.delegation_key {
192 Some(delegation_key) => string_to_sign_user_delegation_sas(
193 url,
194 method,
195 &self.account,
196 &self.start,
197 &self.end,
198 delegation_key,
199 ),
200 None => string_to_sign_service_sas(url, method, &self.account, &self.start, &self.end),
201 };
202 let auth = hmac_sha256(&self.signing_key.0, str_to_sign);
203 url.query_pairs_mut().extend_pairs(query_pairs);
204 url.query_pairs_mut()
205 .append_pair("sig", BASE64_STANDARD.encode(auth).as_str());
206 Ok(())
207 }
208}
209
210fn add_date_and_version_headers(request: &mut HttpRequest) {
211 let date = Utc::now();
213 let date_str = date.format(RFC1123_FMT).to_string();
214 let date_val = HeaderValue::from_str(&date_str).unwrap();
216 request.headers_mut().insert(DATE, date_val);
217 request
218 .headers_mut()
219 .insert(&VERSION, AZURE_VERSION.clone());
220}
221
222#[derive(Debug)]
224pub struct AzureAuthorizer<'a> {
225 credential: &'a AzureCredential,
226 account: &'a str,
227}
228
229impl<'a> AzureAuthorizer<'a> {
230 pub fn new(credential: &'a AzureCredential, account: &'a str) -> Self {
232 AzureAuthorizer {
233 credential,
234 account,
235 }
236 }
237
238 pub fn authorize(&self, request: &mut HttpRequest) {
240 add_date_and_version_headers(request);
241
242 match self.credential {
243 AzureCredential::AccessKey(key) => {
244 let url = Url::parse(&request.uri().to_string()).unwrap();
245 let signature = generate_authorization(
246 request.headers(),
247 &url,
248 request.method(),
249 self.account,
250 key,
251 );
252
253 request.headers_mut().append(
256 AUTHORIZATION,
257 HeaderValue::from_str(signature.as_str()).unwrap(),
258 );
259 }
260 AzureCredential::BearerToken(token) => {
261 request.headers_mut().append(
262 AUTHORIZATION,
263 HeaderValue::from_str(format!("Bearer {token}").as_str()).unwrap(),
264 );
265 }
266 AzureCredential::SASToken(query_pairs) => {
267 add_query_pairs(request.uri_mut(), query_pairs);
268 }
269 }
270 }
271}
272
273pub(crate) trait CredentialExt {
274 fn with_azure_authorization(
277 self,
278 credential: &Option<impl Deref<Target = AzureCredential>>,
279 account: &str,
280 ) -> Self;
281}
282
283impl CredentialExt for HttpRequestBuilder {
284 fn with_azure_authorization(
285 self,
286 credential: &Option<impl Deref<Target = AzureCredential>>,
287 account: &str,
288 ) -> Self {
289 let (client, request) = self.into_parts();
290 let mut request = request.expect("request valid");
291
292 match credential.as_deref() {
293 Some(credential) => {
294 AzureAuthorizer::new(credential, account).authorize(&mut request);
295 }
296 None => {
297 add_date_and_version_headers(&mut request);
298 }
299 }
300
301 Self::from_parts(client, request)
302 }
303}
304
305fn generate_authorization(
308 h: &HeaderMap,
309 u: &Url,
310 method: &Method,
311 account: &str,
312 key: &AzureAccessKey,
313) -> String {
314 let str_to_sign = string_to_sign(h, u, method, account);
315 let auth = hmac_sha256(&key.0, str_to_sign);
316 format!("SharedKey {}:{}", account, BASE64_STANDARD.encode(auth))
317}
318
319fn add_if_exists<'a>(h: &'a HeaderMap, key: &HeaderName) -> &'a str {
320 h.get(key)
321 .map(|s| s.to_str())
322 .transpose()
323 .ok()
324 .flatten()
325 .unwrap_or_default()
326}
327
328fn string_to_sign_sas(
329 u: &Url,
330 method: &Method,
331 account: &str,
332 start: &DateTime<Utc>,
333 end: &DateTime<Utc>,
334) -> (String, String, String, String, String) {
335 let signed_resource = "b".to_string();
337
338 let signed_permissions = match *method {
340 Method::GET => match signed_resource.as_str() {
342 "c" => "rl",
343 "b" => "r",
344 _ => unreachable!(),
345 },
346 Method::PUT => "w",
348 Method::DELETE => "d",
350 _ => "",
352 }
353 .to_string();
354 let signed_start = start.to_rfc3339_opts(SecondsFormat::Secs, true);
355 let signed_expiry = end.to_rfc3339_opts(SecondsFormat::Secs, true);
356 let canonicalized_resource = if u.host_str().unwrap_or_default().contains(account) {
357 format!("/blob/{}{}", account, u.path())
358 } else {
359 format!("/blob{}", u.path())
362 };
363
364 (
365 signed_resource,
366 signed_permissions,
367 signed_start,
368 signed_expiry,
369 canonicalized_resource,
370 )
371}
372
373fn string_to_sign_service_sas(
377 u: &Url,
378 method: &Method,
379 account: &str,
380 start: &DateTime<Utc>,
381 end: &DateTime<Utc>,
382) -> (String, HashMap<&'static str, String>) {
383 let (signed_resource, signed_permissions, signed_start, signed_expiry, canonicalized_resource) =
384 string_to_sign_sas(u, method, account, start, end);
385
386 let string_to_sign = format!(
387 "{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}",
388 signed_permissions,
389 signed_start,
390 signed_expiry,
391 canonicalized_resource,
392 "", "", "", &AZURE_VERSION.to_str().unwrap(), signed_resource, "", "", "", "", "", "", "", );
405
406 let mut pairs = HashMap::new();
407 pairs.insert("sv", AZURE_VERSION.to_str().unwrap().to_string());
408 pairs.insert("sp", signed_permissions);
409 pairs.insert("st", signed_start);
410 pairs.insert("se", signed_expiry);
411 pairs.insert("sr", signed_resource);
412
413 (string_to_sign, pairs)
414}
415
416fn string_to_sign_user_delegation_sas(
420 u: &Url,
421 method: &Method,
422 account: &str,
423 start: &DateTime<Utc>,
424 end: &DateTime<Utc>,
425 delegation_key: &UserDelegationKey,
426) -> (String, HashMap<&'static str, String>) {
427 let (signed_resource, signed_permissions, signed_start, signed_expiry, canonicalized_resource) =
428 string_to_sign_sas(u, method, account, start, end);
429
430 let string_to_sign = format!(
431 "{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}",
432 signed_permissions,
433 signed_start,
434 signed_expiry,
435 canonicalized_resource,
436 delegation_key.signed_oid, delegation_key.signed_tid, delegation_key.signed_start, delegation_key.signed_expiry, delegation_key.signed_service, delegation_key.signed_version, "", "", "", "", "", &AZURE_VERSION.to_str().unwrap(), signed_resource, "", "", "", "", "", "", "", );
457
458 let mut pairs = HashMap::new();
459 pairs.insert("sv", AZURE_VERSION.to_str().unwrap().to_string());
460 pairs.insert("sp", signed_permissions);
461 pairs.insert("st", signed_start);
462 pairs.insert("se", signed_expiry);
463 pairs.insert("sr", signed_resource);
464 pairs.insert("skoid", delegation_key.signed_oid.clone());
465 pairs.insert("sktid", delegation_key.signed_tid.clone());
466 pairs.insert("skt", delegation_key.signed_start.clone());
467 pairs.insert("ske", delegation_key.signed_expiry.clone());
468 pairs.insert("sks", delegation_key.signed_service.clone());
469 pairs.insert("skv", delegation_key.signed_version.clone());
470
471 (string_to_sign, pairs)
472}
473
474fn string_to_sign(h: &HeaderMap, u: &Url, method: &Method, account: &str) -> String {
476 let content_length = h
479 .get(&CONTENT_LENGTH)
480 .map(|s| s.to_str())
481 .transpose()
482 .ok()
483 .flatten()
484 .filter(|&v| v != "0")
485 .unwrap_or_default();
486 format!(
487 "{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}{}",
488 method.as_ref(),
489 add_if_exists(h, &CONTENT_ENCODING),
490 add_if_exists(h, &CONTENT_LANGUAGE),
491 content_length,
492 add_if_exists(h, &CONTENT_MD5),
493 add_if_exists(h, &CONTENT_TYPE),
494 add_if_exists(h, &DATE),
495 add_if_exists(h, &IF_MODIFIED_SINCE),
496 add_if_exists(h, &IF_MATCH),
497 add_if_exists(h, &IF_NONE_MATCH),
498 add_if_exists(h, &IF_UNMODIFIED_SINCE),
499 add_if_exists(h, &RANGE),
500 canonicalize_header(h),
501 canonicalize_resource(account, u)
502 )
503}
504
505fn canonicalize_header(headers: &HeaderMap) -> String {
507 let mut names = headers
508 .iter()
509 .filter(|&(k, _)| k.as_str().starts_with("x-ms"))
510 .map(|(k, _)| (k.as_str(), headers.get(k).unwrap().to_str().unwrap()))
512 .collect::<Vec<_>>();
513 names.sort_unstable();
514
515 let mut result = String::new();
516 for (name, value) in names {
517 result.push_str(name);
518 result.push(':');
519 result.push_str(value);
520 result.push('\n');
521 }
522 result
523}
524
525fn canonicalize_resource(account: &str, uri: &Url) -> String {
527 let mut can_res: String = String::new();
528 can_res.push('/');
529 can_res.push_str(account);
530 can_res.push_str(uri.path().to_string().as_str());
531 can_res.push('\n');
532
533 let query_pairs = uri.query_pairs();
535 {
536 let mut qps: Vec<String> = Vec::new();
537 for (q, _) in query_pairs {
538 if !(qps.iter().any(|x| x == &*q)) {
539 qps.push(q.into_owned());
540 }
541 }
542
543 qps.sort();
544
545 for qparam in qps {
546 let ret = lexy_sort(query_pairs, &qparam);
548
549 can_res = can_res + &qparam.to_lowercase() + ":";
550
551 for (i, item) in ret.iter().enumerate() {
552 if i > 0 {
553 can_res.push(',');
554 }
555 can_res.push_str(item);
556 }
557
558 can_res.push('\n');
559 }
560 };
561
562 can_res[0..can_res.len() - 1].to_owned()
563}
564
565fn lexy_sort<'a>(
566 vec: impl Iterator<Item = (Cow<'a, str>, Cow<'a, str>)> + 'a,
567 query_param: &str,
568) -> Vec<Cow<'a, str>> {
569 let mut values = vec
570 .filter(|(k, _)| *k == query_param)
571 .map(|(_, v)| v)
572 .collect::<Vec<_>>();
573 values.sort_unstable();
574 values
575}
576
577#[derive(Deserialize, Debug)]
579struct OAuthTokenResponse {
580 access_token: String,
581 expires_in: u64,
582}
583
584#[derive(Debug)]
588pub(crate) struct ClientSecretOAuthProvider {
589 token_url: String,
590 client_id: String,
591 client_secret: String,
592}
593
594impl ClientSecretOAuthProvider {
595 pub(crate) fn new(
597 client_id: String,
598 client_secret: String,
599 tenant_id: impl AsRef<str>,
600 authority_host: Option<String>,
601 ) -> Self {
602 let authority_host =
603 authority_host.unwrap_or_else(|| authority_hosts::AZURE_PUBLIC_CLOUD.to_owned());
604
605 Self {
606 token_url: format!(
607 "{}/{}/oauth2/v2.0/token",
608 authority_host,
609 tenant_id.as_ref()
610 ),
611 client_id,
612 client_secret,
613 }
614 }
615}
616
617#[async_trait::async_trait]
618impl TokenProvider for ClientSecretOAuthProvider {
619 type Credential = AzureCredential;
620
621 async fn fetch_token(
623 &self,
624 client: &HttpClient,
625 retry: &RetryConfig,
626 ) -> crate::Result<TemporaryToken<Arc<AzureCredential>>> {
627 let response: OAuthTokenResponse = client
628 .request(Method::POST, &self.token_url)
629 .header(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_JSON))
630 .form([
631 ("client_id", self.client_id.as_str()),
632 ("client_secret", self.client_secret.as_str()),
633 ("scope", AZURE_STORAGE_SCOPE),
634 ("grant_type", "client_credentials"),
635 ])
636 .retryable(retry)
637 .idempotent(true)
638 .send()
639 .await
640 .map_err(|source| Error::TokenRequest { source })?
641 .into_body()
642 .json()
643 .await
644 .map_err(|source| Error::TokenResponseBody { source })?;
645
646 Ok(TemporaryToken {
647 token: Arc::new(AzureCredential::BearerToken(response.access_token)),
648 expiry: Some(Instant::now() + Duration::from_secs(response.expires_in)),
649 })
650 }
651}
652
653fn expires_on_string<'de, D>(deserializer: D) -> std::result::Result<Instant, D::Error>
654where
655 D: serde::de::Deserializer<'de>,
656{
657 let v = String::deserialize(deserializer)?;
658 let v = v.parse::<u64>().map_err(serde::de::Error::custom)?;
659 let now = SystemTime::now()
660 .duration_since(SystemTime::UNIX_EPOCH)
661 .map_err(serde::de::Error::custom)?;
662
663 Ok(Instant::now() + Duration::from_secs(v.saturating_sub(now.as_secs())))
664}
665
666#[derive(Debug, Clone, Deserialize)]
670struct ImdsTokenResponse {
671 pub access_token: String,
672 #[serde(deserialize_with = "expires_on_string")]
673 pub expires_on: Instant,
674}
675
676#[derive(Debug)]
681pub(crate) struct ImdsManagedIdentityProvider {
682 msi_endpoint: String,
683 client_id: Option<String>,
684 object_id: Option<String>,
685 msi_res_id: Option<String>,
686}
687
688impl ImdsManagedIdentityProvider {
689 pub(crate) fn new(
691 client_id: Option<String>,
692 object_id: Option<String>,
693 msi_res_id: Option<String>,
694 msi_endpoint: Option<String>,
695 ) -> Self {
696 let msi_endpoint = msi_endpoint
697 .unwrap_or_else(|| "http://169.254.169.254/metadata/identity/oauth2/token".to_owned());
698
699 Self {
700 msi_endpoint,
701 client_id,
702 object_id,
703 msi_res_id,
704 }
705 }
706}
707
708#[async_trait::async_trait]
709impl TokenProvider for ImdsManagedIdentityProvider {
710 type Credential = AzureCredential;
711
712 async fn fetch_token(
714 &self,
715 client: &HttpClient,
716 retry: &RetryConfig,
717 ) -> crate::Result<TemporaryToken<Arc<AzureCredential>>> {
718 let mut query_items = vec![
719 ("api-version", MSI_API_VERSION),
720 ("resource", AZURE_STORAGE_RESOURCE),
721 ];
722
723 let mut identity = None;
724 if let Some(client_id) = &self.client_id {
725 identity = Some(("client_id", client_id));
726 }
727 if let Some(object_id) = &self.object_id {
728 identity = Some(("object_id", object_id));
729 }
730 if let Some(msi_res_id) = &self.msi_res_id {
731 identity = Some(("msi_res_id", msi_res_id));
732 }
733 if let Some((key, value)) = identity {
734 query_items.push((key, value));
735 }
736
737 let mut builder = client
738 .request(Method::GET, &self.msi_endpoint)
739 .header("metadata", "true")
740 .query(&query_items);
741
742 if let Ok(val) = std::env::var(MSI_SECRET_ENV_KEY) {
743 builder = builder.header("x-identity-header", val);
744 };
745
746 let response: ImdsTokenResponse = builder
747 .send_retry(retry)
748 .await
749 .map_err(|source| Error::TokenRequest { source })?
750 .into_body()
751 .json()
752 .await
753 .map_err(|source| Error::TokenResponseBody { source })?;
754
755 Ok(TemporaryToken {
756 token: Arc::new(AzureCredential::BearerToken(response.access_token)),
757 expiry: Some(response.expires_on),
758 })
759 }
760}
761
762#[derive(Debug)]
766pub(crate) struct WorkloadIdentityOAuthProvider {
767 token_url: String,
768 client_id: String,
769 federated_token_file: String,
770}
771
772impl WorkloadIdentityOAuthProvider {
773 pub(crate) fn new(
775 client_id: impl Into<String>,
776 federated_token_file: impl Into<String>,
777 tenant_id: impl AsRef<str>,
778 authority_host: Option<String>,
779 ) -> Self {
780 let authority_host =
781 authority_host.unwrap_or_else(|| authority_hosts::AZURE_PUBLIC_CLOUD.to_owned());
782
783 Self {
784 token_url: format!(
785 "{}/{}/oauth2/v2.0/token",
786 authority_host,
787 tenant_id.as_ref()
788 ),
789 client_id: client_id.into(),
790 federated_token_file: federated_token_file.into(),
791 }
792 }
793}
794
795#[async_trait::async_trait]
796impl TokenProvider for WorkloadIdentityOAuthProvider {
797 type Credential = AzureCredential;
798
799 async fn fetch_token(
801 &self,
802 client: &HttpClient,
803 retry: &RetryConfig,
804 ) -> crate::Result<TemporaryToken<Arc<AzureCredential>>> {
805 let token_str = std::fs::read_to_string(&self.federated_token_file)
806 .map_err(|_| Error::FederatedTokenFile)?;
807
808 let response: OAuthTokenResponse = client
810 .request(Method::POST, &self.token_url)
811 .header(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_JSON))
812 .form([
813 ("client_id", self.client_id.as_str()),
814 (
815 "client_assertion_type",
816 "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
817 ),
818 ("client_assertion", token_str.as_str()),
819 ("scope", AZURE_STORAGE_SCOPE),
820 ("grant_type", "client_credentials"),
821 ])
822 .retryable(retry)
823 .idempotent(true)
824 .send()
825 .await
826 .map_err(|source| Error::TokenRequest { source })?
827 .into_body()
828 .json()
829 .await
830 .map_err(|source| Error::TokenResponseBody { source })?;
831
832 Ok(TemporaryToken {
833 token: Arc::new(AzureCredential::BearerToken(response.access_token)),
834 expiry: Some(Instant::now() + Duration::from_secs(response.expires_in)),
835 })
836 }
837}
838
839mod az_cli_date_format {
840 use chrono::{DateTime, TimeZone};
841 use serde::{self, Deserialize, Deserializer};
842
843 pub(crate) fn deserialize<'de, D>(deserializer: D) -> Result<DateTime<chrono::Local>, D::Error>
844 where
845 D: Deserializer<'de>,
846 {
847 let s = String::deserialize(deserializer)?;
848 let date = chrono::NaiveDateTime::parse_from_str(&s, "%Y-%m-%d %H:%M:%S.%6f")
850 .map_err(serde::de::Error::custom)?;
851 chrono::Local
852 .from_local_datetime(&date)
853 .single()
854 .ok_or(serde::de::Error::custom(
855 "azure cli returned ambiguous expiry date",
856 ))
857 }
858}
859
860#[derive(Debug, Clone, Deserialize)]
861#[serde(rename_all = "camelCase")]
862struct AzureCliTokenResponse {
863 pub access_token: String,
864 #[serde(with = "az_cli_date_format")]
865 pub expires_on: DateTime<chrono::Local>,
866 pub token_type: String,
867}
868
869#[derive(Default, Debug)]
870pub(crate) struct AzureCliCredential {
871 cache: TokenCache<Arc<AzureCredential>>,
872}
873
874impl AzureCliCredential {
875 pub(crate) fn new() -> Self {
876 Self::default()
877 }
878
879 async fn fetch_token(&self) -> Result<TemporaryToken<Arc<AzureCredential>>> {
881 let program = if cfg!(target_os = "windows") {
884 "cmd"
885 } else {
886 "az"
887 };
888 let mut args = Vec::new();
889 if cfg!(target_os = "windows") {
890 args.push("/C");
891 args.push("az");
892 }
893 args.push("account");
894 args.push("get-access-token");
895 args.push("--output");
896 args.push("json");
897 args.push("--scope");
898 args.push(AZURE_STORAGE_SCOPE);
899
900 match Command::new(program).args(args).output() {
901 Ok(az_output) if az_output.status.success() => {
902 let output = str::from_utf8(&az_output.stdout).map_err(|_| Error::AzureCli {
903 message: "az response is not a valid utf-8 string".to_string(),
904 })?;
905
906 let token_response = serde_json::from_str::<AzureCliTokenResponse>(output)
907 .map_err(|source| Error::AzureCliResponse { source })?;
908
909 if !token_response.token_type.eq_ignore_ascii_case("bearer") {
910 return Err(Error::AzureCli {
911 message: format!(
912 "got unexpected token type from azure cli: {0}",
913 token_response.token_type
914 ),
915 });
916 }
917 let duration =
918 token_response.expires_on.naive_local() - chrono::Local::now().naive_local();
919 Ok(TemporaryToken {
920 token: Arc::new(AzureCredential::BearerToken(token_response.access_token)),
921 expiry: Some(
922 Instant::now()
923 + duration.to_std().map_err(|_| Error::AzureCli {
924 message: "az returned invalid lifetime".to_string(),
925 })?,
926 ),
927 })
928 }
929 Ok(az_output) => {
930 let message = String::from_utf8_lossy(&az_output.stderr);
931 Err(Error::AzureCli {
932 message: message.into(),
933 })
934 }
935 Err(e) => match e.kind() {
936 std::io::ErrorKind::NotFound => Err(Error::AzureCli {
937 message: "Azure Cli not installed".into(),
938 }),
939 error_kind => Err(Error::AzureCli {
940 message: format!("io error: {error_kind:?}"),
941 }),
942 },
943 }
944 }
945}
946
947#[derive(Debug)]
949pub(crate) struct FabricTokenOAuthProvider {
950 fabric_token_service_url: String,
951 fabric_workload_host: String,
952 fabric_session_token: String,
953 fabric_cluster_identifier: String,
954 storage_access_token: Option<String>,
955 token_expiry: Option<u64>,
956}
957
958#[derive(Debug, Deserialize)]
959struct Claims {
960 exp: u64,
961}
962
963impl FabricTokenOAuthProvider {
964 pub(crate) fn new(
966 fabric_token_service_url: impl Into<String>,
967 fabric_workload_host: impl Into<String>,
968 fabric_session_token: impl Into<String>,
969 fabric_cluster_identifier: impl Into<String>,
970 storage_access_token: Option<String>,
971 ) -> Self {
972 let (storage_access_token, token_expiry) = match storage_access_token {
973 Some(token) => match Self::validate_and_get_expiry(&token) {
974 Some(expiry) if expiry > Self::get_current_timestamp() + TOKEN_MIN_TTL => {
975 (Some(token), Some(expiry))
976 }
977 _ => (None, None),
978 },
979 None => (None, None),
980 };
981
982 Self {
983 fabric_token_service_url: fabric_token_service_url.into(),
984 fabric_workload_host: fabric_workload_host.into(),
985 fabric_session_token: fabric_session_token.into(),
986 fabric_cluster_identifier: fabric_cluster_identifier.into(),
987 storage_access_token,
988 token_expiry,
989 }
990 }
991
992 fn validate_and_get_expiry(token: &str) -> Option<u64> {
993 let payload = token.split('.').nth(1)?;
994 let decoded_bytes = BASE64_URL_SAFE_NO_PAD.decode(payload).ok()?;
995 let decoded_str = str::from_utf8(&decoded_bytes).ok()?;
996 let claims: Claims = serde_json::from_str(decoded_str).ok()?;
997 Some(claims.exp)
998 }
999
1000 fn get_current_timestamp() -> u64 {
1001 SystemTime::now()
1002 .duration_since(SystemTime::UNIX_EPOCH)
1003 .map_or(0, |d| d.as_secs())
1004 }
1005}
1006
1007#[async_trait::async_trait]
1008impl TokenProvider for FabricTokenOAuthProvider {
1009 type Credential = AzureCredential;
1010
1011 async fn fetch_token(
1013 &self,
1014 client: &HttpClient,
1015 retry: &RetryConfig,
1016 ) -> crate::Result<TemporaryToken<Arc<AzureCredential>>> {
1017 if let Some(storage_access_token) = &self.storage_access_token {
1018 if let Some(expiry) = self.token_expiry {
1019 let exp_in = expiry.saturating_sub(Self::get_current_timestamp());
1020 if exp_in > TOKEN_MIN_TTL {
1021 return Ok(TemporaryToken {
1022 token: Arc::new(AzureCredential::BearerToken(storage_access_token.clone())),
1023 expiry: Some(Instant::now() + Duration::from_secs(exp_in)),
1024 });
1025 }
1026 }
1027 }
1028
1029 let query_items = vec![("resource", AZURE_STORAGE_RESOURCE)];
1030 let access_token: String = client
1031 .request(Method::GET, &self.fabric_token_service_url)
1032 .header(&PARTNER_TOKEN, self.fabric_session_token.as_str())
1033 .header(&CLUSTER_IDENTIFIER, self.fabric_cluster_identifier.as_str())
1034 .header(&WORKLOAD_RESOURCE, self.fabric_cluster_identifier.as_str())
1035 .header(&PROXY_HOST, self.fabric_workload_host.as_str())
1036 .query(&query_items)
1037 .retryable(retry)
1038 .idempotent(true)
1039 .send()
1040 .await
1041 .map_err(|source| Error::TokenRequest { source })?
1042 .into_body()
1043 .text()
1044 .await
1045 .map_err(|source| Error::TokenResponseBody { source })?;
1046 let exp_in = Self::validate_and_get_expiry(&access_token).map_or(3600, |expiry| {
1047 expiry.saturating_sub(Self::get_current_timestamp())
1048 });
1049 Ok(TemporaryToken {
1050 token: Arc::new(AzureCredential::BearerToken(access_token)),
1051 expiry: Some(Instant::now() + Duration::from_secs(exp_in)),
1052 })
1053 }
1054}
1055
1056#[async_trait]
1057impl CredentialProvider for AzureCliCredential {
1058 type Credential = AzureCredential;
1059
1060 async fn get_credential(&self) -> crate::Result<Arc<Self::Credential>> {
1061 Ok(self.cache.get_or_insert_with(|| self.fetch_token()).await?)
1062 }
1063}
1064
1065#[cfg(test)]
1066mod tests {
1067 use futures_executor::block_on;
1068 use http::{Response, StatusCode};
1069 use http_body_util::BodyExt;
1070 use reqwest::{Client, Method};
1071 use tempfile::NamedTempFile;
1072
1073 use super::*;
1074 use crate::azure::MicrosoftAzureBuilder;
1075 use crate::client::mock_server::MockServer;
1076 use crate::{ObjectStoreExt, Path};
1077
1078 #[tokio::test]
1079 async fn test_managed_identity() {
1080 let server = MockServer::new().await;
1081
1082 unsafe { std::env::set_var(MSI_SECRET_ENV_KEY, "env-secret") };
1083
1084 let endpoint = server.url();
1085 let client = HttpClient::new(Client::new());
1086 let retry_config = RetryConfig::default();
1087
1088 server.push_fn(|req| {
1090 assert_eq!(req.uri().path(), "/metadata/identity/oauth2/token");
1091 assert!(req.uri().query().unwrap().contains("client_id=client_id"));
1092 assert_eq!(req.method(), &Method::GET);
1093 let t = req
1094 .headers()
1095 .get("x-identity-header")
1096 .unwrap()
1097 .to_str()
1098 .unwrap();
1099 assert_eq!(t, "env-secret");
1100 let t = req.headers().get("metadata").unwrap().to_str().unwrap();
1101 assert_eq!(t, "true");
1102 Response::new(
1103 r#"
1104 {
1105 "access_token": "TOKEN",
1106 "refresh_token": "",
1107 "expires_in": "3599",
1108 "expires_on": "1506484173",
1109 "not_before": "1506480273",
1110 "resource": "https://management.azure.com/",
1111 "token_type": "Bearer"
1112 }
1113 "#
1114 .to_string(),
1115 )
1116 });
1117
1118 let credential = ImdsManagedIdentityProvider::new(
1119 Some("client_id".into()),
1120 None,
1121 None,
1122 Some(format!("{endpoint}/metadata/identity/oauth2/token")),
1123 );
1124
1125 let token = credential
1126 .fetch_token(&client, &retry_config)
1127 .await
1128 .unwrap();
1129
1130 assert_eq!(
1131 token.token.as_ref(),
1132 &AzureCredential::BearerToken("TOKEN".into())
1133 );
1134 }
1135
1136 #[tokio::test]
1137 async fn test_workload_identity() {
1138 let server = MockServer::new().await;
1139 let tokenfile = NamedTempFile::new().unwrap();
1140 let tenant = "tenant";
1141 std::fs::write(tokenfile.path(), "federated-token").unwrap();
1142
1143 let endpoint = server.url();
1144 let client = HttpClient::new(Client::new());
1145 let retry_config = RetryConfig::default();
1146
1147 server.push_fn(move |req| {
1149 assert_eq!(req.uri().path(), format!("/{tenant}/oauth2/v2.0/token"));
1150 assert_eq!(req.method(), &Method::POST);
1151 let body = block_on(async move { req.into_body().collect().await.unwrap().to_bytes() });
1152 let body = String::from_utf8(body.to_vec()).unwrap();
1153 assert!(body.contains("federated-token"));
1154 Response::new(
1155 r#"
1156 {
1157 "access_token": "TOKEN",
1158 "refresh_token": "",
1159 "expires_in": 3599,
1160 "expires_on": "1506484173",
1161 "not_before": "1506480273",
1162 "resource": "https://management.azure.com/",
1163 "token_type": "Bearer"
1164 }
1165 "#
1166 .to_string(),
1167 )
1168 });
1169
1170 let credential = WorkloadIdentityOAuthProvider::new(
1171 "client_id",
1172 tokenfile.path().to_str().unwrap(),
1173 tenant,
1174 Some(endpoint.to_string()),
1175 );
1176
1177 let token = credential
1178 .fetch_token(&client, &retry_config)
1179 .await
1180 .unwrap();
1181
1182 assert_eq!(
1183 token.token.as_ref(),
1184 &AzureCredential::BearerToken("TOKEN".into())
1185 );
1186 }
1187
1188 #[tokio::test]
1189 async fn test_no_credentials() {
1190 let server = MockServer::new().await;
1191
1192 let endpoint = server.url();
1193 let store = MicrosoftAzureBuilder::new()
1194 .with_account("test")
1195 .with_container_name("test")
1196 .with_allow_http(true)
1197 .with_bearer_token_authorization("token")
1198 .with_endpoint(endpoint.to_string())
1199 .with_skip_signature(true)
1200 .build()
1201 .unwrap();
1202
1203 server.push_fn(|req| {
1204 assert_eq!(req.method(), &Method::GET);
1205 assert!(req.headers().get("Authorization").is_none());
1206 Response::builder()
1207 .status(StatusCode::NOT_FOUND)
1208 .body("not found".to_string())
1209 .unwrap()
1210 });
1211
1212 let path = Path::from("file.txt");
1213 match store.get(&path).await {
1214 Err(crate::Error::NotFound { .. }) => {}
1215 _ => {
1216 panic!("unexpected response");
1217 }
1218 }
1219 }
1220
1221 #[tokio::test]
1222 async fn test_fabric_refresh_expired_token() {
1223 let server = MockServer::new().await;
1224
1225 let expired_timestamp = FabricTokenOAuthProvider::get_current_timestamp() - 3600;
1227 let claims = format!(r#"{{"exp":{expired_timestamp}}}"#);
1228 let encoded_claims = BASE64_URL_SAFE_NO_PAD.encode(claims.as_bytes());
1229 let expired_token = format!("header.{encoded_claims}.signature");
1230
1231 let fresh_timestamp = FabricTokenOAuthProvider::get_current_timestamp() + 3600;
1233 let fresh_claims = format!(r#"{{"exp":{fresh_timestamp}}}"#);
1234 let fresh_encoded = BASE64_URL_SAFE_NO_PAD.encode(fresh_claims.as_bytes());
1235 let fresh_token = format!("header.{fresh_encoded}.signature");
1236 let expected_token = fresh_token.clone();
1237
1238 server.push_fn(move |req| {
1240 assert_eq!(req.headers().get(&PARTNER_TOKEN).unwrap(), "session-token");
1241 assert_eq!(
1242 req.headers().get(&CLUSTER_IDENTIFIER).unwrap(),
1243 "cluster-id"
1244 );
1245 assert_eq!(req.headers().get(&PROXY_HOST).unwrap(), "fake");
1246
1247 Response::new(fresh_token)
1248 });
1249
1250 let provider = FabricTokenOAuthProvider {
1251 fabric_token_service_url: server.url().to_string(),
1252 fabric_workload_host: "fake".to_string(),
1253 fabric_session_token: "session-token".to_string(),
1254 fabric_cluster_identifier: "cluster-id".to_string(),
1255 storage_access_token: Some(expired_token),
1256 token_expiry: Some(expired_timestamp),
1257 };
1258
1259 let client = HttpClient::new(Client::new());
1260 let retry = RetryConfig::default();
1261
1262 let result = provider.fetch_token(&client, &retry).await;
1263
1264 assert!(
1265 result.is_ok(),
1266 "fetch_token should handle expired cached token gracefully"
1267 );
1268
1269 let temp_token = result.unwrap();
1270
1271 if let AzureCredential::BearerToken(token) = temp_token.token.as_ref() {
1273 assert_eq!(
1274 token, &expected_token,
1275 "Should have fetched fresh token from API, not returned expired cached token"
1276 );
1277 }
1278 }
1279}