object_store/azure/
credential.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use super::client::UserDelegationKey;
19use crate::azure::STORE;
20use crate::client::builder::{add_query_pairs, HttpRequestBuilder};
21use crate::client::retry::RetryExt;
22use crate::client::token::{TemporaryToken, TokenCache};
23use crate::client::{CredentialProvider, HttpClient, HttpError, HttpRequest, TokenProvider};
24use crate::util::hmac_sha256;
25use crate::RetryConfig;
26use async_trait::async_trait;
27use base64::prelude::{BASE64_STANDARD, BASE64_URL_SAFE_NO_PAD};
28use base64::Engine;
29use chrono::{DateTime, SecondsFormat, Utc};
30use http::header::{
31    HeaderMap, HeaderName, HeaderValue, ACCEPT, AUTHORIZATION, CONTENT_ENCODING, CONTENT_LANGUAGE,
32    CONTENT_LENGTH, CONTENT_TYPE, DATE, IF_MATCH, IF_MODIFIED_SINCE, IF_NONE_MATCH,
33    IF_UNMODIFIED_SINCE, RANGE,
34};
35use http::Method;
36use serde::Deserialize;
37use std::borrow::Cow;
38use std::collections::HashMap;
39use std::fmt::Debug;
40use std::ops::Deref;
41use std::process::Command;
42use std::str;
43use std::sync::Arc;
44use std::time::{Duration, Instant, SystemTime};
45use url::Url;
46
47static AZURE_VERSION: HeaderValue = HeaderValue::from_static("2023-11-03");
48static VERSION: HeaderName = HeaderName::from_static("x-ms-version");
49pub(crate) static BLOB_TYPE: HeaderName = HeaderName::from_static("x-ms-blob-type");
50pub(crate) static DELETE_SNAPSHOTS: HeaderName = HeaderName::from_static("x-ms-delete-snapshots");
51pub(crate) static COPY_SOURCE: HeaderName = HeaderName::from_static("x-ms-copy-source");
52static CONTENT_MD5: HeaderName = HeaderName::from_static("content-md5");
53static PARTNER_TOKEN: HeaderName = HeaderName::from_static("x-ms-partner-token");
54static CLUSTER_IDENTIFIER: HeaderName = HeaderName::from_static("x-ms-cluster-identifier");
55static WORKLOAD_RESOURCE: HeaderName = HeaderName::from_static("x-ms-workload-resource-moniker");
56static PROXY_HOST: HeaderName = HeaderName::from_static("x-ms-proxy-host");
57pub(crate) const RFC1123_FMT: &str = "%a, %d %h %Y %T GMT";
58const CONTENT_TYPE_JSON: &str = "application/json";
59const MSI_SECRET_ENV_KEY: &str = "IDENTITY_HEADER";
60const MSI_API_VERSION: &str = "2019-08-01";
61const TOKEN_MIN_TTL: u64 = 300;
62
63/// OIDC scope used when interacting with OAuth2 APIs
64///
65/// <https://learn.microsoft.com/en-us/azure/active-directory/develop/scopes-oidc#the-default-scope>
66const AZURE_STORAGE_SCOPE: &str = "https://storage.azure.com/.default";
67
68/// Resource ID used when obtaining an access token from the metadata endpoint
69///
70/// <https://learn.microsoft.com/en-us/azure/storage/blobs/authorize-access-azure-active-directory#microsoft-authentication-library-msal>
71const AZURE_STORAGE_RESOURCE: &str = "https://storage.azure.com";
72
73#[derive(Debug, thiserror::Error)]
74pub enum Error {
75    #[error("Error performing token request: {}", source)]
76    TokenRequest {
77        source: crate::client::retry::RetryError,
78    },
79
80    #[error("Error getting token response body: {}", source)]
81    TokenResponseBody { source: HttpError },
82
83    #[error("Error reading federated token file ")]
84    FederatedTokenFile,
85
86    #[error("Invalid Access Key: {}", source)]
87    InvalidAccessKey { source: base64::DecodeError },
88
89    #[error("'az account get-access-token' command failed: {message}")]
90    AzureCli { message: String },
91
92    #[error("Failed to parse azure cli response: {source}")]
93    AzureCliResponse { source: serde_json::Error },
94
95    #[error("Generating SAS keys with SAS tokens auth is not supported")]
96    SASforSASNotSupported,
97}
98
99pub(crate) type Result<T, E = Error> = std::result::Result<T, E>;
100
101impl From<Error> for crate::Error {
102    fn from(value: Error) -> Self {
103        Self::Generic {
104            store: STORE,
105            source: Box::new(value),
106        }
107    }
108}
109
110/// A shared Azure Storage Account Key
111#[derive(Debug, Clone, Eq, PartialEq)]
112pub struct AzureAccessKey(Vec<u8>);
113
114impl AzureAccessKey {
115    /// Create a new [`AzureAccessKey`], checking it for validity
116    pub fn try_new(key: &str) -> Result<Self> {
117        let key = BASE64_STANDARD
118            .decode(key)
119            .map_err(|source| Error::InvalidAccessKey { source })?;
120
121        Ok(Self(key))
122    }
123}
124
125/// An Azure storage credential
126#[derive(Debug, Eq, PartialEq)]
127pub enum AzureCredential {
128    /// A shared access key
129    ///
130    /// <https://learn.microsoft.com/en-us/rest/api/storageservices/authorize-with-shared-key>
131    AccessKey(AzureAccessKey),
132    /// A shared access signature
133    ///
134    /// <https://learn.microsoft.com/en-us/rest/api/storageservices/delegate-access-with-shared-access-signature>
135    SASToken(Vec<(String, String)>),
136    /// An authorization token
137    ///
138    /// <https://learn.microsoft.com/en-us/rest/api/storageservices/authorize-with-azure-active-directory>
139    BearerToken(String),
140}
141
142impl AzureCredential {
143    /// Determines if the credential requires the request be treated as sensitive
144    pub fn sensitive_request(&self) -> bool {
145        match self {
146            Self::AccessKey(_) => false,
147            Self::BearerToken(_) => false,
148            // SAS tokens are sent as query parameters in the url
149            Self::SASToken(_) => true,
150        }
151    }
152}
153
154/// A list of known Azure authority hosts
155pub mod authority_hosts {
156    /// China-based Azure Authority Host
157    pub const AZURE_CHINA: &str = "https://login.chinacloudapi.cn";
158    /// Germany-based Azure Authority Host
159    pub const AZURE_GERMANY: &str = "https://login.microsoftonline.de";
160    /// US Government Azure Authority Host
161    pub const AZURE_GOVERNMENT: &str = "https://login.microsoftonline.us";
162    /// Public Cloud Azure Authority Host
163    pub const AZURE_PUBLIC_CLOUD: &str = "https://login.microsoftonline.com";
164}
165
166pub(crate) struct AzureSigner {
167    signing_key: AzureAccessKey,
168    start: DateTime<Utc>,
169    end: DateTime<Utc>,
170    account: String,
171    delegation_key: Option<UserDelegationKey>,
172}
173
174impl AzureSigner {
175    pub(crate) fn new(
176        signing_key: AzureAccessKey,
177        account: String,
178        start: DateTime<Utc>,
179        end: DateTime<Utc>,
180        delegation_key: Option<UserDelegationKey>,
181    ) -> Self {
182        Self {
183            signing_key,
184            account,
185            start,
186            end,
187            delegation_key,
188        }
189    }
190
191    pub(crate) fn sign(&self, method: &Method, url: &mut Url) -> Result<()> {
192        let (str_to_sign, query_pairs) = match &self.delegation_key {
193            Some(delegation_key) => string_to_sign_user_delegation_sas(
194                url,
195                method,
196                &self.account,
197                &self.start,
198                &self.end,
199                delegation_key,
200            ),
201            None => string_to_sign_service_sas(url, method, &self.account, &self.start, &self.end),
202        };
203        let auth = hmac_sha256(&self.signing_key.0, str_to_sign);
204        url.query_pairs_mut().extend_pairs(query_pairs);
205        url.query_pairs_mut()
206            .append_pair("sig", BASE64_STANDARD.encode(auth).as_str());
207        Ok(())
208    }
209}
210
211fn add_date_and_version_headers(request: &mut HttpRequest) {
212    // rfc2822 string should never contain illegal characters
213    let date = Utc::now();
214    let date_str = date.format(RFC1123_FMT).to_string();
215    // we formatted the data string ourselves, so unwrapping should be fine
216    let date_val = HeaderValue::from_str(&date_str).unwrap();
217    request.headers_mut().insert(DATE, date_val);
218    request
219        .headers_mut()
220        .insert(&VERSION, AZURE_VERSION.clone());
221}
222
223/// Authorize a [`HttpRequest`] with an [`AzureAuthorizer`]
224#[derive(Debug)]
225pub struct AzureAuthorizer<'a> {
226    credential: &'a AzureCredential,
227    account: &'a str,
228}
229
230impl<'a> AzureAuthorizer<'a> {
231    /// Create a new [`AzureAuthorizer`]
232    pub fn new(credential: &'a AzureCredential, account: &'a str) -> Self {
233        AzureAuthorizer {
234            credential,
235            account,
236        }
237    }
238
239    /// Authorize `request`
240    pub fn authorize(&self, request: &mut HttpRequest) {
241        add_date_and_version_headers(request);
242
243        match self.credential {
244            AzureCredential::AccessKey(key) => {
245                let url = Url::parse(&request.uri().to_string()).unwrap();
246                let signature = generate_authorization(
247                    request.headers(),
248                    &url,
249                    request.method(),
250                    self.account,
251                    key,
252                );
253
254                // "signature" is a base 64 encoded string so it should never
255                // contain illegal characters
256                request.headers_mut().append(
257                    AUTHORIZATION,
258                    HeaderValue::from_str(signature.as_str()).unwrap(),
259                );
260            }
261            AzureCredential::BearerToken(token) => {
262                request.headers_mut().append(
263                    AUTHORIZATION,
264                    HeaderValue::from_str(format!("Bearer {token}").as_str()).unwrap(),
265                );
266            }
267            AzureCredential::SASToken(query_pairs) => {
268                add_query_pairs(request.uri_mut(), query_pairs);
269            }
270        }
271    }
272}
273
274pub(crate) trait CredentialExt {
275    /// Apply authorization to requests against azure storage accounts
276    /// <https://docs.microsoft.com/en-us/rest/api/storageservices/authorize-requests-to-azure-storage>
277    fn with_azure_authorization(
278        self,
279        credential: &Option<impl Deref<Target = AzureCredential>>,
280        account: &str,
281    ) -> Self;
282}
283
284impl CredentialExt for HttpRequestBuilder {
285    fn with_azure_authorization(
286        self,
287        credential: &Option<impl Deref<Target = AzureCredential>>,
288        account: &str,
289    ) -> Self {
290        let (client, request) = self.into_parts();
291        let mut request = request.expect("request valid");
292
293        match credential.as_deref() {
294            Some(credential) => {
295                AzureAuthorizer::new(credential, account).authorize(&mut request);
296            }
297            None => {
298                add_date_and_version_headers(&mut request);
299            }
300        }
301
302        Self::from_parts(client, request)
303    }
304}
305
306/// Generate signed key for authorization via access keys
307/// <https://docs.microsoft.com/en-us/rest/api/storageservices/authorize-with-shared-key>
308fn generate_authorization(
309    h: &HeaderMap,
310    u: &Url,
311    method: &Method,
312    account: &str,
313    key: &AzureAccessKey,
314) -> String {
315    let str_to_sign = string_to_sign(h, u, method, account);
316    let auth = hmac_sha256(&key.0, str_to_sign);
317    format!("SharedKey {}:{}", account, BASE64_STANDARD.encode(auth))
318}
319
320fn add_if_exists<'a>(h: &'a HeaderMap, key: &HeaderName) -> &'a str {
321    h.get(key)
322        .map(|s| s.to_str())
323        .transpose()
324        .ok()
325        .flatten()
326        .unwrap_or_default()
327}
328
329fn string_to_sign_sas(
330    u: &Url,
331    method: &Method,
332    account: &str,
333    start: &DateTime<Utc>,
334    end: &DateTime<Utc>,
335) -> (String, String, String, String, String) {
336    // NOTE: for now only blob signing is supported.
337    let signed_resource = "b".to_string();
338
339    // https://learn.microsoft.com/en-us/rest/api/storageservices/create-service-sas#permissions-for-a-directory-container-or-blob
340    let signed_permissions = match *method {
341        // read and list permissions
342        Method::GET => match signed_resource.as_str() {
343            "c" => "rl",
344            "b" => "r",
345            _ => unreachable!(),
346        },
347        // write permissions (also allows crating a new blob in a sub-key)
348        Method::PUT => "w",
349        // delete permissions
350        Method::DELETE => "d",
351        // other methods are not used in any of the current operations
352        _ => "",
353    }
354    .to_string();
355    let signed_start = start.to_rfc3339_opts(SecondsFormat::Secs, true);
356    let signed_expiry = end.to_rfc3339_opts(SecondsFormat::Secs, true);
357    let canonicalized_resource = if u.host_str().unwrap_or_default().contains(account) {
358        format!("/blob/{}{}", account, u.path())
359    } else {
360        // NOTE: in case of the emulator, the account name is not part of the host
361        //      but the path starts with the account name
362        format!("/blob{}", u.path())
363    };
364
365    (
366        signed_resource,
367        signed_permissions,
368        signed_start,
369        signed_expiry,
370        canonicalized_resource,
371    )
372}
373
374/// Create a string to be signed for authorization via [service sas].
375///
376/// [service sas]: https://learn.microsoft.com/en-us/rest/api/storageservices/create-service-sas#version-2020-12-06-and-later
377fn string_to_sign_service_sas(
378    u: &Url,
379    method: &Method,
380    account: &str,
381    start: &DateTime<Utc>,
382    end: &DateTime<Utc>,
383) -> (String, HashMap<&'static str, String>) {
384    let (signed_resource, signed_permissions, signed_start, signed_expiry, canonicalized_resource) =
385        string_to_sign_sas(u, method, account, start, end);
386
387    let string_to_sign = format!(
388        "{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}",
389        signed_permissions,
390        signed_start,
391        signed_expiry,
392        canonicalized_resource,
393        "",                               // signed identifier
394        "",                               // signed ip
395        "",                               // signed protocol
396        &AZURE_VERSION.to_str().unwrap(), // signed version
397        signed_resource,                  // signed resource
398        "",                               // signed snapshot time
399        "",                               // signed encryption scope
400        "",                               // rscc - response header: Cache-Control
401        "",                               // rscd - response header: Content-Disposition
402        "",                               // rsce - response header: Content-Encoding
403        "",                               // rscl - response header: Content-Language
404        "",                               // rsct - response header: Content-Type
405    );
406
407    let mut pairs = HashMap::new();
408    pairs.insert("sv", AZURE_VERSION.to_str().unwrap().to_string());
409    pairs.insert("sp", signed_permissions);
410    pairs.insert("st", signed_start);
411    pairs.insert("se", signed_expiry);
412    pairs.insert("sr", signed_resource);
413
414    (string_to_sign, pairs)
415}
416
417/// Create a string to be signed for authorization via [user delegation sas].
418///
419/// [user delegation sas]: https://learn.microsoft.com/en-us/rest/api/storageservices/create-user-delegation-sas#version-2020-12-06-and-later
420fn string_to_sign_user_delegation_sas(
421    u: &Url,
422    method: &Method,
423    account: &str,
424    start: &DateTime<Utc>,
425    end: &DateTime<Utc>,
426    delegation_key: &UserDelegationKey,
427) -> (String, HashMap<&'static str, String>) {
428    let (signed_resource, signed_permissions, signed_start, signed_expiry, canonicalized_resource) =
429        string_to_sign_sas(u, method, account, start, end);
430
431    let string_to_sign = format!(
432        "{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}",
433        signed_permissions,
434        signed_start,
435        signed_expiry,
436        canonicalized_resource,
437        delegation_key.signed_oid,        // signed key object id
438        delegation_key.signed_tid,        // signed key tenant id
439        delegation_key.signed_start,      // signed key start
440        delegation_key.signed_expiry,     // signed key expiry
441        delegation_key.signed_service,    // signed key service
442        delegation_key.signed_version,    // signed key version
443        "",                               // signed authorized user object id
444        "",                               // signed unauthorized user object id
445        "",                               // signed correlation id
446        "",                               // signed ip
447        "",                               // signed protocol
448        &AZURE_VERSION.to_str().unwrap(), // signed version
449        signed_resource,                  // signed resource
450        "",                               // signed snapshot time
451        "",                               // signed encryption scope
452        "",                               // rscc - response header: Cache-Control
453        "",                               // rscd - response header: Content-Disposition
454        "",                               // rsce - response header: Content-Encoding
455        "",                               // rscl - response header: Content-Language
456        "",                               // rsct - response header: Content-Type
457    );
458
459    let mut pairs = HashMap::new();
460    pairs.insert("sv", AZURE_VERSION.to_str().unwrap().to_string());
461    pairs.insert("sp", signed_permissions);
462    pairs.insert("st", signed_start);
463    pairs.insert("se", signed_expiry);
464    pairs.insert("sr", signed_resource);
465    pairs.insert("skoid", delegation_key.signed_oid.clone());
466    pairs.insert("sktid", delegation_key.signed_tid.clone());
467    pairs.insert("skt", delegation_key.signed_start.clone());
468    pairs.insert("ske", delegation_key.signed_expiry.clone());
469    pairs.insert("sks", delegation_key.signed_service.clone());
470    pairs.insert("skv", delegation_key.signed_version.clone());
471
472    (string_to_sign, pairs)
473}
474
475/// <https://docs.microsoft.com/en-us/rest/api/storageservices/authorize-with-shared-key#constructing-the-signature-string>
476fn string_to_sign(h: &HeaderMap, u: &Url, method: &Method, account: &str) -> String {
477    // content length must only be specified if != 0
478    // this is valid from 2015-02-21
479    let content_length = h
480        .get(&CONTENT_LENGTH)
481        .map(|s| s.to_str())
482        .transpose()
483        .ok()
484        .flatten()
485        .filter(|&v| v != "0")
486        .unwrap_or_default();
487    format!(
488        "{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}{}",
489        method.as_ref(),
490        add_if_exists(h, &CONTENT_ENCODING),
491        add_if_exists(h, &CONTENT_LANGUAGE),
492        content_length,
493        add_if_exists(h, &CONTENT_MD5),
494        add_if_exists(h, &CONTENT_TYPE),
495        add_if_exists(h, &DATE),
496        add_if_exists(h, &IF_MODIFIED_SINCE),
497        add_if_exists(h, &IF_MATCH),
498        add_if_exists(h, &IF_NONE_MATCH),
499        add_if_exists(h, &IF_UNMODIFIED_SINCE),
500        add_if_exists(h, &RANGE),
501        canonicalize_header(h),
502        canonicalize_resource(account, u)
503    )
504}
505
506/// <https://docs.microsoft.com/en-us/rest/api/storageservices/authorize-with-shared-key#constructing-the-canonicalized-headers-string>
507fn canonicalize_header(headers: &HeaderMap) -> String {
508    let mut names = headers
509        .iter()
510        .filter(|&(k, _)| (k.as_str().starts_with("x-ms")))
511        // TODO remove unwraps
512        .map(|(k, _)| (k.as_str(), headers.get(k).unwrap().to_str().unwrap()))
513        .collect::<Vec<_>>();
514    names.sort_unstable();
515
516    let mut result = String::new();
517    for (name, value) in names {
518        result.push_str(name);
519        result.push(':');
520        result.push_str(value);
521        result.push('\n');
522    }
523    result
524}
525
526/// <https://docs.microsoft.com/en-us/rest/api/storageservices/authorize-with-shared-key#constructing-the-canonicalized-resource-string>
527fn canonicalize_resource(account: &str, uri: &Url) -> String {
528    let mut can_res: String = String::new();
529    can_res.push('/');
530    can_res.push_str(account);
531    can_res.push_str(uri.path().to_string().as_str());
532    can_res.push('\n');
533
534    // query parameters
535    let query_pairs = uri.query_pairs();
536    {
537        let mut qps: Vec<String> = Vec::new();
538        for (q, _) in query_pairs {
539            if !(qps.iter().any(|x| x == &*q)) {
540                qps.push(q.into_owned());
541            }
542        }
543
544        qps.sort();
545
546        for qparam in qps {
547            // find correct parameter
548            let ret = lexy_sort(query_pairs, &qparam);
549
550            can_res = can_res + &qparam.to_lowercase() + ":";
551
552            for (i, item) in ret.iter().enumerate() {
553                if i > 0 {
554                    can_res.push(',');
555                }
556                can_res.push_str(item);
557            }
558
559            can_res.push('\n');
560        }
561    };
562
563    can_res[0..can_res.len() - 1].to_owned()
564}
565
566fn lexy_sort<'a>(
567    vec: impl Iterator<Item = (Cow<'a, str>, Cow<'a, str>)> + 'a,
568    query_param: &str,
569) -> Vec<Cow<'a, str>> {
570    let mut values = vec
571        .filter(|(k, _)| *k == query_param)
572        .map(|(_, v)| v)
573        .collect::<Vec<_>>();
574    values.sort_unstable();
575    values
576}
577
578/// <https://learn.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-client-creds-grant-flow#successful-response-1>
579#[derive(Deserialize, Debug)]
580struct OAuthTokenResponse {
581    access_token: String,
582    expires_in: u64,
583}
584
585/// Encapsulates the logic to perform an OAuth token challenge
586///
587/// <https://learn.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-client-creds-grant-flow#first-case-access-token-request-with-a-shared-secret>
588#[derive(Debug)]
589pub(crate) struct ClientSecretOAuthProvider {
590    token_url: String,
591    client_id: String,
592    client_secret: String,
593}
594
595impl ClientSecretOAuthProvider {
596    /// Create a new [`ClientSecretOAuthProvider`] for an azure backed store
597    pub(crate) fn new(
598        client_id: String,
599        client_secret: String,
600        tenant_id: impl AsRef<str>,
601        authority_host: Option<String>,
602    ) -> Self {
603        let authority_host =
604            authority_host.unwrap_or_else(|| authority_hosts::AZURE_PUBLIC_CLOUD.to_owned());
605
606        Self {
607            token_url: format!(
608                "{}/{}/oauth2/v2.0/token",
609                authority_host,
610                tenant_id.as_ref()
611            ),
612            client_id,
613            client_secret,
614        }
615    }
616}
617
618#[async_trait::async_trait]
619impl TokenProvider for ClientSecretOAuthProvider {
620    type Credential = AzureCredential;
621
622    /// Fetch a token
623    async fn fetch_token(
624        &self,
625        client: &HttpClient,
626        retry: &RetryConfig,
627    ) -> crate::Result<TemporaryToken<Arc<AzureCredential>>> {
628        let response: OAuthTokenResponse = client
629            .request(Method::POST, &self.token_url)
630            .header(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_JSON))
631            .form([
632                ("client_id", self.client_id.as_str()),
633                ("client_secret", self.client_secret.as_str()),
634                ("scope", AZURE_STORAGE_SCOPE),
635                ("grant_type", "client_credentials"),
636            ])
637            .retryable(retry)
638            .idempotent(true)
639            .send()
640            .await
641            .map_err(|source| Error::TokenRequest { source })?
642            .into_body()
643            .json()
644            .await
645            .map_err(|source| Error::TokenResponseBody { source })?;
646
647        Ok(TemporaryToken {
648            token: Arc::new(AzureCredential::BearerToken(response.access_token)),
649            expiry: Some(Instant::now() + Duration::from_secs(response.expires_in)),
650        })
651    }
652}
653
654fn expires_on_string<'de, D>(deserializer: D) -> std::result::Result<Instant, D::Error>
655where
656    D: serde::de::Deserializer<'de>,
657{
658    let v = String::deserialize(deserializer)?;
659    let v = v.parse::<u64>().map_err(serde::de::Error::custom)?;
660    let now = SystemTime::now()
661        .duration_since(SystemTime::UNIX_EPOCH)
662        .map_err(serde::de::Error::custom)?;
663
664    Ok(Instant::now() + Duration::from_secs(v.saturating_sub(now.as_secs())))
665}
666
667/// NOTE: expires_on is a String version of unix epoch time, not an integer.
668/// <https://learn.microsoft.com/en-gb/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http>
669/// <https://learn.microsoft.com/en-us/azure/app-service/overview-managed-identity?tabs=portal%2Chttp#connect-to-azure-services-in-app-code>
670#[derive(Debug, Clone, Deserialize)]
671struct ImdsTokenResponse {
672    pub access_token: String,
673    #[serde(deserialize_with = "expires_on_string")]
674    pub expires_on: Instant,
675}
676
677/// Attempts authentication using a managed identity that has been assigned to the deployment environment.
678///
679/// This authentication type works in Azure VMs, App Service and Azure Functions applications, as well as the Azure Cloud Shell
680/// <https://learn.microsoft.com/en-gb/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http>
681#[derive(Debug)]
682pub(crate) struct ImdsManagedIdentityProvider {
683    msi_endpoint: String,
684    client_id: Option<String>,
685    object_id: Option<String>,
686    msi_res_id: Option<String>,
687}
688
689impl ImdsManagedIdentityProvider {
690    /// Create a new [`ImdsManagedIdentityProvider`] for an azure backed store
691    pub(crate) fn new(
692        client_id: Option<String>,
693        object_id: Option<String>,
694        msi_res_id: Option<String>,
695        msi_endpoint: Option<String>,
696    ) -> Self {
697        let msi_endpoint = msi_endpoint
698            .unwrap_or_else(|| "http://169.254.169.254/metadata/identity/oauth2/token".to_owned());
699
700        Self {
701            msi_endpoint,
702            client_id,
703            object_id,
704            msi_res_id,
705        }
706    }
707}
708
709#[async_trait::async_trait]
710impl TokenProvider for ImdsManagedIdentityProvider {
711    type Credential = AzureCredential;
712
713    /// Fetch a token
714    async fn fetch_token(
715        &self,
716        client: &HttpClient,
717        retry: &RetryConfig,
718    ) -> crate::Result<TemporaryToken<Arc<AzureCredential>>> {
719        let mut query_items = vec![
720            ("api-version", MSI_API_VERSION),
721            ("resource", AZURE_STORAGE_RESOURCE),
722        ];
723
724        let mut identity = None;
725        if let Some(client_id) = &self.client_id {
726            identity = Some(("client_id", client_id));
727        }
728        if let Some(object_id) = &self.object_id {
729            identity = Some(("object_id", object_id));
730        }
731        if let Some(msi_res_id) = &self.msi_res_id {
732            identity = Some(("msi_res_id", msi_res_id));
733        }
734        if let Some((key, value)) = identity {
735            query_items.push((key, value));
736        }
737
738        let mut builder = client
739            .request(Method::GET, &self.msi_endpoint)
740            .header("metadata", "true")
741            .query(&query_items);
742
743        if let Ok(val) = std::env::var(MSI_SECRET_ENV_KEY) {
744            builder = builder.header("x-identity-header", val);
745        };
746
747        let response: ImdsTokenResponse = builder
748            .send_retry(retry)
749            .await
750            .map_err(|source| Error::TokenRequest { source })?
751            .into_body()
752            .json()
753            .await
754            .map_err(|source| Error::TokenResponseBody { source })?;
755
756        Ok(TemporaryToken {
757            token: Arc::new(AzureCredential::BearerToken(response.access_token)),
758            expiry: Some(response.expires_on),
759        })
760    }
761}
762
763/// Credential for using workload identity federation
764///
765/// <https://learn.microsoft.com/en-us/azure/active-directory/develop/workload-identity-federation>
766#[derive(Debug)]
767pub(crate) struct WorkloadIdentityOAuthProvider {
768    token_url: String,
769    client_id: String,
770    federated_token_file: String,
771}
772
773impl WorkloadIdentityOAuthProvider {
774    /// Create a new [`WorkloadIdentityOAuthProvider`] for an azure backed store
775    pub(crate) fn new(
776        client_id: impl Into<String>,
777        federated_token_file: impl Into<String>,
778        tenant_id: impl AsRef<str>,
779        authority_host: Option<String>,
780    ) -> Self {
781        let authority_host =
782            authority_host.unwrap_or_else(|| authority_hosts::AZURE_PUBLIC_CLOUD.to_owned());
783
784        Self {
785            token_url: format!(
786                "{}/{}/oauth2/v2.0/token",
787                authority_host,
788                tenant_id.as_ref()
789            ),
790            client_id: client_id.into(),
791            federated_token_file: federated_token_file.into(),
792        }
793    }
794}
795
796#[async_trait::async_trait]
797impl TokenProvider for WorkloadIdentityOAuthProvider {
798    type Credential = AzureCredential;
799
800    /// Fetch a token
801    async fn fetch_token(
802        &self,
803        client: &HttpClient,
804        retry: &RetryConfig,
805    ) -> crate::Result<TemporaryToken<Arc<AzureCredential>>> {
806        let token_str = std::fs::read_to_string(&self.federated_token_file)
807            .map_err(|_| Error::FederatedTokenFile)?;
808
809        // https://learn.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-client-creds-grant-flow#third-case-access-token-request-with-a-federated-credential
810        let response: OAuthTokenResponse = client
811            .request(Method::POST, &self.token_url)
812            .header(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_JSON))
813            .form([
814                ("client_id", self.client_id.as_str()),
815                (
816                    "client_assertion_type",
817                    "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
818                ),
819                ("client_assertion", token_str.as_str()),
820                ("scope", AZURE_STORAGE_SCOPE),
821                ("grant_type", "client_credentials"),
822            ])
823            .retryable(retry)
824            .idempotent(true)
825            .send()
826            .await
827            .map_err(|source| Error::TokenRequest { source })?
828            .into_body()
829            .json()
830            .await
831            .map_err(|source| Error::TokenResponseBody { source })?;
832
833        Ok(TemporaryToken {
834            token: Arc::new(AzureCredential::BearerToken(response.access_token)),
835            expiry: Some(Instant::now() + Duration::from_secs(response.expires_in)),
836        })
837    }
838}
839
840mod az_cli_date_format {
841    use chrono::{DateTime, TimeZone};
842    use serde::{self, Deserialize, Deserializer};
843
844    pub(crate) fn deserialize<'de, D>(deserializer: D) -> Result<DateTime<chrono::Local>, D::Error>
845    where
846        D: Deserializer<'de>,
847    {
848        let s = String::deserialize(deserializer)?;
849        // expiresOn from azure cli uses the local timezone
850        let date = chrono::NaiveDateTime::parse_from_str(&s, "%Y-%m-%d %H:%M:%S.%6f")
851            .map_err(serde::de::Error::custom)?;
852        chrono::Local
853            .from_local_datetime(&date)
854            .single()
855            .ok_or(serde::de::Error::custom(
856                "azure cli returned ambiguous expiry date",
857            ))
858    }
859}
860
861#[derive(Debug, Clone, Deserialize)]
862#[serde(rename_all = "camelCase")]
863struct AzureCliTokenResponse {
864    pub access_token: String,
865    #[serde(with = "az_cli_date_format")]
866    pub expires_on: DateTime<chrono::Local>,
867    pub token_type: String,
868}
869
870#[derive(Default, Debug)]
871pub(crate) struct AzureCliCredential {
872    cache: TokenCache<Arc<AzureCredential>>,
873}
874
875impl AzureCliCredential {
876    pub(crate) fn new() -> Self {
877        Self::default()
878    }
879
880    /// Fetch a token
881    async fn fetch_token(&self) -> Result<TemporaryToken<Arc<AzureCredential>>> {
882        // on window az is a cmd and it should be called like this
883        // see https://doc.rust-lang.org/nightly/std/process/struct.Command.html
884        let program = if cfg!(target_os = "windows") {
885            "cmd"
886        } else {
887            "az"
888        };
889        let mut args = Vec::new();
890        if cfg!(target_os = "windows") {
891            args.push("/C");
892            args.push("az");
893        }
894        args.push("account");
895        args.push("get-access-token");
896        args.push("--output");
897        args.push("json");
898        args.push("--scope");
899        args.push(AZURE_STORAGE_SCOPE);
900
901        match Command::new(program).args(args).output() {
902            Ok(az_output) if az_output.status.success() => {
903                let output = str::from_utf8(&az_output.stdout).map_err(|_| Error::AzureCli {
904                    message: "az response is not a valid utf-8 string".to_string(),
905                })?;
906
907                let token_response = serde_json::from_str::<AzureCliTokenResponse>(output)
908                    .map_err(|source| Error::AzureCliResponse { source })?;
909
910                if !token_response.token_type.eq_ignore_ascii_case("bearer") {
911                    return Err(Error::AzureCli {
912                        message: format!(
913                            "got unexpected token type from azure cli: {0}",
914                            token_response.token_type
915                        ),
916                    });
917                }
918                let duration =
919                    token_response.expires_on.naive_local() - chrono::Local::now().naive_local();
920                Ok(TemporaryToken {
921                    token: Arc::new(AzureCredential::BearerToken(token_response.access_token)),
922                    expiry: Some(
923                        Instant::now()
924                            + duration.to_std().map_err(|_| Error::AzureCli {
925                                message: "az returned invalid lifetime".to_string(),
926                            })?,
927                    ),
928                })
929            }
930            Ok(az_output) => {
931                let message = String::from_utf8_lossy(&az_output.stderr);
932                Err(Error::AzureCli {
933                    message: message.into(),
934                })
935            }
936            Err(e) => match e.kind() {
937                std::io::ErrorKind::NotFound => Err(Error::AzureCli {
938                    message: "Azure Cli not installed".into(),
939                }),
940                error_kind => Err(Error::AzureCli {
941                    message: format!("io error: {error_kind:?}"),
942                }),
943            },
944        }
945    }
946}
947
948/// Encapsulates the logic to perform an OAuth token challenge for Fabric
949#[derive(Debug)]
950pub(crate) struct FabricTokenOAuthProvider {
951    fabric_token_service_url: String,
952    fabric_workload_host: String,
953    fabric_session_token: String,
954    fabric_cluster_identifier: String,
955    storage_access_token: Option<String>,
956    token_expiry: Option<u64>,
957}
958
959#[derive(Debug, Deserialize)]
960struct Claims {
961    exp: u64,
962}
963
964impl FabricTokenOAuthProvider {
965    /// Create a new [`FabricTokenOAuthProvider`] for an azure backed store
966    pub(crate) fn new(
967        fabric_token_service_url: impl Into<String>,
968        fabric_workload_host: impl Into<String>,
969        fabric_session_token: impl Into<String>,
970        fabric_cluster_identifier: impl Into<String>,
971        storage_access_token: Option<String>,
972    ) -> Self {
973        let (storage_access_token, token_expiry) = match storage_access_token {
974            Some(token) => match Self::validate_and_get_expiry(&token) {
975                Some(expiry) if expiry > Self::get_current_timestamp() + TOKEN_MIN_TTL => {
976                    (Some(token), Some(expiry))
977                }
978                _ => (None, None),
979            },
980            None => (None, None),
981        };
982
983        Self {
984            fabric_token_service_url: fabric_token_service_url.into(),
985            fabric_workload_host: fabric_workload_host.into(),
986            fabric_session_token: fabric_session_token.into(),
987            fabric_cluster_identifier: fabric_cluster_identifier.into(),
988            storage_access_token,
989            token_expiry,
990        }
991    }
992
993    fn validate_and_get_expiry(token: &str) -> Option<u64> {
994        let payload = token.split('.').nth(1)?;
995        let decoded_bytes = BASE64_URL_SAFE_NO_PAD.decode(payload).ok()?;
996        let decoded_str = str::from_utf8(&decoded_bytes).ok()?;
997        let claims: Claims = serde_json::from_str(decoded_str).ok()?;
998        Some(claims.exp)
999    }
1000
1001    fn get_current_timestamp() -> u64 {
1002        SystemTime::now()
1003            .duration_since(SystemTime::UNIX_EPOCH)
1004            .map_or(0, |d| d.as_secs())
1005    }
1006}
1007
1008#[async_trait::async_trait]
1009impl TokenProvider for FabricTokenOAuthProvider {
1010    type Credential = AzureCredential;
1011
1012    /// Fetch a token
1013    async fn fetch_token(
1014        &self,
1015        client: &HttpClient,
1016        retry: &RetryConfig,
1017    ) -> crate::Result<TemporaryToken<Arc<AzureCredential>>> {
1018        if let Some(storage_access_token) = &self.storage_access_token {
1019            if let Some(expiry) = self.token_expiry {
1020                let exp_in = expiry - Self::get_current_timestamp();
1021                if exp_in > TOKEN_MIN_TTL {
1022                    return Ok(TemporaryToken {
1023                        token: Arc::new(AzureCredential::BearerToken(storage_access_token.clone())),
1024                        expiry: Some(Instant::now() + Duration::from_secs(exp_in)),
1025                    });
1026                }
1027            }
1028        }
1029
1030        let query_items = vec![("resource", AZURE_STORAGE_RESOURCE)];
1031        let access_token: String = client
1032            .request(Method::GET, &self.fabric_token_service_url)
1033            .header(&PARTNER_TOKEN, self.fabric_session_token.as_str())
1034            .header(&CLUSTER_IDENTIFIER, self.fabric_cluster_identifier.as_str())
1035            .header(&WORKLOAD_RESOURCE, self.fabric_cluster_identifier.as_str())
1036            .header(&PROXY_HOST, self.fabric_workload_host.as_str())
1037            .query(&query_items)
1038            .retryable(retry)
1039            .idempotent(true)
1040            .send()
1041            .await
1042            .map_err(|source| Error::TokenRequest { source })?
1043            .into_body()
1044            .text()
1045            .await
1046            .map_err(|source| Error::TokenResponseBody { source })?;
1047        let exp_in = Self::validate_and_get_expiry(&access_token)
1048            .map_or(3600, |expiry| expiry - Self::get_current_timestamp());
1049        Ok(TemporaryToken {
1050            token: Arc::new(AzureCredential::BearerToken(access_token)),
1051            expiry: Some(Instant::now() + Duration::from_secs(exp_in)),
1052        })
1053    }
1054}
1055
1056#[async_trait]
1057impl CredentialProvider for AzureCliCredential {
1058    type Credential = AzureCredential;
1059
1060    async fn get_credential(&self) -> crate::Result<Arc<Self::Credential>> {
1061        Ok(self.cache.get_or_insert_with(|| self.fetch_token()).await?)
1062    }
1063}
1064
1065#[cfg(test)]
1066mod tests {
1067    use futures::executor::block_on;
1068    use http::{Response, StatusCode};
1069    use http_body_util::BodyExt;
1070    use reqwest::{Client, Method};
1071    use tempfile::NamedTempFile;
1072
1073    use super::*;
1074    use crate::azure::MicrosoftAzureBuilder;
1075    use crate::client::mock_server::MockServer;
1076    use crate::{ObjectStore, Path};
1077
1078    #[tokio::test]
1079    async fn test_managed_identity() {
1080        let server = MockServer::new().await;
1081
1082        std::env::set_var(MSI_SECRET_ENV_KEY, "env-secret");
1083
1084        let endpoint = server.url();
1085        let client = HttpClient::new(Client::new());
1086        let retry_config = RetryConfig::default();
1087
1088        // Test IMDS
1089        server.push_fn(|req| {
1090            assert_eq!(req.uri().path(), "/metadata/identity/oauth2/token");
1091            assert!(req.uri().query().unwrap().contains("client_id=client_id"));
1092            assert_eq!(req.method(), &Method::GET);
1093            let t = req
1094                .headers()
1095                .get("x-identity-header")
1096                .unwrap()
1097                .to_str()
1098                .unwrap();
1099            assert_eq!(t, "env-secret");
1100            let t = req.headers().get("metadata").unwrap().to_str().unwrap();
1101            assert_eq!(t, "true");
1102            Response::new(
1103                r#"
1104            {
1105                "access_token": "TOKEN",
1106                "refresh_token": "",
1107                "expires_in": "3599",
1108                "expires_on": "1506484173",
1109                "not_before": "1506480273",
1110                "resource": "https://management.azure.com/",
1111                "token_type": "Bearer"
1112              }
1113            "#
1114                .to_string(),
1115            )
1116        });
1117
1118        let credential = ImdsManagedIdentityProvider::new(
1119            Some("client_id".into()),
1120            None,
1121            None,
1122            Some(format!("{endpoint}/metadata/identity/oauth2/token")),
1123        );
1124
1125        let token = credential
1126            .fetch_token(&client, &retry_config)
1127            .await
1128            .unwrap();
1129
1130        assert_eq!(
1131            token.token.as_ref(),
1132            &AzureCredential::BearerToken("TOKEN".into())
1133        );
1134    }
1135
1136    #[tokio::test]
1137    async fn test_workload_identity() {
1138        let server = MockServer::new().await;
1139        let tokenfile = NamedTempFile::new().unwrap();
1140        let tenant = "tenant";
1141        std::fs::write(tokenfile.path(), "federated-token").unwrap();
1142
1143        let endpoint = server.url();
1144        let client = HttpClient::new(Client::new());
1145        let retry_config = RetryConfig::default();
1146
1147        // Test IMDS
1148        server.push_fn(move |req| {
1149            assert_eq!(req.uri().path(), format!("/{tenant}/oauth2/v2.0/token"));
1150            assert_eq!(req.method(), &Method::POST);
1151            let body = block_on(async move { req.into_body().collect().await.unwrap().to_bytes() });
1152            let body = String::from_utf8(body.to_vec()).unwrap();
1153            assert!(body.contains("federated-token"));
1154            Response::new(
1155                r#"
1156            {
1157                "access_token": "TOKEN",
1158                "refresh_token": "",
1159                "expires_in": 3599,
1160                "expires_on": "1506484173",
1161                "not_before": "1506480273",
1162                "resource": "https://management.azure.com/",
1163                "token_type": "Bearer"
1164              }
1165            "#
1166                .to_string(),
1167            )
1168        });
1169
1170        let credential = WorkloadIdentityOAuthProvider::new(
1171            "client_id",
1172            tokenfile.path().to_str().unwrap(),
1173            tenant,
1174            Some(endpoint.to_string()),
1175        );
1176
1177        let token = credential
1178            .fetch_token(&client, &retry_config)
1179            .await
1180            .unwrap();
1181
1182        assert_eq!(
1183            token.token.as_ref(),
1184            &AzureCredential::BearerToken("TOKEN".into())
1185        );
1186    }
1187
1188    #[tokio::test]
1189    async fn test_no_credentials() {
1190        let server = MockServer::new().await;
1191
1192        let endpoint = server.url();
1193        let store = MicrosoftAzureBuilder::new()
1194            .with_account("test")
1195            .with_container_name("test")
1196            .with_allow_http(true)
1197            .with_bearer_token_authorization("token")
1198            .with_endpoint(endpoint.to_string())
1199            .with_skip_signature(true)
1200            .build()
1201            .unwrap();
1202
1203        server.push_fn(|req| {
1204            assert_eq!(req.method(), &Method::GET);
1205            assert!(req.headers().get("Authorization").is_none());
1206            Response::builder()
1207                .status(StatusCode::NOT_FOUND)
1208                .body("not found".to_string())
1209                .unwrap()
1210        });
1211
1212        let path = Path::from("file.txt");
1213        match store.get(&path).await {
1214            Err(crate::Error::NotFound { .. }) => {}
1215            _ => {
1216                panic!("unexpected response");
1217            }
1218        }
1219    }
1220}