azure_identity 0.35.0

Rust wrappers around Microsoft Azure REST APIs - Azure identity helper crate
Documentation
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

use crate::{authentication_error, get_authority_host, TokenCache};
use azure_core::credentials::TokenRequestOptions;
use azure_core::http::PipelineSendOptions;
use azure_core::Result;
use azure_core::{
    credentials::{AccessToken, Secret, TokenCredential},
    error::{ErrorKind, ResultExt},
    http::{
        headers::{self, content_type},
        ClientOptions, Method, Pipeline, Request, Url,
    },
    Error,
};
use std::{str, sync::Arc};
use url::form_urlencoded;

/// Options for constructing a new [`ClientSecretCredential`].
#[derive(Debug, Default)]
pub struct ClientSecretCredentialOptions {
    /// Options for the credential's HTTP pipeline.
    pub client_options: ClientOptions,
}

/// Authenticates an application with a client secret.
#[derive(Debug)]
pub struct ClientSecretCredential {
    cache: TokenCache,
    client_id: String,
    endpoint: Url,
    pipeline: Pipeline,
    secret: Secret,
}

impl ClientSecretCredential {
    /// Create a new `ClientSecretCredential`.
    ///
    /// # Arguments
    /// - `tenant_id`: The tenant (directory) ID of the service principal.
    /// - `client_id`: The client (application) ID of the service principal.
    /// - `secret`: The client secret that was generated for the service principal.
    /// - `options`: Options for configuring the credential. If `None`, the credential uses its default options.
    ///
    pub fn new(
        tenant_id: &str,
        client_id: String,
        secret: Secret,
        options: Option<ClientSecretCredentialOptions>,
    ) -> Result<Arc<Self>> {
        crate::validate_tenant_id(tenant_id)?;
        crate::validate_not_empty(&client_id, "no client ID specified")?;
        crate::validate_not_empty(secret.secret(), "no secret specified")?;

        let options = options.unwrap_or_default();
        let authority_host = get_authority_host(None, options.client_options.cloud.as_deref())?;
        let endpoint = authority_host
            .join(&format!("/{tenant_id}/oauth2/v2.0/token"))
            .with_context_fn(ErrorKind::DataConversion, || {
                format!("tenant_id '{tenant_id}' could not be URL encoded")
            })?;

        let pipeline = Pipeline::new(
            option_env!("CARGO_PKG_NAME"),
            option_env!("CARGO_PKG_VERSION"),
            options.client_options,
            Vec::default(),
            Vec::default(),
            None,
        );

        Ok(Arc::new(Self {
            cache: TokenCache::new(),
            client_id,
            endpoint,
            pipeline,
            secret,
        }))
    }

    async fn get_token_impl(
        &self,
        scopes: &[&str],
        options: Option<TokenRequestOptions<'_>>,
    ) -> Result<AccessToken> {
        let mut req = Request::new(self.endpoint.clone(), Method::Post);
        req.insert_header(
            headers::CONTENT_TYPE,
            content_type::APPLICATION_X_WWW_FORM_URLENCODED,
        );
        let body = form_urlencoded::Serializer::new(String::new())
            .append_pair("client_id", &self.client_id)
            .append_pair("client_secret", self.secret.secret())
            .append_pair("grant_type", "client_credentials")
            .append_pair("scope", &scopes.join(" "))
            .finish();
        req.set_body(body);

        let options = options.unwrap_or_default();
        let ctx = options.method_options.context.to_borrowed();
        let res = self
            .pipeline
            .send(
                &ctx,
                &mut req,
                Some(PipelineSendOptions {
                    skip_checks: true,
                    ..Default::default()
                }),
            )
            .await?;

        crate::handle_entra_response(res)
    }
}

