azure_core 0.33.0

Rust wrappers around Microsoft Azure REST APIs - Core crate
Documentation
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

//! Azure authentication and authorization.

use crate::Bytes;
use serde::{Deserialize, Serialize};
use std::{borrow::Cow, fmt};
use typespec_client_core::{fmt::SafeDebug, http::ClientMethodOptions, time::OffsetDateTime};

/// Represents a secret.
///
/// The [`Debug`] implementation will not print the secret.
#[derive(Clone, Deserialize, Serialize, Eq)]
pub struct Secret(Cow<'static, str>);

impl Secret {
    /// Create a new `Secret`.
    pub fn new<T>(access_token: T) -> Self
    where
        T: Into<Cow<'static, str>>,
    {
        Self(access_token.into())
    }

    /// Get the secret value.
    pub fn secret(&self) -> &str {
        &self.0
    }
}

// NOTE: this is a constant time compare, however LLVM may (and probably will)
// optimize this in unexpected ways.
impl PartialEq for Secret {
    fn eq(&self, other: &Self) -> bool {
        let a = self.secret();
        let b = other.secret();

        if a.len() != b.len() {
            return false;
        }

        a.bytes()
            .zip(b.bytes())
            .fold(0, |acc, (a, b)| acc | (a ^ b))
            == 0
    }
}

impl From<String> for Secret {
    fn from(access_token: String) -> Self {
        Self::new(access_token)
    }
}

impl From<&'static str> for Secret {
    fn from(access_token: &'static str) -> Self {
        Self::new(access_token)
    }
}

impl fmt::Debug for Secret {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.write_str("Secret")
    }
}

/// Represents secret bytes, e.g., certificate data.
///
/// Neither the [`Debug`](fmt::Debug) nor the [`Display`](fmt::Display) implementation will print the bytes.
#[derive(Clone, Eq)]
pub struct SecretBytes(Vec<u8>);

impl SecretBytes {
    /// Create a new `SecretBytes`.
    pub fn new(bytes: impl Into<Vec<u8>>) -> Self {
        Self(bytes.into())
    }

    /// Get the secret bytes.
    pub fn bytes(&self) -> &[u8] {
        &self.0
    }
}

// NOTE: this is a constant time compare, however LLVM may (and probably will)
// optimize this in unexpected ways.
impl PartialEq for SecretBytes {
    fn eq(&self, other: &Self) -> bool {
        let a = self.bytes();
        let b = other.bytes();

        if a.len() != b.len() {
            return false;
        }

        a.iter().zip(b.iter()).fold(0, |acc, (a, b)| acc | (a ^ b)) == 0
    }
}

impl fmt::Debug for SecretBytes {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.write_str("SecretBytes")
    }
}

impl fmt::Display for SecretBytes {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.write_str("SecretBytes")
    }
}

impl From<Bytes> for SecretBytes {
    fn from(bytes: Bytes) -> Self {
        Self(bytes.to_vec())
    }
}

impl From<&[u8]> for SecretBytes {
    fn from(bytes: &[u8]) -> Self {
        Self(bytes.to_vec())
    }
}

impl From<Vec<u8>> for SecretBytes {
    fn from(bytes: Vec<u8>) -> Self {
        Self(bytes)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn debug_does_not_print_bytes() {
        let secret = SecretBytes::new(b"super-secret".to_vec());
        assert_eq!("SecretBytes", format!("{secret:?}"));
    }

    #[test]
    fn display_does_not_print_bytes() {
        let secret = SecretBytes::new(b"super-secret".to_vec());
        assert_eq!("SecretBytes", format!("{secret}"));
    }

    #[test]
    fn eq_same_bytes() {
        let a = SecretBytes::new(b"hello".to_vec());
        let b = SecretBytes::new(b"hello".to_vec());
        assert_eq!(a, b);
    }

    #[test]
    fn ne_different_bytes() {
        let a = SecretBytes::new(b"hello".to_vec());
        let b = SecretBytes::new(b"world".to_vec());
        assert_ne!(a, b);
    }

    #[test]
    fn ne_different_lengths() {
        let a = SecretBytes::new(b"hello".to_vec());
        let b = SecretBytes::new(b"hello!".to_vec());
        assert_ne!(a, b);
    }

    #[test]
    fn from_bytes_type() {
        let bytes = Bytes::from_static(b"data");
        let secret = SecretBytes::from(bytes);
        assert_eq!(b"data", secret.bytes());
    }

    #[test]
    fn from_slice() {
        let data: &[u8] = b"data";
        let secret = SecretBytes::from(data);
        assert_eq!(b"data", secret.bytes());
    }

    #[test]
    fn from_vec() {
        let secret = SecretBytes::from(b"data".to_vec());
        assert_eq!(b"data", secret.bytes());
    }
}

/// Represents an Azure service bearer access token with expiry information.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AccessToken {
    /// Get the access token value.
    pub token: Secret,
    /// Gets the time when the provided token expires.
    pub expires_on: OffsetDateTime,
}

impl AccessToken {
    /// Create a new `AccessToken`.
    pub fn new<T>(token: T, expires_on: OffsetDateTime) -> Self
    where
        T: Into<Secret>,
    {
        Self {
            token: token.into(),
            expires_on,
        }
    }
}

/// Options for getting a token from a [`TokenCredential`]
#[derive(Clone, Default, SafeDebug)]
pub struct TokenRequestOptions<'a> {
    /// Method options to be used when requesting a token.
    pub method_options: ClientMethodOptions<'a>,
}

/// Represents a credential capable of providing an OAuth token.
#[async_trait::async_trait]
pub trait TokenCredential: Send + Sync + fmt::Debug {
    /// Gets an [`AccessToken`] for the specified scopes
    async fn get_token(
        &self,
        scopes: &[&str],
        options: Option<TokenRequestOptions<'_>>,
    ) -> crate::Result<AccessToken>;
}