Skip to main content

object_store/azure/
credential.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use super::client::UserDelegationKey;
19use crate::RetryConfig;
20use crate::azure::STORE;
21use crate::client::builder::{HttpRequestBuilder, add_query_pairs};
22use crate::client::retry::RetryExt;
23use crate::client::token::{TemporaryToken, TokenCache};
24use crate::client::{CredentialProvider, HttpClient, HttpError, HttpRequest, TokenProvider};
25use crate::util::hmac_sha256;
26use async_trait::async_trait;
27use base64::Engine;
28use base64::prelude::{BASE64_STANDARD, BASE64_URL_SAFE_NO_PAD};
29use chrono::{DateTime, SecondsFormat, Utc};
30use http::Method;
31use http::header::{
32    ACCEPT, AUTHORIZATION, CONTENT_ENCODING, CONTENT_LANGUAGE, CONTENT_LENGTH, CONTENT_TYPE, DATE,
33    HeaderMap, HeaderName, HeaderValue, IF_MATCH, IF_MODIFIED_SINCE, IF_NONE_MATCH,
34    IF_UNMODIFIED_SINCE, RANGE,
35};
36use serde::Deserialize;
37use std::borrow::Cow;
38use std::collections::HashMap;
39use std::fmt::Debug;
40use std::ops::Deref;
41use std::process::Command;
42use std::str;
43use std::sync::Arc;
44use std::time::{Duration, Instant, SystemTime};
45use url::Url;
46
47static AZURE_VERSION: HeaderValue = HeaderValue::from_static("2023-11-03");
48static VERSION: HeaderName = HeaderName::from_static("x-ms-version");
49pub(crate) static BLOB_TYPE: HeaderName = HeaderName::from_static("x-ms-blob-type");
50pub(crate) static COPY_SOURCE: HeaderName = HeaderName::from_static("x-ms-copy-source");
51static CONTENT_MD5: HeaderName = HeaderName::from_static("content-md5");
52static PARTNER_TOKEN: HeaderName = HeaderName::from_static("x-ms-partner-token");
53static CLUSTER_IDENTIFIER: HeaderName = HeaderName::from_static("x-ms-cluster-identifier");
54static WORKLOAD_RESOURCE: HeaderName = HeaderName::from_static("x-ms-workload-resource-moniker");
55static PROXY_HOST: HeaderName = HeaderName::from_static("x-ms-proxy-host");
56pub(crate) const RFC1123_FMT: &str = "%a, %d %h %Y %T GMT";
57const CONTENT_TYPE_JSON: &str = "application/json";
58const MSI_SECRET_ENV_KEY: &str = "IDENTITY_HEADER";
59const MSI_API_VERSION: &str = "2019-08-01";
60const TOKEN_MIN_TTL: u64 = 300;
61
62/// OIDC scope used when interacting with OAuth2 APIs
63///
64/// <https://learn.microsoft.com/en-us/azure/active-directory/develop/scopes-oidc#the-default-scope>
65const AZURE_STORAGE_SCOPE: &str = "https://storage.azure.com/.default";
66
67/// Resource ID used when obtaining an access token from the metadata endpoint
68///
69/// <https://learn.microsoft.com/en-us/azure/storage/blobs/authorize-access-azure-active-directory#microsoft-authentication-library-msal>
70const AZURE_STORAGE_RESOURCE: &str = "https://storage.azure.com";
71
72#[derive(Debug, thiserror::Error)]
73pub enum Error {
74    #[error("Error performing token request: {}", source)]
75    TokenRequest {
76        source: crate::client::retry::RetryError,
77    },
78
79    #[error("Error getting token response body: {}", source)]
80    TokenResponseBody { source: HttpError },
81
82    #[error("Error reading federated token file ")]
83    FederatedTokenFile,
84
85    #[error("Invalid Access Key: {}", source)]
86    InvalidAccessKey { source: base64::DecodeError },
87
88    #[error("'az account get-access-token' command failed: {message}")]
89    AzureCli { message: String },
90
91    #[error("Failed to parse azure cli response: {source}")]
92    AzureCliResponse { source: serde_json::Error },
93
94    #[error("Generating SAS keys with SAS tokens auth is not supported")]
95    SASforSASNotSupported,
96}
97
98pub(crate) type Result<T, E = Error> = std::result::Result<T, E>;
99
100impl From<Error> for crate::Error {
101    fn from(value: Error) -> Self {
102        Self::Generic {
103            store: STORE,
104            source: Box::new(value),
105        }
106    }
107}
108
109/// A shared Azure Storage Account Key
110#[derive(Debug, Clone, Eq, PartialEq)]
111pub struct AzureAccessKey(Vec<u8>);
112
113impl AzureAccessKey {
114    /// Create a new [`AzureAccessKey`], checking it for validity
115    pub fn try_new(key: &str) -> Result<Self> {
116        let key = BASE64_STANDARD
117            .decode(key)
118            .map_err(|source| Error::InvalidAccessKey { source })?;
119
120        Ok(Self(key))
121    }
122}
123
124/// An Azure storage credential
125#[derive(Debug, Eq, PartialEq)]
126pub enum AzureCredential {
127    /// A shared access key
128    ///
129    /// <https://learn.microsoft.com/en-us/rest/api/storageservices/authorize-with-shared-key>
130    AccessKey(AzureAccessKey),
131    /// A shared access signature
132    ///
133    /// <https://learn.microsoft.com/en-us/rest/api/storageservices/delegate-access-with-shared-access-signature>
134    SASToken(Vec<(String, String)>),
135    /// An authorization token
136    ///
137    /// <https://learn.microsoft.com/en-us/rest/api/storageservices/authorize-with-azure-active-directory>
138    BearerToken(String),
139}
140
141impl AzureCredential {
142    /// Determines if the credential requires the request be treated as sensitive
143    pub fn sensitive_request(&self) -> bool {
144        match self {
145            Self::AccessKey(_) => false,
146            Self::BearerToken(_) => false,
147            // SAS tokens are sent as query parameters in the url
148            Self::SASToken(_) => true,
149        }
150    }
151}
152
153/// A list of known Azure authority hosts
154pub mod authority_hosts {
155    /// China-based Azure Authority Host
156    pub const AZURE_CHINA: &str = "https://login.chinacloudapi.cn";
157    /// Germany-based Azure Authority Host
158    pub const AZURE_GERMANY: &str = "https://login.microsoftonline.de";
159    /// US Government Azure Authority Host
160    pub const AZURE_GOVERNMENT: &str = "https://login.microsoftonline.us";
161    /// Public Cloud Azure Authority Host
162    pub const AZURE_PUBLIC_CLOUD: &str = "https://login.microsoftonline.com";
163}
164
165pub(crate) struct AzureSigner {
166    signing_key: AzureAccessKey,
167    start: DateTime<Utc>,
168    end: DateTime<Utc>,
169    account: String,
170    delegation_key: Option<UserDelegationKey>,
171}
172
173impl AzureSigner {
174    pub(crate) fn new(
175        signing_key: AzureAccessKey,
176        account: String,
177        start: DateTime<Utc>,
178        end: DateTime<Utc>,
179        delegation_key: Option<UserDelegationKey>,
180    ) -> Self {
181        Self {
182            signing_key,
183            account,
184            start,
185            end,
186            delegation_key,
187        }
188    }
189
190    pub(crate) fn sign(&self, method: &Method, url: &mut Url) -> Result<()> {
191        let (str_to_sign, query_pairs) = match &self.delegation_key {
192            Some(delegation_key) => string_to_sign_user_delegation_sas(
193                url,
194                method,
195                &self.account,
196                &self.start,
197                &self.end,
198                delegation_key,
199            ),
200            None => string_to_sign_service_sas(url, method, &self.account, &self.start, &self.end),
201        };
202        let auth = hmac_sha256(&self.signing_key.0, str_to_sign);
203        url.query_pairs_mut().extend_pairs(query_pairs);
204        url.query_pairs_mut()
205            .append_pair("sig", BASE64_STANDARD.encode(auth).as_str());
206        Ok(())
207    }
208}
209
210fn add_date_and_version_headers(request: &mut HttpRequest) {
211    // rfc2822 string should never contain illegal characters
212    let date = Utc::now();
213    let date_str = date.format(RFC1123_FMT).to_string();
214    // we formatted the data string ourselves, so unwrapping should be fine
215    let date_val = HeaderValue::from_str(&date_str).unwrap();
216    request.headers_mut().insert(DATE, date_val);
217    request
218        .headers_mut()
219        .insert(&VERSION, AZURE_VERSION.clone());
220}
221
222/// Authorize a [`HttpRequest`] with an [`AzureAuthorizer`]
223#[derive(Debug)]
224pub struct AzureAuthorizer<'a> {
225    credential: &'a AzureCredential,
226    account: &'a str,
227}
228
229impl<'a> AzureAuthorizer<'a> {
230    /// Create a new [`AzureAuthorizer`]
231    pub fn new(credential: &'a AzureCredential, account: &'a str) -> Self {
232        AzureAuthorizer {
233            credential,
234            account,
235        }
236    }
237
238    /// Authorize `request`
239    pub fn authorize(&self, request: &mut HttpRequest) {
240        add_date_and_version_headers(request);
241
242        match self.credential {
243            AzureCredential::AccessKey(key) => {
244                let url = Url::parse(&request.uri().to_string()).unwrap();
245                let signature = generate_authorization(
246                    request.headers(),
247                    &url,
248                    request.method(),
249                    self.account,
250                    key,
251                );
252
253                // "signature" is a base 64 encoded string so it should never
254                // contain illegal characters
255                request.headers_mut().append(
256                    AUTHORIZATION,
257                    HeaderValue::from_str(signature.as_str()).unwrap(),
258                );
259            }
260            AzureCredential::BearerToken(token) => {
261                request.headers_mut().append(
262                    AUTHORIZATION,
263                    HeaderValue::from_str(format!("Bearer {token}").as_str()).unwrap(),
264                );
265            }
266            AzureCredential::SASToken(query_pairs) => {
267                add_query_pairs(request.uri_mut(), query_pairs);
268            }
269        }
270    }
271}
272
273pub(crate) trait CredentialExt {
274    /// Apply authorization to requests against azure storage accounts
275    /// <https://docs.microsoft.com/en-us/rest/api/storageservices/authorize-requests-to-azure-storage>
276    fn with_azure_authorization(
277        self,
278        credential: &Option<impl Deref<Target = AzureCredential>>,
279        account: &str,
280    ) -> Self;
281}
282
283impl CredentialExt for HttpRequestBuilder {
284    fn with_azure_authorization(
285        self,
286        credential: &Option<impl Deref<Target = AzureCredential>>,
287        account: &str,
288    ) -> Self {
289        let (client, request) = self.into_parts();
290        let mut request = request.expect("request valid");
291
292        match credential.as_deref() {
293            Some(credential) => {
294                AzureAuthorizer::new(credential, account).authorize(&mut request);
295            }
296            None => {
297                add_date_and_version_headers(&mut request);
298            }
299        }
300
301        Self::from_parts(client, request)
302    }
303}
304
305/// Generate signed key for authorization via access keys
306/// <https://docs.microsoft.com/en-us/rest/api/storageservices/authorize-with-shared-key>
307fn generate_authorization(
308    h: &HeaderMap,
309    u: &Url,
310    method: &Method,
311    account: &str,
312    key: &AzureAccessKey,
313) -> String {
314    let str_to_sign = string_to_sign(h, u, method, account);
315    let auth = hmac_sha256(&key.0, str_to_sign);
316    format!("SharedKey {}:{}", account, BASE64_STANDARD.encode(auth))
317}
318
319fn add_if_exists<'a>(h: &'a HeaderMap, key: &HeaderName) -> &'a str {
320    h.get(key)
321        .map(|s| s.to_str())
322        .transpose()
323        .ok()
324        .flatten()
325        .unwrap_or_default()
326}
327
328fn string_to_sign_sas(
329    u: &Url,
330    method: &Method,
331    account: &str,
332    start: &DateTime<Utc>,
333    end: &DateTime<Utc>,
334) -> (String, String, String, String, String) {
335    // NOTE: for now only blob signing is supported.
336    let signed_resource = "b".to_string();
337
338    // https://learn.microsoft.com/en-us/rest/api/storageservices/create-service-sas#permissions-for-a-directory-container-or-blob
339    let signed_permissions = match *method {
340        // read and list permissions
341        Method::GET => match signed_resource.as_str() {
342            "c" => "rl",
343            "b" => "r",
344            _ => unreachable!(),
345        },
346        // write permissions (also allows crating a new blob in a sub-key)
347        Method::PUT => "w",
348        // delete permissions
349        Method::DELETE => "d",
350        // other methods are not used in any of the current operations
351        _ => "",
352    }
353    .to_string();
354    let signed_start = start.to_rfc3339_opts(SecondsFormat::Secs, true);
355    let signed_expiry = end.to_rfc3339_opts(SecondsFormat::Secs, true);
356    let canonicalized_resource = if u.host_str().unwrap_or_default().contains(account) {
357        format!("/blob/{}{}", account, u.path())
358    } else {
359        // NOTE: in case of the emulator, the account name is not part of the host
360        //      but the path starts with the account name
361        format!("/blob{}", u.path())
362    };
363
364    (
365        signed_resource,
366        signed_permissions,
367        signed_start,
368        signed_expiry,
369        canonicalized_resource,
370    )
371}
372
373/// Create a string to be signed for authorization via [service sas].
374///
375/// [service sas]: https://learn.microsoft.com/en-us/rest/api/storageservices/create-service-sas#version-2020-12-06-and-later
376fn string_to_sign_service_sas(
377    u: &Url,
378    method: &Method,
379    account: &str,
380    start: &DateTime<Utc>,
381    end: &DateTime<Utc>,
382) -> (String, HashMap<&'static str, String>) {
383    let (signed_resource, signed_permissions, signed_start, signed_expiry, canonicalized_resource) =
384        string_to_sign_sas(u, method, account, start, end);
385
386    let string_to_sign = format!(
387        "{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}",
388        signed_permissions,
389        signed_start,
390        signed_expiry,
391        canonicalized_resource,
392        "",                               // signed identifier
393        "",                               // signed ip
394        "",                               // signed protocol
395        &AZURE_VERSION.to_str().unwrap(), // signed version
396        signed_resource,                  // signed resource
397        "",                               // signed snapshot time
398        "",                               // signed encryption scope
399        "",                               // rscc - response header: Cache-Control
400        "",                               // rscd - response header: Content-Disposition
401        "",                               // rsce - response header: Content-Encoding
402        "",                               // rscl - response header: Content-Language
403        "",                               // rsct - response header: Content-Type
404    );
405
406    let mut pairs = HashMap::new();
407    pairs.insert("sv", AZURE_VERSION.to_str().unwrap().to_string());
408    pairs.insert("sp", signed_permissions);
409    pairs.insert("st", signed_start);
410    pairs.insert("se", signed_expiry);
411    pairs.insert("sr", signed_resource);
412
413    (string_to_sign, pairs)
414}
415
416/// Create a string to be signed for authorization via [user delegation sas].
417///
418/// [user delegation sas]: https://learn.microsoft.com/en-us/rest/api/storageservices/create-user-delegation-sas#version-2020-12-06-and-later
419fn string_to_sign_user_delegation_sas(
420    u: &Url,
421    method: &Method,
422    account: &str,
423    start: &DateTime<Utc>,
424    end: &DateTime<Utc>,
425    delegation_key: &UserDelegationKey,
426) -> (String, HashMap<&'static str, String>) {
427    let (signed_resource, signed_permissions, signed_start, signed_expiry, canonicalized_resource) =
428        string_to_sign_sas(u, method, account, start, end);
429
430    let string_to_sign = format!(
431        "{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}",
432        signed_permissions,
433        signed_start,
434        signed_expiry,
435        canonicalized_resource,
436        delegation_key.signed_oid,        // signed key object id
437        delegation_key.signed_tid,        // signed key tenant id
438        delegation_key.signed_start,      // signed key start
439        delegation_key.signed_expiry,     // signed key expiry
440        delegation_key.signed_service,    // signed key service
441        delegation_key.signed_version,    // signed key version
442        "",                               // signed authorized user object id
443        "",                               // signed unauthorized user object id
444        "",                               // signed correlation id
445        "",                               // signed ip
446        "",                               // signed protocol
447        &AZURE_VERSION.to_str().unwrap(), // signed version
448        signed_resource,                  // signed resource
449        "",                               // signed snapshot time
450        "",                               // signed encryption scope
451        "",                               // rscc - response header: Cache-Control
452        "",                               // rscd - response header: Content-Disposition
453        "",                               // rsce - response header: Content-Encoding
454        "",                               // rscl - response header: Content-Language
455        "",                               // rsct - response header: Content-Type
456    );
457
458    let mut pairs = HashMap::new();
459    pairs.insert("sv", AZURE_VERSION.to_str().unwrap().to_string());
460    pairs.insert("sp", signed_permissions);
461    pairs.insert("st", signed_start);
462    pairs.insert("se", signed_expiry);
463    pairs.insert("sr", signed_resource);
464    pairs.insert("skoid", delegation_key.signed_oid.clone());
465    pairs.insert("sktid", delegation_key.signed_tid.clone());
466    pairs.insert("skt", delegation_key.signed_start.clone());
467    pairs.insert("ske", delegation_key.signed_expiry.clone());
468    pairs.insert("sks", delegation_key.signed_service.clone());
469    pairs.insert("skv", delegation_key.signed_version.clone());
470
471    (string_to_sign, pairs)
472}
473
474/// <https://docs.microsoft.com/en-us/rest/api/storageservices/authorize-with-shared-key#constructing-the-signature-string>
475fn string_to_sign(h: &HeaderMap, u: &Url, method: &Method, account: &str) -> String {
476    // content length must only be specified if != 0
477    // this is valid from 2015-02-21
478    let content_length = h
479        .get(&CONTENT_LENGTH)
480        .map(|s| s.to_str())
481        .transpose()
482        .ok()
483        .flatten()
484        .filter(|&v| v != "0")
485        .unwrap_or_default();
486    format!(
487        "{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}{}",
488        method.as_ref(),
489        add_if_exists(h, &CONTENT_ENCODING),
490        add_if_exists(h, &CONTENT_LANGUAGE),
491        content_length,
492        add_if_exists(h, &CONTENT_MD5),
493        add_if_exists(h, &CONTENT_TYPE),
494        add_if_exists(h, &DATE),
495        add_if_exists(h, &IF_MODIFIED_SINCE),
496        add_if_exists(h, &IF_MATCH),
497        add_if_exists(h, &IF_NONE_MATCH),
498        add_if_exists(h, &IF_UNMODIFIED_SINCE),
499        add_if_exists(h, &RANGE),
500        canonicalize_header(h),
501        canonicalize_resource(account, u)
502    )
503}
504
505/// <https://docs.microsoft.com/en-us/rest/api/storageservices/authorize-with-shared-key#constructing-the-canonicalized-headers-string>
506fn canonicalize_header(headers: &HeaderMap) -> String {
507    let mut names = headers
508        .iter()
509        .filter(|&(k, _)| k.as_str().starts_with("x-ms"))
510        // TODO remove unwraps
511        .map(|(k, _)| (k.as_str(), headers.get(k).unwrap().to_str().unwrap()))
512        .collect::<Vec<_>>();
513    names.sort_unstable();
514
515    let mut result = String::new();
516    for (name, value) in names {
517        result.push_str(name);
518        result.push(':');
519        result.push_str(value);
520        result.push('\n');
521    }
522    result
523}
524
525/// <https://docs.microsoft.com/en-us/rest/api/storageservices/authorize-with-shared-key#constructing-the-canonicalized-resource-string>
526fn canonicalize_resource(account: &str, uri: &Url) -> String {
527    let mut can_res: String = String::new();
528    can_res.push('/');
529    can_res.push_str(account);
530    can_res.push_str(uri.path().to_string().as_str());
531    can_res.push('\n');
532
533    // query parameters
534    let query_pairs = uri.query_pairs();
535    {
536        let mut qps: Vec<String> = Vec::new();
537        for (q, _) in query_pairs {
538            if !(qps.iter().any(|x| x == &*q)) {
539                qps.push(q.into_owned());
540            }
541        }
542
543        qps.sort();
544
545        for qparam in qps {
546            // find correct parameter
547            let ret = lexy_sort(query_pairs, &qparam);
548
549            can_res = can_res + &qparam.to_lowercase() + ":";
550
551            for (i, item) in ret.iter().enumerate() {
552                if i > 0 {
553                    can_res.push(',');
554                }
555                can_res.push_str(item);
556            }
557
558            can_res.push('\n');
559        }
560    };
561
562    can_res[0..can_res.len() - 1].to_owned()
563}
564
565fn lexy_sort<'a>(
566    vec: impl Iterator<Item = (Cow<'a, str>, Cow<'a, str>)> + 'a,
567    query_param: &str,
568) -> Vec<Cow<'a, str>> {
569    let mut values = vec
570        .filter(|(k, _)| *k == query_param)
571        .map(|(_, v)| v)
572        .collect::<Vec<_>>();
573    values.sort_unstable();
574    values
575}
576
577/// <https://learn.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-client-creds-grant-flow#successful-response-1>
578#[derive(Deserialize, Debug)]
579struct OAuthTokenResponse {
580    access_token: String,
581    expires_in: u64,
582}
583
584/// Encapsulates the logic to perform an OAuth token challenge
585///
586/// <https://learn.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-client-creds-grant-flow#first-case-access-token-request-with-a-shared-secret>
587#[derive(Debug)]
588pub(crate) struct ClientSecretOAuthProvider {
589    token_url: String,
590    client_id: String,
591    client_secret: String,
592}
593
594impl ClientSecretOAuthProvider {
595    /// Create a new [`ClientSecretOAuthProvider`] for an azure backed store
596    pub(crate) fn new(
597        client_id: String,
598        client_secret: String,
599        tenant_id: impl AsRef<str>,
600        authority_host: Option<String>,
601    ) -> Self {
602        let authority_host =
603            authority_host.unwrap_or_else(|| authority_hosts::AZURE_PUBLIC_CLOUD.to_owned());
604
605        Self {
606            token_url: format!(
607                "{}/{}/oauth2/v2.0/token",
608                authority_host,
609                tenant_id.as_ref()
610            ),
611            client_id,
612            client_secret,
613        }
614    }
615}
616
617#[async_trait::async_trait]
618impl TokenProvider for ClientSecretOAuthProvider {
619    type Credential = AzureCredential;
620
621    /// Fetch a token
622    async fn fetch_token(
623        &self,
624        client: &HttpClient,
625        retry: &RetryConfig,
626    ) -> crate::Result<TemporaryToken<Arc<AzureCredential>>> {
627        let response: OAuthTokenResponse = client
628            .request(Method::POST, &self.token_url)
629            .header(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_JSON))
630            .form([
631                ("client_id", self.client_id.as_str()),
632                ("client_secret", self.client_secret.as_str()),
633                ("scope", AZURE_STORAGE_SCOPE),
634                ("grant_type", "client_credentials"),
635            ])
636            .retryable(retry)
637            .idempotent(true)
638            .send()
639            .await
640            .map_err(|source| Error::TokenRequest { source })?
641            .into_body()
642            .json()
643            .await
644            .map_err(|source| Error::TokenResponseBody { source })?;
645
646        Ok(TemporaryToken {
647            token: Arc::new(AzureCredential::BearerToken(response.access_token)),
648            expiry: Some(Instant::now() + Duration::from_secs(response.expires_in)),
649        })
650    }
651}
652
653fn expires_on_string<'de, D>(deserializer: D) -> std::result::Result<Instant, D::Error>
654where
655    D: serde::de::Deserializer<'de>,
656{
657    let v = String::deserialize(deserializer)?;
658    let v = v.parse::<u64>().map_err(serde::de::Error::custom)?;
659    let now = SystemTime::now()
660        .duration_since(SystemTime::UNIX_EPOCH)
661        .map_err(serde::de::Error::custom)?;
662
663    Ok(Instant::now() + Duration::from_secs(v.saturating_sub(now.as_secs())))
664}
665
666/// NOTE: expires_on is a String version of unix epoch time, not an integer.
667/// <https://learn.microsoft.com/en-gb/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http>
668/// <https://learn.microsoft.com/en-us/azure/app-service/overview-managed-identity?tabs=portal%2Chttp#connect-to-azure-services-in-app-code>
669#[derive(Debug, Clone, Deserialize)]
670struct ImdsTokenResponse {
671    pub access_token: String,
672    #[serde(deserialize_with = "expires_on_string")]
673    pub expires_on: Instant,
674}
675
676/// Attempts authentication using a managed identity that has been assigned to the deployment environment.
677///
678/// This authentication type works in Azure VMs, App Service and Azure Functions applications, as well as the Azure Cloud Shell
679/// <https://learn.microsoft.com/en-gb/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http>
680#[derive(Debug)]
681pub(crate) struct ImdsManagedIdentityProvider {
682    msi_endpoint: String,
683    client_id: Option<String>,
684    object_id: Option<String>,
685    msi_res_id: Option<String>,
686}
687
688impl ImdsManagedIdentityProvider {
689    /// Create a new [`ImdsManagedIdentityProvider`] for an azure backed store
690    pub(crate) fn new(
691        client_id: Option<String>,
692        object_id: Option<String>,
693        msi_res_id: Option<String>,
694        msi_endpoint: Option<String>,
695    ) -> Self {
696        let msi_endpoint = msi_endpoint
697            .unwrap_or_else(|| "http://169.254.169.254/metadata/identity/oauth2/token".to_owned());
698
699        Self {
700            msi_endpoint,
701            client_id,
702            object_id,
703            msi_res_id,
704        }
705    }
706}
707
708#[async_trait::async_trait]
709impl TokenProvider for ImdsManagedIdentityProvider {
710    type Credential = AzureCredential;
711
712    /// Fetch a token
713    async fn fetch_token(
714        &self,
715        client: &HttpClient,
716        retry: &RetryConfig,
717    ) -> crate::Result<TemporaryToken<Arc<AzureCredential>>> {
718        let mut query_items = vec![
719            ("api-version", MSI_API_VERSION),
720            ("resource", AZURE_STORAGE_RESOURCE),
721        ];
722
723        let mut identity = None;
724        if let Some(client_id) = &self.client_id {
725            identity = Some(("client_id", client_id));
726        }
727        if let Some(object_id) = &self.object_id {
728            identity = Some(("object_id", object_id));
729        }
730        if let Some(msi_res_id) = &self.msi_res_id {
731            identity = Some(("msi_res_id", msi_res_id));
732        }
733        if let Some((key, value)) = identity {
734            query_items.push((key, value));
735        }
736
737        let mut builder = client
738            .request(Method::GET, &self.msi_endpoint)
739            .header("metadata", "true")
740            .query(&query_items);
741
742        if let Ok(val) = std::env::var(MSI_SECRET_ENV_KEY) {
743            builder = builder.header("x-identity-header", val);
744        };
745
746        let response: ImdsTokenResponse = builder
747            .send_retry(retry)
748            .await
749            .map_err(|source| Error::TokenRequest { source })?
750            .into_body()
751            .json()
752            .await
753            .map_err(|source| Error::TokenResponseBody { source })?;
754
755        Ok(TemporaryToken {
756            token: Arc::new(AzureCredential::BearerToken(response.access_token)),
757            expiry: Some(response.expires_on),
758        })
759    }
760}
761
762/// Credential for using workload identity federation
763///
764/// <https://learn.microsoft.com/en-us/azure/active-directory/develop/workload-identity-federation>
765#[derive(Debug)]
766pub(crate) struct WorkloadIdentityOAuthProvider {
767    token_url: String,
768    client_id: String,
769    federated_token_file: String,
770}
771
772impl WorkloadIdentityOAuthProvider {
773    /// Create a new [`WorkloadIdentityOAuthProvider`] for an azure backed store
774    pub(crate) fn new(
775        client_id: impl Into<String>,
776        federated_token_file: impl Into<String>,
777        tenant_id: impl AsRef<str>,
778        authority_host: Option<String>,
779    ) -> Self {
780        let authority_host =
781            authority_host.unwrap_or_else(|| authority_hosts::AZURE_PUBLIC_CLOUD.to_owned());
782
783        Self {
784            token_url: format!(
785                "{}/{}/oauth2/v2.0/token",
786                authority_host,
787                tenant_id.as_ref()
788            ),
789            client_id: client_id.into(),
790            federated_token_file: federated_token_file.into(),
791        }
792    }
793}
794
795#[async_trait::async_trait]
796impl TokenProvider for WorkloadIdentityOAuthProvider {
797    type Credential = AzureCredential;
798
799    /// Fetch a token
800    async fn fetch_token(
801        &self,
802        client: &HttpClient,
803        retry: &RetryConfig,
804    ) -> crate::Result<TemporaryToken<Arc<AzureCredential>>> {
805        let token_str = std::fs::read_to_string(&self.federated_token_file)
806            .map_err(|_| Error::FederatedTokenFile)?;
807
808        // https://learn.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-client-creds-grant-flow#third-case-access-token-request-with-a-federated-credential
809        let response: OAuthTokenResponse = client
810            .request(Method::POST, &self.token_url)
811            .header(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_JSON))
812            .form([
813                ("client_id", self.client_id.as_str()),
814                (
815                    "client_assertion_type",
816                    "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
817                ),
818                ("client_assertion", token_str.as_str()),
819                ("scope", AZURE_STORAGE_SCOPE),
820                ("grant_type", "client_credentials"),
821            ])
822            .retryable(retry)
823            .idempotent(true)
824            .send()
825            .await
826            .map_err(|source| Error::TokenRequest { source })?
827            .into_body()
828            .json()
829            .await
830            .map_err(|source| Error::TokenResponseBody { source })?;
831
832        Ok(TemporaryToken {
833            token: Arc::new(AzureCredential::BearerToken(response.access_token)),
834            expiry: Some(Instant::now() + Duration::from_secs(response.expires_in)),
835        })
836    }
837}
838
839mod az_cli_date_format {
840    use chrono::{DateTime, TimeZone};
841    use serde::{self, Deserialize, Deserializer};
842
843    pub(crate) fn deserialize<'de, D>(deserializer: D) -> Result<DateTime<chrono::Local>, D::Error>
844    where
845        D: Deserializer<'de>,
846    {
847        let s = String::deserialize(deserializer)?;
848        // expiresOn from azure cli uses the local timezone
849        let date = chrono::NaiveDateTime::parse_from_str(&s, "%Y-%m-%d %H:%M:%S.%6f")
850            .map_err(serde::de::Error::custom)?;
851        chrono::Local
852            .from_local_datetime(&date)
853            .single()
854            .ok_or(serde::de::Error::custom(
855                "azure cli returned ambiguous expiry date",
856            ))
857    }
858}
859
860#[derive(Debug, Clone, Deserialize)]
861#[serde(rename_all = "camelCase")]
862struct AzureCliTokenResponse {
863    pub access_token: String,
864    #[serde(with = "az_cli_date_format")]
865    pub expires_on: DateTime<chrono::Local>,
866    pub token_type: String,
867}
868
869#[derive(Default, Debug)]
870pub(crate) struct AzureCliCredential {
871    cache: TokenCache<Arc<AzureCredential>>,
872}
873
874impl AzureCliCredential {
875    pub(crate) fn new() -> Self {
876        Self::default()
877    }
878
879    /// Fetch a token
880    async fn fetch_token(&self) -> Result<TemporaryToken<Arc<AzureCredential>>> {
881        // on window az is a cmd and it should be called like this
882        // see https://doc.rust-lang.org/nightly/std/process/struct.Command.html
883        let program = if cfg!(target_os = "windows") {
884            "cmd"
885        } else {
886            "az"
887        };
888        let mut args = Vec::new();
889        if cfg!(target_os = "windows") {
890            args.push("/C");
891            args.push("az");
892        }
893        args.push("account");
894        args.push("get-access-token");
895        args.push("--output");
896        args.push("json");
897        args.push("--scope");
898        args.push(AZURE_STORAGE_SCOPE);
899
900        match Command::new(program).args(args).output() {
901            Ok(az_output) if az_output.status.success() => {
902                let output = str::from_utf8(&az_output.stdout).map_err(|_| Error::AzureCli {
903                    message: "az response is not a valid utf-8 string".to_string(),
904                })?;
905
906                let token_response = serde_json::from_str::<AzureCliTokenResponse>(output)
907                    .map_err(|source| Error::AzureCliResponse { source })?;
908
909                if !token_response.token_type.eq_ignore_ascii_case("bearer") {
910                    return Err(Error::AzureCli {
911                        message: format!(
912                            "got unexpected token type from azure cli: {0}",
913                            token_response.token_type
914                        ),
915                    });
916                }
917                let duration =
918                    token_response.expires_on.naive_local() - chrono::Local::now().naive_local();
919                Ok(TemporaryToken {
920                    token: Arc::new(AzureCredential::BearerToken(token_response.access_token)),
921                    expiry: Some(
922                        Instant::now()
923                            + duration.to_std().map_err(|_| Error::AzureCli {
924                                message: "az returned invalid lifetime".to_string(),
925                            })?,
926                    ),
927                })
928            }
929            Ok(az_output) => {
930                let message = String::from_utf8_lossy(&az_output.stderr);
931                Err(Error::AzureCli {
932                    message: message.into(),
933                })
934            }
935            Err(e) => match e.kind() {
936                std::io::ErrorKind::NotFound => Err(Error::AzureCli {
937                    message: "Azure Cli not installed".into(),
938                }),
939                error_kind => Err(Error::AzureCli {
940                    message: format!("io error: {error_kind:?}"),
941                }),
942            },
943        }
944    }
945}
946
947/// Encapsulates the logic to perform an OAuth token challenge for Fabric
948#[derive(Debug)]
949pub(crate) struct FabricTokenOAuthProvider {
950    fabric_token_service_url: String,
951    fabric_workload_host: String,
952    fabric_session_token: String,
953    fabric_cluster_identifier: String,
954    storage_access_token: Option<String>,
955    token_expiry: Option<u64>,
956}
957
958#[derive(Debug, Deserialize)]
959struct Claims {
960    exp: u64,
961}
962
963impl FabricTokenOAuthProvider {
964    /// Create a new [`FabricTokenOAuthProvider`] for an azure backed store
965    pub(crate) fn new(
966        fabric_token_service_url: impl Into<String>,
967        fabric_workload_host: impl Into<String>,
968        fabric_session_token: impl Into<String>,
969        fabric_cluster_identifier: impl Into<String>,
970        storage_access_token: Option<String>,
971    ) -> Self {
972        let (storage_access_token, token_expiry) = match storage_access_token {
973            Some(token) => match Self::validate_and_get_expiry(&token) {
974                Some(expiry) if expiry > Self::get_current_timestamp() + TOKEN_MIN_TTL => {
975                    (Some(token), Some(expiry))
976                }
977                _ => (None, None),
978            },
979            None => (None, None),
980        };
981
982        Self {
983            fabric_token_service_url: fabric_token_service_url.into(),
984            fabric_workload_host: fabric_workload_host.into(),
985            fabric_session_token: fabric_session_token.into(),
986            fabric_cluster_identifier: fabric_cluster_identifier.into(),
987            storage_access_token,
988            token_expiry,
989        }
990    }
991
992    fn validate_and_get_expiry(token: &str) -> Option<u64> {
993        let payload = token.split('.').nth(1)?;
994        let decoded_bytes = BASE64_URL_SAFE_NO_PAD.decode(payload).ok()?;
995        let decoded_str = str::from_utf8(&decoded_bytes).ok()?;
996        let claims: Claims = serde_json::from_str(decoded_str).ok()?;
997        Some(claims.exp)
998    }
999
1000    fn get_current_timestamp() -> u64 {
1001        SystemTime::now()
1002            .duration_since(SystemTime::UNIX_EPOCH)
1003            .map_or(0, |d| d.as_secs())
1004    }
1005}
1006
1007#[async_trait::async_trait]
1008impl TokenProvider for FabricTokenOAuthProvider {
1009    type Credential = AzureCredential;
1010
1011    /// Fetch a token
1012    async fn fetch_token(
1013        &self,
1014        client: &HttpClient,
1015        retry: &RetryConfig,
1016    ) -> crate::Result<TemporaryToken<Arc<AzureCredential>>> {
1017        if let Some(storage_access_token) = &self.storage_access_token {
1018            if let Some(expiry) = self.token_expiry {
1019                let exp_in = expiry.saturating_sub(Self::get_current_timestamp());
1020                if exp_in > TOKEN_MIN_TTL {
1021                    return Ok(TemporaryToken {
1022                        token: Arc::new(AzureCredential::BearerToken(storage_access_token.clone())),
1023                        expiry: Some(Instant::now() + Duration::from_secs(exp_in)),
1024                    });
1025                }
1026            }
1027        }
1028
1029        let query_items = vec![("resource", AZURE_STORAGE_RESOURCE)];
1030        let access_token: String = client
1031            .request(Method::GET, &self.fabric_token_service_url)
1032            .header(&PARTNER_TOKEN, self.fabric_session_token.as_str())
1033            .header(&CLUSTER_IDENTIFIER, self.fabric_cluster_identifier.as_str())
1034            .header(&WORKLOAD_RESOURCE, self.fabric_cluster_identifier.as_str())
1035            .header(&PROXY_HOST, self.fabric_workload_host.as_str())
1036            .query(&query_items)
1037            .retryable(retry)
1038            .idempotent(true)
1039            .send()
1040            .await
1041            .map_err(|source| Error::TokenRequest { source })?
1042            .into_body()
1043            .text()
1044            .await
1045            .map_err(|source| Error::TokenResponseBody { source })?;
1046        let exp_in = Self::validate_and_get_expiry(&access_token).map_or(3600, |expiry| {
1047            expiry.saturating_sub(Self::get_current_timestamp())
1048        });
1049        Ok(TemporaryToken {
1050            token: Arc::new(AzureCredential::BearerToken(access_token)),
1051            expiry: Some(Instant::now() + Duration::from_secs(exp_in)),
1052        })
1053    }
1054}
1055
1056#[async_trait]
1057impl CredentialProvider for AzureCliCredential {
1058    type Credential = AzureCredential;
1059
1060    async fn get_credential(&self) -> crate::Result<Arc<Self::Credential>> {
1061        Ok(self.cache.get_or_insert_with(|| self.fetch_token()).await?)
1062    }
1063}
1064
1065#[cfg(test)]
1066mod tests {
1067    use futures_executor::block_on;
1068    use http::{Response, StatusCode};
1069    use http_body_util::BodyExt;
1070    use reqwest::{Client, Method};
1071    use tempfile::NamedTempFile;
1072
1073    use super::*;
1074    use crate::azure::MicrosoftAzureBuilder;
1075    use crate::client::mock_server::MockServer;
1076    use crate::{ObjectStoreExt, Path};
1077
1078    #[tokio::test]
1079    async fn test_managed_identity() {
1080        let server = MockServer::new().await;
1081
1082        unsafe { std::env::set_var(MSI_SECRET_ENV_KEY, "env-secret") };
1083
1084        let endpoint = server.url();
1085        let client = HttpClient::new(Client::new());
1086        let retry_config = RetryConfig::default();
1087
1088        // Test IMDS
1089        server.push_fn(|req| {
1090            assert_eq!(req.uri().path(), "/metadata/identity/oauth2/token");
1091            assert!(req.uri().query().unwrap().contains("client_id=client_id"));
1092            assert_eq!(req.method(), &Method::GET);
1093            let t = req
1094                .headers()
1095                .get("x-identity-header")
1096                .unwrap()
1097                .to_str()
1098                .unwrap();
1099            assert_eq!(t, "env-secret");
1100            let t = req.headers().get("metadata").unwrap().to_str().unwrap();
1101            assert_eq!(t, "true");
1102            Response::new(
1103                r#"
1104            {
1105                "access_token": "TOKEN",
1106                "refresh_token": "",
1107                "expires_in": "3599",
1108                "expires_on": "1506484173",
1109                "not_before": "1506480273",
1110                "resource": "https://management.azure.com/",
1111                "token_type": "Bearer"
1112              }
1113            "#
1114                .to_string(),
1115            )
1116        });
1117
1118        let credential = ImdsManagedIdentityProvider::new(
1119            Some("client_id".into()),
1120            None,
1121            None,
1122            Some(format!("{endpoint}/metadata/identity/oauth2/token")),
1123        );
1124
1125        let token = credential
1126            .fetch_token(&client, &retry_config)
1127            .await
1128            .unwrap();
1129
1130        assert_eq!(
1131            token.token.as_ref(),
1132            &AzureCredential::BearerToken("TOKEN".into())
1133        );
1134    }
1135
1136    #[tokio::test]
1137    async fn test_workload_identity() {
1138        let server = MockServer::new().await;
1139        let tokenfile = NamedTempFile::new().unwrap();
1140        let tenant = "tenant";
1141        std::fs::write(tokenfile.path(), "federated-token").unwrap();
1142
1143        let endpoint = server.url();
1144        let client = HttpClient::new(Client::new());
1145        let retry_config = RetryConfig::default();
1146
1147        // Test IMDS
1148        server.push_fn(move |req| {
1149            assert_eq!(req.uri().path(), format!("/{tenant}/oauth2/v2.0/token"));
1150            assert_eq!(req.method(), &Method::POST);
1151            let body = block_on(async move { req.into_body().collect().await.unwrap().to_bytes() });
1152            let body = String::from_utf8(body.to_vec()).unwrap();
1153            assert!(body.contains("federated-token"));
1154            Response::new(
1155                r#"
1156            {
1157                "access_token": "TOKEN",
1158                "refresh_token": "",
1159                "expires_in": 3599,
1160                "expires_on": "1506484173",
1161                "not_before": "1506480273",
1162                "resource": "https://management.azure.com/",
1163                "token_type": "Bearer"
1164              }
1165            "#
1166                .to_string(),
1167            )
1168        });
1169
1170        let credential = WorkloadIdentityOAuthProvider::new(
1171            "client_id",
1172            tokenfile.path().to_str().unwrap(),
1173            tenant,
1174            Some(endpoint.to_string()),
1175        );
1176
1177        let token = credential
1178            .fetch_token(&client, &retry_config)
1179            .await
1180            .unwrap();
1181
1182        assert_eq!(
1183            token.token.as_ref(),
1184            &AzureCredential::BearerToken("TOKEN".into())
1185        );
1186    }
1187
1188    #[tokio::test]
1189    async fn test_no_credentials() {
1190        let server = MockServer::new().await;
1191
1192        let endpoint = server.url();
1193        let store = MicrosoftAzureBuilder::new()
1194            .with_account("test")
1195            .with_container_name("test")
1196            .with_allow_http(true)
1197            .with_bearer_token_authorization("token")
1198            .with_endpoint(endpoint.to_string())
1199            .with_skip_signature(true)
1200            .build()
1201            .unwrap();
1202
1203        server.push_fn(|req| {
1204            assert_eq!(req.method(), &Method::GET);
1205            assert!(req.headers().get("Authorization").is_none());
1206            Response::builder()
1207                .status(StatusCode::NOT_FOUND)
1208                .body("not found".to_string())
1209                .unwrap()
1210        });
1211
1212        let path = Path::from("file.txt");
1213        match store.get(&path).await {
1214            Err(crate::Error::NotFound { .. }) => {}
1215            _ => {
1216                panic!("unexpected response");
1217            }
1218        }
1219    }
1220
1221    #[tokio::test]
1222    async fn test_fabric_refresh_expired_token() {
1223        let server = MockServer::new().await;
1224
1225        // Create an expired initial token (1 hour in the past)
1226        let expired_timestamp = FabricTokenOAuthProvider::get_current_timestamp() - 3600;
1227        let claims = format!(r#"{{"exp":{expired_timestamp}}}"#);
1228        let encoded_claims = BASE64_URL_SAFE_NO_PAD.encode(claims.as_bytes());
1229        let expired_token = format!("header.{encoded_claims}.signature");
1230
1231        // Create a fresh token that the mock API will return (1 hour in the future)
1232        let fresh_timestamp = FabricTokenOAuthProvider::get_current_timestamp() + 3600;
1233        let fresh_claims = format!(r#"{{"exp":{fresh_timestamp}}}"#);
1234        let fresh_encoded = BASE64_URL_SAFE_NO_PAD.encode(fresh_claims.as_bytes());
1235        let fresh_token = format!("header.{fresh_encoded}.signature");
1236        let expected_token = fresh_token.clone();
1237
1238        // Mock the Fabric token service to return a fresh token
1239        server.push_fn(move |req| {
1240            assert_eq!(req.headers().get(&PARTNER_TOKEN).unwrap(), "session-token");
1241            assert_eq!(
1242                req.headers().get(&CLUSTER_IDENTIFIER).unwrap(),
1243                "cluster-id"
1244            );
1245            assert_eq!(req.headers().get(&PROXY_HOST).unwrap(), "fake");
1246
1247            Response::new(fresh_token)
1248        });
1249
1250        let provider = FabricTokenOAuthProvider {
1251            fabric_token_service_url: server.url().to_string(),
1252            fabric_workload_host: "fake".to_string(),
1253            fabric_session_token: "session-token".to_string(),
1254            fabric_cluster_identifier: "cluster-id".to_string(),
1255            storage_access_token: Some(expired_token),
1256            token_expiry: Some(expired_timestamp),
1257        };
1258
1259        let client = HttpClient::new(Client::new());
1260        let retry = RetryConfig::default();
1261
1262        let result = provider.fetch_token(&client, &retry).await;
1263
1264        assert!(
1265            result.is_ok(),
1266            "fetch_token should handle expired cached token gracefully"
1267        );
1268
1269        let temp_token = result.unwrap();
1270
1271        // Verify we got a fresh token from the API (not the expired cached one)
1272        if let AzureCredential::BearerToken(token) = temp_token.token.as_ref() {
1273            assert_eq!(
1274                token, &expected_token,
1275                "Should have fetched fresh token from API, not returned expired cached token"
1276            );
1277        }
1278    }
1279}