1use super::client::UserDelegationKey;
19use crate::azure::STORE;
20use crate::client::builder::{add_query_pairs, HttpRequestBuilder};
21use crate::client::retry::RetryExt;
22use crate::client::token::{TemporaryToken, TokenCache};
23use crate::client::{CredentialProvider, HttpClient, HttpError, HttpRequest, TokenProvider};
24use crate::util::hmac_sha256;
25use crate::RetryConfig;
26use async_trait::async_trait;
27use base64::prelude::{BASE64_STANDARD, BASE64_URL_SAFE_NO_PAD};
28use base64::Engine;
29use chrono::{DateTime, SecondsFormat, Utc};
30use http::header::{
31 HeaderMap, HeaderName, HeaderValue, ACCEPT, AUTHORIZATION, CONTENT_ENCODING, CONTENT_LANGUAGE,
32 CONTENT_LENGTH, CONTENT_TYPE, DATE, IF_MATCH, IF_MODIFIED_SINCE, IF_NONE_MATCH,
33 IF_UNMODIFIED_SINCE, RANGE,
34};
35use http::Method;
36use serde::Deserialize;
37use std::borrow::Cow;
38use std::collections::HashMap;
39use std::fmt::Debug;
40use std::ops::Deref;
41use std::process::Command;
42use std::str;
43use std::sync::Arc;
44use std::time::{Duration, Instant, SystemTime};
45use url::Url;
46
47static AZURE_VERSION: HeaderValue = HeaderValue::from_static("2023-11-03");
48static VERSION: HeaderName = HeaderName::from_static("x-ms-version");
49pub(crate) static BLOB_TYPE: HeaderName = HeaderName::from_static("x-ms-blob-type");
50pub(crate) static DELETE_SNAPSHOTS: HeaderName = HeaderName::from_static("x-ms-delete-snapshots");
51pub(crate) static COPY_SOURCE: HeaderName = HeaderName::from_static("x-ms-copy-source");
52static CONTENT_MD5: HeaderName = HeaderName::from_static("content-md5");
53static PARTNER_TOKEN: HeaderName = HeaderName::from_static("x-ms-partner-token");
54static CLUSTER_IDENTIFIER: HeaderName = HeaderName::from_static("x-ms-cluster-identifier");
55static WORKLOAD_RESOURCE: HeaderName = HeaderName::from_static("x-ms-workload-resource-moniker");
56static PROXY_HOST: HeaderName = HeaderName::from_static("x-ms-proxy-host");
57pub(crate) const RFC1123_FMT: &str = "%a, %d %h %Y %T GMT";
58const CONTENT_TYPE_JSON: &str = "application/json";
59const MSI_SECRET_ENV_KEY: &str = "IDENTITY_HEADER";
60const MSI_API_VERSION: &str = "2019-08-01";
61const TOKEN_MIN_TTL: u64 = 300;
62
63const AZURE_STORAGE_SCOPE: &str = "https://storage.azure.com/.default";
67
68const AZURE_STORAGE_RESOURCE: &str = "https://storage.azure.com";
72
73#[derive(Debug, thiserror::Error)]
74pub enum Error {
75 #[error("Error performing token request: {}", source)]
76 TokenRequest {
77 source: crate::client::retry::RetryError,
78 },
79
80 #[error("Error getting token response body: {}", source)]
81 TokenResponseBody { source: HttpError },
82
83 #[error("Error reading federated token file ")]
84 FederatedTokenFile,
85
86 #[error("Invalid Access Key: {}", source)]
87 InvalidAccessKey { source: base64::DecodeError },
88
89 #[error("'az account get-access-token' command failed: {message}")]
90 AzureCli { message: String },
91
92 #[error("Failed to parse azure cli response: {source}")]
93 AzureCliResponse { source: serde_json::Error },
94
95 #[error("Generating SAS keys with SAS tokens auth is not supported")]
96 SASforSASNotSupported,
97}
98
99pub(crate) type Result<T, E = Error> = std::result::Result<T, E>;
100
101impl From<Error> for crate::Error {
102 fn from(value: Error) -> Self {
103 Self::Generic {
104 store: STORE,
105 source: Box::new(value),
106 }
107 }
108}
109
110#[derive(Debug, Clone, Eq, PartialEq)]
112pub struct AzureAccessKey(Vec<u8>);
113
114impl AzureAccessKey {
115 pub fn try_new(key: &str) -> Result<Self> {
117 let key = BASE64_STANDARD
118 .decode(key)
119 .map_err(|source| Error::InvalidAccessKey { source })?;
120
121 Ok(Self(key))
122 }
123}
124
125#[derive(Debug, Eq, PartialEq)]
127pub enum AzureCredential {
128 AccessKey(AzureAccessKey),
132 SASToken(Vec<(String, String)>),
136 BearerToken(String),
140}
141
142impl AzureCredential {
143 pub fn sensitive_request(&self) -> bool {
145 match self {
146 Self::AccessKey(_) => false,
147 Self::BearerToken(_) => false,
148 Self::SASToken(_) => true,
150 }
151 }
152}
153
154pub mod authority_hosts {
156 pub const AZURE_CHINA: &str = "https://login.chinacloudapi.cn";
158 pub const AZURE_GERMANY: &str = "https://login.microsoftonline.de";
160 pub const AZURE_GOVERNMENT: &str = "https://login.microsoftonline.us";
162 pub const AZURE_PUBLIC_CLOUD: &str = "https://login.microsoftonline.com";
164}
165
166pub(crate) struct AzureSigner {
167 signing_key: AzureAccessKey,
168 start: DateTime<Utc>,
169 end: DateTime<Utc>,
170 account: String,
171 delegation_key: Option<UserDelegationKey>,
172}
173
174impl AzureSigner {
175 pub(crate) fn new(
176 signing_key: AzureAccessKey,
177 account: String,
178 start: DateTime<Utc>,
179 end: DateTime<Utc>,
180 delegation_key: Option<UserDelegationKey>,
181 ) -> Self {
182 Self {
183 signing_key,
184 account,
185 start,
186 end,
187 delegation_key,
188 }
189 }
190
191 pub(crate) fn sign(&self, method: &Method, url: &mut Url) -> Result<()> {
192 let (str_to_sign, query_pairs) = match &self.delegation_key {
193 Some(delegation_key) => string_to_sign_user_delegation_sas(
194 url,
195 method,
196 &self.account,
197 &self.start,
198 &self.end,
199 delegation_key,
200 ),
201 None => string_to_sign_service_sas(url, method, &self.account, &self.start, &self.end),
202 };
203 let auth = hmac_sha256(&self.signing_key.0, str_to_sign);
204 url.query_pairs_mut().extend_pairs(query_pairs);
205 url.query_pairs_mut()
206 .append_pair("sig", BASE64_STANDARD.encode(auth).as_str());
207 Ok(())
208 }
209}
210
211fn add_date_and_version_headers(request: &mut HttpRequest) {
212 let date = Utc::now();
214 let date_str = date.format(RFC1123_FMT).to_string();
215 let date_val = HeaderValue::from_str(&date_str).unwrap();
217 request.headers_mut().insert(DATE, date_val);
218 request
219 .headers_mut()
220 .insert(&VERSION, AZURE_VERSION.clone());
221}
222
223#[derive(Debug)]
225pub struct AzureAuthorizer<'a> {
226 credential: &'a AzureCredential,
227 account: &'a str,
228}
229
230impl<'a> AzureAuthorizer<'a> {
231 pub fn new(credential: &'a AzureCredential, account: &'a str) -> Self {
233 AzureAuthorizer {
234 credential,
235 account,
236 }
237 }
238
239 pub fn authorize(&self, request: &mut HttpRequest) {
241 add_date_and_version_headers(request);
242
243 match self.credential {
244 AzureCredential::AccessKey(key) => {
245 let url = Url::parse(&request.uri().to_string()).unwrap();
246 let signature = generate_authorization(
247 request.headers(),
248 &url,
249 request.method(),
250 self.account,
251 key,
252 );
253
254 request.headers_mut().append(
257 AUTHORIZATION,
258 HeaderValue::from_str(signature.as_str()).unwrap(),
259 );
260 }
261 AzureCredential::BearerToken(token) => {
262 request.headers_mut().append(
263 AUTHORIZATION,
264 HeaderValue::from_str(format!("Bearer {token}").as_str()).unwrap(),
265 );
266 }
267 AzureCredential::SASToken(query_pairs) => {
268 add_query_pairs(request.uri_mut(), query_pairs);
269 }
270 }
271 }
272}
273
274pub(crate) trait CredentialExt {
275 fn with_azure_authorization(
278 self,
279 credential: &Option<impl Deref<Target = AzureCredential>>,
280 account: &str,
281 ) -> Self;
282}
283
284impl CredentialExt for HttpRequestBuilder {
285 fn with_azure_authorization(
286 self,
287 credential: &Option<impl Deref<Target = AzureCredential>>,
288 account: &str,
289 ) -> Self {
290 let (client, request) = self.into_parts();
291 let mut request = request.expect("request valid");
292
293 match credential.as_deref() {
294 Some(credential) => {
295 AzureAuthorizer::new(credential, account).authorize(&mut request);
296 }
297 None => {
298 add_date_and_version_headers(&mut request);
299 }
300 }
301
302 Self::from_parts(client, request)
303 }
304}
305
306fn generate_authorization(
309 h: &HeaderMap,
310 u: &Url,
311 method: &Method,
312 account: &str,
313 key: &AzureAccessKey,
314) -> String {
315 let str_to_sign = string_to_sign(h, u, method, account);
316 let auth = hmac_sha256(&key.0, str_to_sign);
317 format!("SharedKey {}:{}", account, BASE64_STANDARD.encode(auth))
318}
319
320fn add_if_exists<'a>(h: &'a HeaderMap, key: &HeaderName) -> &'a str {
321 h.get(key)
322 .map(|s| s.to_str())
323 .transpose()
324 .ok()
325 .flatten()
326 .unwrap_or_default()
327}
328
329fn string_to_sign_sas(
330 u: &Url,
331 method: &Method,
332 account: &str,
333 start: &DateTime<Utc>,
334 end: &DateTime<Utc>,
335) -> (String, String, String, String, String) {
336 let signed_resource = "b".to_string();
338
339 let signed_permissions = match *method {
341 Method::GET => match signed_resource.as_str() {
343 "c" => "rl",
344 "b" => "r",
345 _ => unreachable!(),
346 },
347 Method::PUT => "w",
349 Method::DELETE => "d",
351 _ => "",
353 }
354 .to_string();
355 let signed_start = start.to_rfc3339_opts(SecondsFormat::Secs, true);
356 let signed_expiry = end.to_rfc3339_opts(SecondsFormat::Secs, true);
357 let canonicalized_resource = if u.host_str().unwrap_or_default().contains(account) {
358 format!("/blob/{}{}", account, u.path())
359 } else {
360 format!("/blob{}", u.path())
363 };
364
365 (
366 signed_resource,
367 signed_permissions,
368 signed_start,
369 signed_expiry,
370 canonicalized_resource,
371 )
372}
373
374fn string_to_sign_service_sas(
378 u: &Url,
379 method: &Method,
380 account: &str,
381 start: &DateTime<Utc>,
382 end: &DateTime<Utc>,
383) -> (String, HashMap<&'static str, String>) {
384 let (signed_resource, signed_permissions, signed_start, signed_expiry, canonicalized_resource) =
385 string_to_sign_sas(u, method, account, start, end);
386
387 let string_to_sign = format!(
388 "{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}",
389 signed_permissions,
390 signed_start,
391 signed_expiry,
392 canonicalized_resource,
393 "", "", "", &AZURE_VERSION.to_str().unwrap(), signed_resource, "", "", "", "", "", "", "", );
406
407 let mut pairs = HashMap::new();
408 pairs.insert("sv", AZURE_VERSION.to_str().unwrap().to_string());
409 pairs.insert("sp", signed_permissions);
410 pairs.insert("st", signed_start);
411 pairs.insert("se", signed_expiry);
412 pairs.insert("sr", signed_resource);
413
414 (string_to_sign, pairs)
415}
416
417fn string_to_sign_user_delegation_sas(
421 u: &Url,
422 method: &Method,
423 account: &str,
424 start: &DateTime<Utc>,
425 end: &DateTime<Utc>,
426 delegation_key: &UserDelegationKey,
427) -> (String, HashMap<&'static str, String>) {
428 let (signed_resource, signed_permissions, signed_start, signed_expiry, canonicalized_resource) =
429 string_to_sign_sas(u, method, account, start, end);
430
431 let string_to_sign = format!(
432 "{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}",
433 signed_permissions,
434 signed_start,
435 signed_expiry,
436 canonicalized_resource,
437 delegation_key.signed_oid, delegation_key.signed_tid, delegation_key.signed_start, delegation_key.signed_expiry, delegation_key.signed_service, delegation_key.signed_version, "", "", "", "", "", &AZURE_VERSION.to_str().unwrap(), signed_resource, "", "", "", "", "", "", "", );
458
459 let mut pairs = HashMap::new();
460 pairs.insert("sv", AZURE_VERSION.to_str().unwrap().to_string());
461 pairs.insert("sp", signed_permissions);
462 pairs.insert("st", signed_start);
463 pairs.insert("se", signed_expiry);
464 pairs.insert("sr", signed_resource);
465 pairs.insert("skoid", delegation_key.signed_oid.clone());
466 pairs.insert("sktid", delegation_key.signed_tid.clone());
467 pairs.insert("skt", delegation_key.signed_start.clone());
468 pairs.insert("ske", delegation_key.signed_expiry.clone());
469 pairs.insert("sks", delegation_key.signed_service.clone());
470 pairs.insert("skv", delegation_key.signed_version.clone());
471
472 (string_to_sign, pairs)
473}
474
475fn string_to_sign(h: &HeaderMap, u: &Url, method: &Method, account: &str) -> String {
477 let content_length = h
480 .get(&CONTENT_LENGTH)
481 .map(|s| s.to_str())
482 .transpose()
483 .ok()
484 .flatten()
485 .filter(|&v| v != "0")
486 .unwrap_or_default();
487 format!(
488 "{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}{}",
489 method.as_ref(),
490 add_if_exists(h, &CONTENT_ENCODING),
491 add_if_exists(h, &CONTENT_LANGUAGE),
492 content_length,
493 add_if_exists(h, &CONTENT_MD5),
494 add_if_exists(h, &CONTENT_TYPE),
495 add_if_exists(h, &DATE),
496 add_if_exists(h, &IF_MODIFIED_SINCE),
497 add_if_exists(h, &IF_MATCH),
498 add_if_exists(h, &IF_NONE_MATCH),
499 add_if_exists(h, &IF_UNMODIFIED_SINCE),
500 add_if_exists(h, &RANGE),
501 canonicalize_header(h),
502 canonicalize_resource(account, u)
503 )
504}
505
506fn canonicalize_header(headers: &HeaderMap) -> String {
508 let mut names = headers
509 .iter()
510 .filter(|&(k, _)| (k.as_str().starts_with("x-ms")))
511 .map(|(k, _)| (k.as_str(), headers.get(k).unwrap().to_str().unwrap()))
513 .collect::<Vec<_>>();
514 names.sort_unstable();
515
516 let mut result = String::new();
517 for (name, value) in names {
518 result.push_str(name);
519 result.push(':');
520 result.push_str(value);
521 result.push('\n');
522 }
523 result
524}
525
526fn canonicalize_resource(account: &str, uri: &Url) -> String {
528 let mut can_res: String = String::new();
529 can_res.push('/');
530 can_res.push_str(account);
531 can_res.push_str(uri.path().to_string().as_str());
532 can_res.push('\n');
533
534 let query_pairs = uri.query_pairs();
536 {
537 let mut qps: Vec<String> = Vec::new();
538 for (q, _) in query_pairs {
539 if !(qps.iter().any(|x| x == &*q)) {
540 qps.push(q.into_owned());
541 }
542 }
543
544 qps.sort();
545
546 for qparam in qps {
547 let ret = lexy_sort(query_pairs, &qparam);
549
550 can_res = can_res + &qparam.to_lowercase() + ":";
551
552 for (i, item) in ret.iter().enumerate() {
553 if i > 0 {
554 can_res.push(',');
555 }
556 can_res.push_str(item);
557 }
558
559 can_res.push('\n');
560 }
561 };
562
563 can_res[0..can_res.len() - 1].to_owned()
564}
565
566fn lexy_sort<'a>(
567 vec: impl Iterator<Item = (Cow<'a, str>, Cow<'a, str>)> + 'a,
568 query_param: &str,
569) -> Vec<Cow<'a, str>> {
570 let mut values = vec
571 .filter(|(k, _)| *k == query_param)
572 .map(|(_, v)| v)
573 .collect::<Vec<_>>();
574 values.sort_unstable();
575 values
576}
577
578#[derive(Deserialize, Debug)]
580struct OAuthTokenResponse {
581 access_token: String,
582 expires_in: u64,
583}
584
585#[derive(Debug)]
589pub(crate) struct ClientSecretOAuthProvider {
590 token_url: String,
591 client_id: String,
592 client_secret: String,
593}
594
595impl ClientSecretOAuthProvider {
596 pub(crate) fn new(
598 client_id: String,
599 client_secret: String,
600 tenant_id: impl AsRef<str>,
601 authority_host: Option<String>,
602 ) -> Self {
603 let authority_host =
604 authority_host.unwrap_or_else(|| authority_hosts::AZURE_PUBLIC_CLOUD.to_owned());
605
606 Self {
607 token_url: format!(
608 "{}/{}/oauth2/v2.0/token",
609 authority_host,
610 tenant_id.as_ref()
611 ),
612 client_id,
613 client_secret,
614 }
615 }
616}
617
618#[async_trait::async_trait]
619impl TokenProvider for ClientSecretOAuthProvider {
620 type Credential = AzureCredential;
621
622 async fn fetch_token(
624 &self,
625 client: &HttpClient,
626 retry: &RetryConfig,
627 ) -> crate::Result<TemporaryToken<Arc<AzureCredential>>> {
628 let response: OAuthTokenResponse = client
629 .request(Method::POST, &self.token_url)
630 .header(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_JSON))
631 .form([
632 ("client_id", self.client_id.as_str()),
633 ("client_secret", self.client_secret.as_str()),
634 ("scope", AZURE_STORAGE_SCOPE),
635 ("grant_type", "client_credentials"),
636 ])
637 .retryable(retry)
638 .idempotent(true)
639 .send()
640 .await
641 .map_err(|source| Error::TokenRequest { source })?
642 .into_body()
643 .json()
644 .await
645 .map_err(|source| Error::TokenResponseBody { source })?;
646
647 Ok(TemporaryToken {
648 token: Arc::new(AzureCredential::BearerToken(response.access_token)),
649 expiry: Some(Instant::now() + Duration::from_secs(response.expires_in)),
650 })
651 }
652}
653
654fn expires_on_string<'de, D>(deserializer: D) -> std::result::Result<Instant, D::Error>
655where
656 D: serde::de::Deserializer<'de>,
657{
658 let v = String::deserialize(deserializer)?;
659 let v = v.parse::<u64>().map_err(serde::de::Error::custom)?;
660 let now = SystemTime::now()
661 .duration_since(SystemTime::UNIX_EPOCH)
662 .map_err(serde::de::Error::custom)?;
663
664 Ok(Instant::now() + Duration::from_secs(v.saturating_sub(now.as_secs())))
665}
666
667#[derive(Debug, Clone, Deserialize)]
671struct ImdsTokenResponse {
672 pub access_token: String,
673 #[serde(deserialize_with = "expires_on_string")]
674 pub expires_on: Instant,
675}
676
677#[derive(Debug)]
682pub(crate) struct ImdsManagedIdentityProvider {
683 msi_endpoint: String,
684 client_id: Option<String>,
685 object_id: Option<String>,
686 msi_res_id: Option<String>,
687}
688
689impl ImdsManagedIdentityProvider {
690 pub(crate) fn new(
692 client_id: Option<String>,
693 object_id: Option<String>,
694 msi_res_id: Option<String>,
695 msi_endpoint: Option<String>,
696 ) -> Self {
697 let msi_endpoint = msi_endpoint
698 .unwrap_or_else(|| "http://169.254.169.254/metadata/identity/oauth2/token".to_owned());
699
700 Self {
701 msi_endpoint,
702 client_id,
703 object_id,
704 msi_res_id,
705 }
706 }
707}
708
709#[async_trait::async_trait]
710impl TokenProvider for ImdsManagedIdentityProvider {
711 type Credential = AzureCredential;
712
713 async fn fetch_token(
715 &self,
716 client: &HttpClient,
717 retry: &RetryConfig,
718 ) -> crate::Result<TemporaryToken<Arc<AzureCredential>>> {
719 let mut query_items = vec![
720 ("api-version", MSI_API_VERSION),
721 ("resource", AZURE_STORAGE_RESOURCE),
722 ];
723
724 let mut identity = None;
725 if let Some(client_id) = &self.client_id {
726 identity = Some(("client_id", client_id));
727 }
728 if let Some(object_id) = &self.object_id {
729 identity = Some(("object_id", object_id));
730 }
731 if let Some(msi_res_id) = &self.msi_res_id {
732 identity = Some(("msi_res_id", msi_res_id));
733 }
734 if let Some((key, value)) = identity {
735 query_items.push((key, value));
736 }
737
738 let mut builder = client
739 .request(Method::GET, &self.msi_endpoint)
740 .header("metadata", "true")
741 .query(&query_items);
742
743 if let Ok(val) = std::env::var(MSI_SECRET_ENV_KEY) {
744 builder = builder.header("x-identity-header", val);
745 };
746
747 let response: ImdsTokenResponse = builder
748 .send_retry(retry)
749 .await
750 .map_err(|source| Error::TokenRequest { source })?
751 .into_body()
752 .json()
753 .await
754 .map_err(|source| Error::TokenResponseBody { source })?;
755
756 Ok(TemporaryToken {
757 token: Arc::new(AzureCredential::BearerToken(response.access_token)),
758 expiry: Some(response.expires_on),
759 })
760 }
761}
762
763#[derive(Debug)]
767pub(crate) struct WorkloadIdentityOAuthProvider {
768 token_url: String,
769 client_id: String,
770 federated_token_file: String,
771}
772
773impl WorkloadIdentityOAuthProvider {
774 pub(crate) fn new(
776 client_id: impl Into<String>,
777 federated_token_file: impl Into<String>,
778 tenant_id: impl AsRef<str>,
779 authority_host: Option<String>,
780 ) -> Self {
781 let authority_host =
782 authority_host.unwrap_or_else(|| authority_hosts::AZURE_PUBLIC_CLOUD.to_owned());
783
784 Self {
785 token_url: format!(
786 "{}/{}/oauth2/v2.0/token",
787 authority_host,
788 tenant_id.as_ref()
789 ),
790 client_id: client_id.into(),
791 federated_token_file: federated_token_file.into(),
792 }
793 }
794}
795
796#[async_trait::async_trait]
797impl TokenProvider for WorkloadIdentityOAuthProvider {
798 type Credential = AzureCredential;
799
800 async fn fetch_token(
802 &self,
803 client: &HttpClient,
804 retry: &RetryConfig,
805 ) -> crate::Result<TemporaryToken<Arc<AzureCredential>>> {
806 let token_str = std::fs::read_to_string(&self.federated_token_file)
807 .map_err(|_| Error::FederatedTokenFile)?;
808
809 let response: OAuthTokenResponse = client
811 .request(Method::POST, &self.token_url)
812 .header(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_JSON))
813 .form([
814 ("client_id", self.client_id.as_str()),
815 (
816 "client_assertion_type",
817 "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
818 ),
819 ("client_assertion", token_str.as_str()),
820 ("scope", AZURE_STORAGE_SCOPE),
821 ("grant_type", "client_credentials"),
822 ])
823 .retryable(retry)
824 .idempotent(true)
825 .send()
826 .await
827 .map_err(|source| Error::TokenRequest { source })?
828 .into_body()
829 .json()
830 .await
831 .map_err(|source| Error::TokenResponseBody { source })?;
832
833 Ok(TemporaryToken {
834 token: Arc::new(AzureCredential::BearerToken(response.access_token)),
835 expiry: Some(Instant::now() + Duration::from_secs(response.expires_in)),
836 })
837 }
838}
839
840mod az_cli_date_format {
841 use chrono::{DateTime, TimeZone};
842 use serde::{self, Deserialize, Deserializer};
843
844 pub(crate) fn deserialize<'de, D>(deserializer: D) -> Result<DateTime<chrono::Local>, D::Error>
845 where
846 D: Deserializer<'de>,
847 {
848 let s = String::deserialize(deserializer)?;
849 let date = chrono::NaiveDateTime::parse_from_str(&s, "%Y-%m-%d %H:%M:%S.%6f")
851 .map_err(serde::de::Error::custom)?;
852 chrono::Local
853 .from_local_datetime(&date)
854 .single()
855 .ok_or(serde::de::Error::custom(
856 "azure cli returned ambiguous expiry date",
857 ))
858 }
859}
860
861#[derive(Debug, Clone, Deserialize)]
862#[serde(rename_all = "camelCase")]
863struct AzureCliTokenResponse {
864 pub access_token: String,
865 #[serde(with = "az_cli_date_format")]
866 pub expires_on: DateTime<chrono::Local>,
867 pub token_type: String,
868}
869
870#[derive(Default, Debug)]
871pub(crate) struct AzureCliCredential {
872 cache: TokenCache<Arc<AzureCredential>>,
873}
874
875impl AzureCliCredential {
876 pub(crate) fn new() -> Self {
877 Self::default()
878 }
879
880 async fn fetch_token(&self) -> Result<TemporaryToken<Arc<AzureCredential>>> {
882 let program = if cfg!(target_os = "windows") {
885 "cmd"
886 } else {
887 "az"
888 };
889 let mut args = Vec::new();
890 if cfg!(target_os = "windows") {
891 args.push("/C");
892 args.push("az");
893 }
894 args.push("account");
895 args.push("get-access-token");
896 args.push("--output");
897 args.push("json");
898 args.push("--scope");
899 args.push(AZURE_STORAGE_SCOPE);
900
901 match Command::new(program).args(args).output() {
902 Ok(az_output) if az_output.status.success() => {
903 let output = str::from_utf8(&az_output.stdout).map_err(|_| Error::AzureCli {
904 message: "az response is not a valid utf-8 string".to_string(),
905 })?;
906
907 let token_response = serde_json::from_str::<AzureCliTokenResponse>(output)
908 .map_err(|source| Error::AzureCliResponse { source })?;
909
910 if !token_response.token_type.eq_ignore_ascii_case("bearer") {
911 return Err(Error::AzureCli {
912 message: format!(
913 "got unexpected token type from azure cli: {0}",
914 token_response.token_type
915 ),
916 });
917 }
918 let duration =
919 token_response.expires_on.naive_local() - chrono::Local::now().naive_local();
920 Ok(TemporaryToken {
921 token: Arc::new(AzureCredential::BearerToken(token_response.access_token)),
922 expiry: Some(
923 Instant::now()
924 + duration.to_std().map_err(|_| Error::AzureCli {
925 message: "az returned invalid lifetime".to_string(),
926 })?,
927 ),
928 })
929 }
930 Ok(az_output) => {
931 let message = String::from_utf8_lossy(&az_output.stderr);
932 Err(Error::AzureCli {
933 message: message.into(),
934 })
935 }
936 Err(e) => match e.kind() {
937 std::io::ErrorKind::NotFound => Err(Error::AzureCli {
938 message: "Azure Cli not installed".into(),
939 }),
940 error_kind => Err(Error::AzureCli {
941 message: format!("io error: {error_kind:?}"),
942 }),
943 },
944 }
945 }
946}
947
948#[derive(Debug)]
950pub(crate) struct FabricTokenOAuthProvider {
951 fabric_token_service_url: String,
952 fabric_workload_host: String,
953 fabric_session_token: String,
954 fabric_cluster_identifier: String,
955 storage_access_token: Option<String>,
956 token_expiry: Option<u64>,
957}
958
959#[derive(Debug, Deserialize)]
960struct Claims {
961 exp: u64,
962}
963
964impl FabricTokenOAuthProvider {
965 pub(crate) fn new(
967 fabric_token_service_url: impl Into<String>,
968 fabric_workload_host: impl Into<String>,
969 fabric_session_token: impl Into<String>,
970 fabric_cluster_identifier: impl Into<String>,
971 storage_access_token: Option<String>,
972 ) -> Self {
973 let (storage_access_token, token_expiry) = match storage_access_token {
974 Some(token) => match Self::validate_and_get_expiry(&token) {
975 Some(expiry) if expiry > Self::get_current_timestamp() + TOKEN_MIN_TTL => {
976 (Some(token), Some(expiry))
977 }
978 _ => (None, None),
979 },
980 None => (None, None),
981 };
982
983 Self {
984 fabric_token_service_url: fabric_token_service_url.into(),
985 fabric_workload_host: fabric_workload_host.into(),
986 fabric_session_token: fabric_session_token.into(),
987 fabric_cluster_identifier: fabric_cluster_identifier.into(),
988 storage_access_token,
989 token_expiry,
990 }
991 }
992
993 fn validate_and_get_expiry(token: &str) -> Option<u64> {
994 let payload = token.split('.').nth(1)?;
995 let decoded_bytes = BASE64_URL_SAFE_NO_PAD.decode(payload).ok()?;
996 let decoded_str = str::from_utf8(&decoded_bytes).ok()?;
997 let claims: Claims = serde_json::from_str(decoded_str).ok()?;
998 Some(claims.exp)
999 }
1000
1001 fn get_current_timestamp() -> u64 {
1002 SystemTime::now()
1003 .duration_since(SystemTime::UNIX_EPOCH)
1004 .map_or(0, |d| d.as_secs())
1005 }
1006}
1007
1008#[async_trait::async_trait]
1009impl TokenProvider for FabricTokenOAuthProvider {
1010 type Credential = AzureCredential;
1011
1012 async fn fetch_token(
1014 &self,
1015 client: &HttpClient,
1016 retry: &RetryConfig,
1017 ) -> crate::Result<TemporaryToken<Arc<AzureCredential>>> {
1018 if let Some(storage_access_token) = &self.storage_access_token {
1019 if let Some(expiry) = self.token_expiry {
1020 let exp_in = expiry - Self::get_current_timestamp();
1021 if exp_in > TOKEN_MIN_TTL {
1022 return Ok(TemporaryToken {
1023 token: Arc::new(AzureCredential::BearerToken(storage_access_token.clone())),
1024 expiry: Some(Instant::now() + Duration::from_secs(exp_in)),
1025 });
1026 }
1027 }
1028 }
1029
1030 let query_items = vec![("resource", AZURE_STORAGE_RESOURCE)];
1031 let access_token: String = client
1032 .request(Method::GET, &self.fabric_token_service_url)
1033 .header(&PARTNER_TOKEN, self.fabric_session_token.as_str())
1034 .header(&CLUSTER_IDENTIFIER, self.fabric_cluster_identifier.as_str())
1035 .header(&WORKLOAD_RESOURCE, self.fabric_cluster_identifier.as_str())
1036 .header(&PROXY_HOST, self.fabric_workload_host.as_str())
1037 .query(&query_items)
1038 .retryable(retry)
1039 .idempotent(true)
1040 .send()
1041 .await
1042 .map_err(|source| Error::TokenRequest { source })?
1043 .into_body()
1044 .text()
1045 .await
1046 .map_err(|source| Error::TokenResponseBody { source })?;
1047 let exp_in = Self::validate_and_get_expiry(&access_token)
1048 .map_or(3600, |expiry| expiry - Self::get_current_timestamp());
1049 Ok(TemporaryToken {
1050 token: Arc::new(AzureCredential::BearerToken(access_token)),
1051 expiry: Some(Instant::now() + Duration::from_secs(exp_in)),
1052 })
1053 }
1054}
1055
1056#[async_trait]
1057impl CredentialProvider for AzureCliCredential {
1058 type Credential = AzureCredential;
1059
1060 async fn get_credential(&self) -> crate::Result<Arc<Self::Credential>> {
1061 Ok(self.cache.get_or_insert_with(|| self.fetch_token()).await?)
1062 }
1063}
1064
1065#[cfg(test)]
1066mod tests {
1067 use futures::executor::block_on;
1068 use http::{Response, StatusCode};
1069 use http_body_util::BodyExt;
1070 use reqwest::{Client, Method};
1071 use tempfile::NamedTempFile;
1072
1073 use super::*;
1074 use crate::azure::MicrosoftAzureBuilder;
1075 use crate::client::mock_server::MockServer;
1076 use crate::{ObjectStore, Path};
1077
1078 #[tokio::test]
1079 async fn test_managed_identity() {
1080 let server = MockServer::new().await;
1081
1082 std::env::set_var(MSI_SECRET_ENV_KEY, "env-secret");
1083
1084 let endpoint = server.url();
1085 let client = HttpClient::new(Client::new());
1086 let retry_config = RetryConfig::default();
1087
1088 server.push_fn(|req| {
1090 assert_eq!(req.uri().path(), "/metadata/identity/oauth2/token");
1091 assert!(req.uri().query().unwrap().contains("client_id=client_id"));
1092 assert_eq!(req.method(), &Method::GET);
1093 let t = req
1094 .headers()
1095 .get("x-identity-header")
1096 .unwrap()
1097 .to_str()
1098 .unwrap();
1099 assert_eq!(t, "env-secret");
1100 let t = req.headers().get("metadata").unwrap().to_str().unwrap();
1101 assert_eq!(t, "true");
1102 Response::new(
1103 r#"
1104 {
1105 "access_token": "TOKEN",
1106 "refresh_token": "",
1107 "expires_in": "3599",
1108 "expires_on": "1506484173",
1109 "not_before": "1506480273",
1110 "resource": "https://management.azure.com/",
1111 "token_type": "Bearer"
1112 }
1113 "#
1114 .to_string(),
1115 )
1116 });
1117
1118 let credential = ImdsManagedIdentityProvider::new(
1119 Some("client_id".into()),
1120 None,
1121 None,
1122 Some(format!("{endpoint}/metadata/identity/oauth2/token")),
1123 );
1124
1125 let token = credential
1126 .fetch_token(&client, &retry_config)
1127 .await
1128 .unwrap();
1129
1130 assert_eq!(
1131 token.token.as_ref(),
1132 &AzureCredential::BearerToken("TOKEN".into())
1133 );
1134 }
1135
1136 #[tokio::test]
1137 async fn test_workload_identity() {
1138 let server = MockServer::new().await;
1139 let tokenfile = NamedTempFile::new().unwrap();
1140 let tenant = "tenant";
1141 std::fs::write(tokenfile.path(), "federated-token").unwrap();
1142
1143 let endpoint = server.url();
1144 let client = HttpClient::new(Client::new());
1145 let retry_config = RetryConfig::default();
1146
1147 server.push_fn(move |req| {
1149 assert_eq!(req.uri().path(), format!("/{tenant}/oauth2/v2.0/token"));
1150 assert_eq!(req.method(), &Method::POST);
1151 let body = block_on(async move { req.into_body().collect().await.unwrap().to_bytes() });
1152 let body = String::from_utf8(body.to_vec()).unwrap();
1153 assert!(body.contains("federated-token"));
1154 Response::new(
1155 r#"
1156 {
1157 "access_token": "TOKEN",
1158 "refresh_token": "",
1159 "expires_in": 3599,
1160 "expires_on": "1506484173",
1161 "not_before": "1506480273",
1162 "resource": "https://management.azure.com/",
1163 "token_type": "Bearer"
1164 }
1165 "#
1166 .to_string(),
1167 )
1168 });
1169
1170 let credential = WorkloadIdentityOAuthProvider::new(
1171 "client_id",
1172 tokenfile.path().to_str().unwrap(),
1173 tenant,
1174 Some(endpoint.to_string()),
1175 );
1176
1177 let token = credential
1178 .fetch_token(&client, &retry_config)
1179 .await
1180 .unwrap();
1181
1182 assert_eq!(
1183 token.token.as_ref(),
1184 &AzureCredential::BearerToken("TOKEN".into())
1185 );
1186 }
1187
1188 #[tokio::test]
1189 async fn test_no_credentials() {
1190 let server = MockServer::new().await;
1191
1192 let endpoint = server.url();
1193 let store = MicrosoftAzureBuilder::new()
1194 .with_account("test")
1195 .with_container_name("test")
1196 .with_allow_http(true)
1197 .with_bearer_token_authorization("token")
1198 .with_endpoint(endpoint.to_string())
1199 .with_skip_signature(true)
1200 .build()
1201 .unwrap();
1202
1203 server.push_fn(|req| {
1204 assert_eq!(req.method(), &Method::GET);
1205 assert!(req.headers().get("Authorization").is_none());
1206 Response::builder()
1207 .status(StatusCode::NOT_FOUND)
1208 .body("not found".to_string())
1209 .unwrap()
1210 });
1211
1212 let path = Path::from("file.txt");
1213 match store.get(&path).await {
1214 Err(crate::Error::NotFound { .. }) => {}
1215 _ => {
1216 panic!("unexpected response");
1217 }
1218 }
1219 }
1220}