#[cfg_attr(not(feature = "key_auth"), allow(unused_imports))]
use azure_core::{
credentials::{Secret, TokenCredential},
fmt::SafeDebug,
http::{
headers::{HeaderValue, AUTHORIZATION, MS_DATE, VERSION},
policies::{Policy, PolicyResult},
request::Request,
Context,
},
time::{self, OffsetDateTime},
};
use std::sync::Arc;
use tracing::{debug, trace};
use crate::{pipeline::signature_target::SignatureTarget, resource_context::ResourceLink};
use crate::utils::url_encode;
const AZURE_VERSION: &str = "2020-07-15";
const COSMOS_AAD_SCOPE: &str = "https://cosmos.azure.com/.default";
#[derive(SafeDebug, Clone)]
#[safe(false)]
enum Credential {
Token(Arc<dyn TokenCredential>),
#[cfg(feature = "key_auth")]
PrimaryKey(Secret),
}
#[derive(SafeDebug, Clone)]
#[safe(true)]
pub(crate) struct AuthorizationPolicy {
credential: Credential,
}
impl AuthorizationPolicy {
pub(crate) fn from_token_credential(token: Arc<dyn TokenCredential>) -> Self {
Self {
credential: Credential::Token(token),
}
}
#[cfg(feature = "key_auth")]
pub(crate) fn from_shared_key(key: Secret) -> Self {
Self {
credential: Credential::PrimaryKey(key),
}
}
}
#[async_trait::async_trait]
impl Policy for AuthorizationPolicy {
async fn send(
&self,
ctx: &Context,
request: &mut Request,
next: &[Arc<dyn Policy>],
) -> PolicyResult {
trace!("called AuthorizationPolicy::send. self == {:#?}", self);
assert!(
!next.is_empty(),
"Authorization policies cannot be the last policy of a pipeline"
);
let date_string = time::to_rfc7231(&OffsetDateTime::now_utc()).to_lowercase();
let resource_link: &ResourceLink = ctx
.value()
.expect("ResourceContext should have been provided by CosmosPipeline");
debug!(?resource_link, "generating authorization for resource");
let auth = generate_authorization(
&self.credential,
SignatureTarget::new(request.method(), resource_link, &date_string),
)
.await?;
request.insert_header(MS_DATE, HeaderValue::from(date_string));
request.insert_header(VERSION, HeaderValue::from_static(AZURE_VERSION));
request.insert_header(AUTHORIZATION, HeaderValue::from(auth));
next[0].send(ctx, request, &next[1..]).await
}
}
async fn generate_authorization(
auth_token: &Credential,
#[allow(unused_variables)] signature_target: SignatureTarget<'_>,
) -> azure_core::Result<String> {
let token = match auth_token {
Credential::Token(token_credential) => {
let token = token_credential
.get_token(&[COSMOS_AAD_SCOPE], None)
.await?
.token
.secret()
.to_string();
format!("type=aad&ver=1.0&sig={token}")
}
#[cfg(feature = "key_auth")]
Credential::PrimaryKey(key) => signature_target.into_authorization(key)?,
};
Ok(url_encode(token))
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use azure_core::{
credentials::{AccessToken, TokenCredential, TokenRequestOptions},
http::Method,
time::{Duration, OffsetDateTime},
};
use crate::{
pipeline::{
authorization_policy::{generate_authorization, Credential, COSMOS_AAD_SCOPE},
signature_target::SignatureTarget,
},
resource_context::{ResourceLink, ResourceType},
utils::url_encode,
};
#[derive(Debug)]
struct TestTokenCredential(String);
#[async_trait::async_trait]
impl TokenCredential for TestTokenCredential {
async fn get_token(
&self,
scopes: &[&str],
_: Option<TokenRequestOptions<'_>>,
) -> azure_core::Result<AccessToken> {
let token = format!("{}+{}", self.0, scopes.join(","));
Ok(AccessToken::new(
token,
OffsetDateTime::now_utc().saturating_add(Duration::minutes(5)),
))
}
}
#[tokio::test]
async fn generate_authorization_for_token_credential() {
let time_nonce =
azure_core::time::parse_rfc3339("1900-01-01T01:00:00.000000000+00:00").unwrap();
let date_string = azure_core::time::to_rfc7231(&time_nonce).to_lowercase();
let cred = Arc::new(TestTokenCredential("test_token".to_string()));
let auth_token = Credential::Token(cred);
let ret = generate_authorization(
&auth_token,
SignatureTarget::new(
Method::Get,
&ResourceLink::root(ResourceType::Databases).item("ToDoList"),
&date_string,
),
)
.await
.unwrap();
let expected: String =
url_encode(format!("type=aad&ver=1.0&sig=test_token+{}", COSMOS_AAD_SCOPE).as_bytes());
assert_eq!(ret, expected);
}
#[tokio::test]
#[cfg(feature = "key_auth")]
async fn generate_authorization_for_primary_key_0() {
let time_nonce =
azure_core::time::parse_rfc3339("1900-01-01T01:00:00.000000000+00:00").unwrap();
let date_string = azure_core::time::to_rfc7231(&time_nonce).to_lowercase();
let auth_token = Credential::PrimaryKey(
"8F8xXXOptJxkblM1DBXW7a6NMI5oE8NnwPGYBmwxLCKfejOK7B7yhcCHMGvN3PBrlMLIOeol1Hv9RCdzAZR5sg==".into(),
);
let ret = generate_authorization(
&auth_token,
SignatureTarget::new(
Method::Get,
&ResourceLink::root(ResourceType::Databases)
.item("MyDatabase")
.feed(ResourceType::Containers)
.item("MyCollection"),
&date_string,
),
)
.await
.unwrap();
let expected: String =
url_encode(b"type=master&ver=1.0&sig=vrHmd02almbIg1e4htVWH+Eg/OhEHip3VTwFivZLH0A=");
assert_eq!(ret, expected);
}
#[tokio::test]
#[cfg(feature = "key_auth")]
async fn generate_authorization_for_primary_key_1() {
let time_nonce =
azure_core::time::parse_rfc3339("2017-04-27T00:51:12.000000000+00:00").unwrap();
let date_string = azure_core::time::to_rfc7231(&time_nonce).to_lowercase();
let auth_token = Credential::PrimaryKey(
"dsZQi3KtZmCv1ljt3VNWNm7sQUF1y5rJfC6kv5JiwvW0EndXdDku/dkKBp8/ufDToSxL".into(),
);
let ret = generate_authorization(
&auth_token,
SignatureTarget::new(
Method::Get,
&ResourceLink::root(ResourceType::Databases).item("ToDoList"),
&date_string,
),
)
.await
.unwrap();
let expected: String =
url_encode(b"type=master&ver=1.0&sig=KvBM8vONofkv3yKm/8zD9MEGlbu6jjHDJBp4E9c2ZZI=");
assert_eq!(ret, expected);
}
#[tokio::test]
async fn aad_token_uses_constant_scope() {
use std::sync::Mutex;
#[derive(Debug)]
struct ScopeCapturingCredential {
captured_scopes: Arc<Mutex<Vec<Vec<String>>>>,
}
#[async_trait::async_trait]
impl TokenCredential for ScopeCapturingCredential {
async fn get_token(
&self,
scopes: &[&str],
_: Option<TokenRequestOptions<'_>>,
) -> azure_core::Result<AccessToken> {
self.captured_scopes
.lock()
.unwrap()
.push(scopes.iter().map(|s| s.to_string()).collect());
Ok(AccessToken::new(
"mock_token".to_string(),
OffsetDateTime::now_utc().saturating_add(Duration::minutes(5)),
))
}
}
let captured_scopes = Arc::new(Mutex::new(Vec::new()));
let cred = Arc::new(ScopeCapturingCredential {
captured_scopes: captured_scopes.clone(),
});
let auth_token = Credential::Token(cred);
let time_nonce =
azure_core::time::parse_rfc3339("1900-01-01T01:00:00.000000000+00:00").unwrap();
let date_string = azure_core::time::to_rfc7231(&time_nonce).to_lowercase();
let _result = generate_authorization(
&auth_token,
SignatureTarget::new(
Method::Get,
&ResourceLink::root(ResourceType::Databases).item("TestDB"),
&date_string,
),
)
.await
.unwrap();
let scopes = captured_scopes.lock().unwrap();
assert_eq!(scopes.len(), 1, "get_token should be called exactly once");
assert_eq!(
scopes[0].len(),
1,
"get_token should be called with exactly one scope"
);
assert_eq!(
scopes[0][0], COSMOS_AAD_SCOPE,
"get_token should be called with COSMOS_AAD_SCOPE constant"
);
}
}