#[async_trait::async_trait]
impl TokenCredential for ClientSecretCredential {
    async fn get_token(
        &self,
        scopes: &[&str],
        options: Option<TokenRequestOptions<'_>>,
    ) -> Result<AccessToken> {
        if scopes.is_empty() {
            return Err(Error::with_message(
                ErrorKind::Credential,
                "no scopes specified",
            ));
        }
        self.cache
            .get_token(scopes, options, |s, o| self.get_token_impl(s, o))
            .await
            .map_err(|err| authentication_error(stringify!(ClientSecretCredential), err))
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::tests::*;
    use azure_core::{
        http::{headers::Headers, AsyncRawResponse, RawResponse, StatusCode, Transport},
        Bytes, Result,
    };
    use std::vec;
    use time::OffsetDateTime;

    const FAKE_SECRET: &str = "fake secret";

    fn is_valid_request(expected_authority: String) -> impl Fn(&Request) -> Result<()> {
        let expected_url = format!("{}/oauth2/v2.0/token", expected_authority);
        move |req: &Request| {
            assert_eq!(Method::Post, req.method());
            assert_eq!(expected_url, req.url().to_string());
            assert_eq!(
                req.headers().get_str(&headers::CONTENT_TYPE).unwrap(),
                content_type::APPLICATION_X_WWW_FORM_URLENCODED.as_str()
            );
            Ok(())
        }
    }

    #[tokio::test]
    async fn cloud_configuration() {
        for (cloud, expected_authority) in cloud_configuration_cases() {
            let sts = MockSts::new(
                vec![token_response()],
                Some(Arc::new(is_valid_request(expected_authority))),
            );
            let credential = ClientSecretCredential::new(
                FAKE_TENANT_ID,
                FAKE_CLIENT_ID.to_string(),
                FAKE_SECRET.into(),
                Some(ClientSecretCredentialOptions {
                    client_options: ClientOptions {
                        transport: Some(Transport::new(Arc::new(sts))),
                        cloud: Some(Arc::new(cloud)),
                        ..Default::default()
                    },
                }),
            )
            .expect("valid credential");

            credential
                .get_token(LIVE_TEST_SCOPES, None)
                .await
                .expect("token");
        }
    }

    #[tokio::test]
    async fn get_token_error() {
        let body = Bytes::from(
            r#"{"error":"invalid_client","error_description":"AADSTS7000215: Invalid client secret.","error_codes":[7000215],"timestamp":"2025-04-04 21:10:04Z","trace_id":"...","correlation_id":"...","error_uri":"https://login.microsoftonline.com/error?code=7000215"}"#,
        );
        let expected_status = StatusCode::BadRequest;
        let mut headers = Headers::default();
        headers.insert("key", "value");
        let expected_response =
            RawResponse::from_bytes(expected_status, headers.clone(), body.clone());
        let sts = MockSts::new(
            vec![AsyncRawResponse::from_bytes(expected_status, headers, body)],
            Some(Arc::new(is_valid_request(
                FAKE_PUBLIC_CLOUD_AUTHORITY.to_string(),
            ))),
        );
        let cred = ClientSecretCredential::new(
            FAKE_TENANT_ID,
            FAKE_CLIENT_ID.to_string(),
            FAKE_SECRET.into(),
            Some(ClientSecretCredentialOptions {
                client_options: ClientOptions {
                    transport: Some(Transport::new(Arc::new(sts))),
                    ..Default::default()
                },
            }),
        )
        .expect("valid credential");

        let err = cred
            .get_token(LIVE_TEST_SCOPES, None)
            .await
            .expect_err("expected error");
        assert!(matches!(err.kind(), ErrorKind::Credential));
        assert_eq!(
            "ClientSecretCredential authentication failed. AADSTS7000215: Invalid client secret.\nTo troubleshoot, visit https://aka.ms/azsdk/rust/identity/troubleshoot#client-secret",
            err.to_string(),
        );
        match err
            .downcast_ref::<azure_core::Error>()
            .expect("returned error should wrap an azure_core::Error")
            .kind()
        {
            ErrorKind::HttpResponse {
                error_code: Some(error_code),
                raw_response: Some(response),
                status,
            } => {
                assert_eq!("7000215", error_code);
                assert_eq!(&expected_response, response.as_ref());
                assert_eq!(expected_status, *status);
            }
            err => panic!("unexpected {:?}", err),
        };
    }

    #[tokio::test]
    async fn get_token_success() {
        let expires_in = 3600;
        let sts = MockSts::new(
            vec![token_response()],
            Some(Arc::new(is_valid_request(
                FAKE_PUBLIC_CLOUD_AUTHORITY.to_string(),
            ))),
        );
        let cred = ClientSecretCredential::new(
            FAKE_TENANT_ID,
            FAKE_CLIENT_ID.to_string(),
            FAKE_SECRET.into(),
            Some(ClientSecretCredentialOptions {
                client_options: ClientOptions {
                    transport: Some(Transport::new(Arc::new(sts))),
                    ..Default::default()
                },
            }),
        )
        .expect("valid credential");
        let token = cred.get_token(LIVE_TEST_SCOPES, None).await.expect("token");

        assert_eq!(FAKE_TOKEN, token.token.secret());

        // allow a small margin when validating expiration time because it's computed as
        // the current time plus a number of seconds (expires_in) and the system clock
        // may have ticked into the next second since we assigned expires_in above
        let lifetime =
            token.expires_on.unix_timestamp() - OffsetDateTime::now_utc().unix_timestamp();
        assert!(
            (expires_in..expires_in + 1).contains(&lifetime),
            "token should expire in ~{} seconds but actually expires in {} seconds",
            expires_in,
            lifetime
        );

        // sts will return an error if the credential sends another request
        let cached_token = cred
            .get_token(LIVE_TEST_SCOPES, None)
            .await
            .expect("cached token");
        assert_eq!(token.token.secret(), cached_token.token.secret());
        assert_eq!(token.expires_on, cached_token.expires_on);
    }

    #[test]
    fn invalid_tenant_id() {
        ClientSecretCredential::new(
            "not a valid tenant",
            FAKE_CLIENT_ID.to_string(),
            FAKE_SECRET.into(),
            None,
        )
        .expect_err("invalid tenant ID");
    }

    #[tokio::test]
    async fn no_scopes() {
        ClientSecretCredential::new(
            FAKE_TENANT_ID,
            FAKE_CLIENT_ID.to_string(),
            FAKE_SECRET.into(),
            None,
        )
        .expect("valid credential")
        .get_token(&[], None)
        .await
        .expect_err("no scopes specified");
    }
}