use std::borrow::Cow;
use std::collections::BTreeMap;
use base64::Engine;
use base64::prelude::BASE64_STANDARD;
use hmac::Mac;
use http_cache_stream_reqwest::semantics;
use reqwest::Method;
use reqwest::Request;
use reqwest::header;
use reqwest::header::HeaderMap;
use reqwest::header::HeaderName;
use reqwest::header::HeaderValue;
use secrecy::ExposeSecret;
use sha2::Sha256;
use url::form_urlencoded;
use crate::AzureAuthConfig;
fn is_microsoft_header(name: &HeaderName) -> bool {
name.as_str().starts_with("x-ms-")
}
type Hmac = hmac::Hmac<Sha256>;
pub struct RequestSigner<'a>(&'a AzureAuthConfig);
impl<'a> RequestSigner<'a> {
pub fn new(config: &'a AzureAuthConfig) -> Self {
Self(config)
}
pub fn sign(&self, request: &Request) -> Option<String> {
let canonical_headers = self.canonical_headers(
request
.headers()
.iter()
.filter(|(k, _)| is_microsoft_header(k)),
);
let canonical_resource =
self.canonical_resource(request.url().path(), request.url().query_pairs());
let string_to_sign = self.string_to_sign(
request.method(),
&canonical_headers,
&canonical_resource,
|name| request.headers().get(name),
);
self.authorization_header(&string_to_sign)
}
pub fn sign_revalidation(
&self,
request: &dyn semantics::RequestLike,
headers: &HeaderMap,
) -> Option<String> {
let uri = request.uri();
let canonical_headers = self.canonical_headers(
request
.headers()
.iter()
.filter(|(k, _)| is_microsoft_header(k)),
);
let canonical_resource = self.canonical_resource(
uri.path(),
form_urlencoded::parse(uri.query().unwrap_or("").as_bytes()),
);
let string_to_sign = self.string_to_sign(
request.method(),
&canonical_headers,
&canonical_resource,
|name| headers.get(&name).or_else(|| request.headers().get(name)),
);
self.authorization_header(&string_to_sign)
}
fn string_to_sign<'b>(
&self,
method: &Method,
canonical_headers: &str,
canonical_resource: &str,
headers: impl Fn(HeaderName) -> Option<&'b HeaderValue>,
) -> String {
format!(
"\
{method}
{content_encoding}
{content_language}
{content_length}
{content_md5}
{content_type}
{date}
{if_modified_since}
{if_match}
{if_none_match}
{if_unmodified_since}
{range}
{canonical_headers}{canonical_resource}",
content_encoding = headers(header::CONTENT_ENCODING)
.map(|v| v.to_str().expect("content-encoding should be a string"))
.unwrap_or(""),
content_language = headers(header::CONTENT_LANGUAGE)
.map(|v| v.to_str().expect("content-language should be a string"))
.unwrap_or(""),
content_length = headers(header::CONTENT_LENGTH)
.map(|v| v.to_str().expect("content-length should be a string"))
.unwrap_or(""),
content_md5 = headers(HeaderName::from_static("content-md5"))
.map(|v| v.to_str().expect("content-md5 should be a string"))
.unwrap_or(""),
content_type = headers(header::CONTENT_TYPE)
.map(|v| v.to_str().expect("content-type should be a string"))
.unwrap_or(""),
date = headers(header::DATE)
.map(|v| v.to_str().expect("date should be a string"))
.unwrap_or(""),
if_modified_since = headers(header::IF_MODIFIED_SINCE)
.map(|v| v.to_str().expect("if-modified-since should be a string"))
.unwrap_or(""),
if_match = headers(header::IF_MATCH)
.map(|v| v.to_str().expect("if-match should be a string"))
.unwrap_or(""),
if_none_match = headers(header::IF_NONE_MATCH)
.map(|v| v.to_str().expect("if-none-match should be a string"))
.unwrap_or(""),
if_unmodified_since = headers(header::IF_UNMODIFIED_SINCE)
.map(|v| v.to_str().expect("if-unmodified-since should be a string"))
.unwrap_or(""),
range = headers(header::RANGE)
.map(|v| v.to_str().expect("range should be a string"))
.unwrap_or(""),
)
}
fn canonical_headers<'b>(
&self,
microsoft_headers: impl Iterator<Item = (&'b HeaderName, &'b HeaderValue)>,
) -> String {
let mut headers = BTreeMap::new();
for (k, v) in microsoft_headers {
let value = v.to_str().expect("expected a string value");
debug_assert!(
!value.chars().any(|c| c.is_whitespace()),
"canonical Azure header contains whitespace"
);
if headers.insert(k.as_str(), value).is_some() {
panic!("duplicate header `{k}`", k = k.as_str());
}
}
let mut canonical_headers = String::new();
for (k, v) in headers {
canonical_headers.push_str(k);
canonical_headers.push(':');
canonical_headers.push_str(v);
canonical_headers.push('\n');
}
canonical_headers
}
fn canonical_resource<'b>(
&self,
path: &str,
query_pairs: impl Iterator<Item = (Cow<'b, str>, Cow<'b, str>)>,
) -> String {
let mut canonical_resource = String::new();
canonical_resource.push('/');
canonical_resource.push_str(self.0.account_name());
canonical_resource.push_str(path);
let mut parameters: BTreeMap<_, Vec<_>> = BTreeMap::new();
for (key, value) in query_pairs {
parameters
.entry(key.to_lowercase())
.or_default()
.push(value);
}
for (key, mut values) in parameters {
values.sort();
canonical_resource.push('\n');
canonical_resource.push_str(&key);
canonical_resource.push(':');
canonical_resource.push_str(&values.join(","));
}
canonical_resource
}
fn authorization_header(&self, string_to_sign: &str) -> Option<String> {
let mut hmac = Hmac::new_from_slice(
&BASE64_STANDARD
.decode(self.0.access_key().expose_secret())
.ok()?,
)
.ok()?;
hmac.update(string_to_sign.as_bytes());
let signature = BASE64_STANDARD.encode(hmac.finalize().into_bytes());
Some(format!(
"SharedKey {account_name}:{signature}",
account_name = self.0.account_name()
))
}
}