use crate::http::{
headers::{HeaderValue, USER_AGENT},
options::UserAgentOptions,
};
use std::{
env::consts::{ARCH, OS},
sync::Arc,
};
use typespec_client_core::http::{
policies::{Policy, PolicyResult},
Context, Request,
};
#[derive(Clone, Debug)]
pub struct UserAgentPolicy {
header: HeaderValue,
}
impl<'a> UserAgentPolicy {
pub fn new(
crate_name: Option<&'a str>,
crate_version: Option<&'a str>,
options: &UserAgentOptions,
) -> Self {
Self::new_with_rustc_version(
crate_name,
crate_version,
option_env!("AZSDK_RUSTC_VERSION"),
options,
)
}
fn new_with_rustc_version(
crate_name: Option<&'a str>,
crate_version: Option<&'a str>,
rustc_version: Option<&'a str>,
options: &UserAgentOptions,
) -> Self {
const UNKNOWN: &str = "unknown";
let mut crate_name = crate_name.unwrap_or(UNKNOWN);
let crate_version = crate_version.unwrap_or(UNKNOWN);
let rustc_version = rustc_version.unwrap_or(UNKNOWN);
let platform_info = format!("({rustc_version}; {OS}; {ARCH})",);
if let Some(name) = crate_name.strip_prefix("azure_") {
crate_name = name;
}
const MAX_APPLICATION_ID_LEN: usize = 24;
let header_str = match &options.application_id {
Some(application_id) => {
if application_id.len() > MAX_APPLICATION_ID_LEN {
panic!(
"application_id must be shorter than {} characters",
MAX_APPLICATION_ID_LEN + 1
);
}
if !application_id.chars().all(|c| {
c.is_ascii_alphanumeric()
|| matches!(
c,
'!' | '#'
| '$'
| '%'
| '&'
| '\''
| '*'
| '+'
| '-'
| '.'
| '^'
| '_'
| '`'
| '|'
| '~'
)
}) {
panic!("application_id contains invalid characters. Only RFC 9110 tokens are allowed.");
}
format!("{application_id} azsdk-rust-{crate_name}/{crate_version} {platform_info}")
}
None => format!("azsdk-rust-{crate_name}/{crate_version} {platform_info}"),
};
UserAgentPolicy {
header: HeaderValue::from(header_str),
}
}
}
#[async_trait::async_trait]
impl Policy for UserAgentPolicy {
async fn send(
&self,
ctx: &Context,
request: &mut Request,
next: &[Arc<dyn Policy>],
) -> PolicyResult {
request.insert_header(USER_AGENT, self.header.clone());
next[0].send(ctx, request, &next[1..]).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn without_application_id() {
let policy = UserAgentPolicy::new_with_rustc_version(
Some("azure_test"), Some("1.2.3"),
Some("4.5.6"),
&UserAgentOptions::default(),
);
assert_eq!(
policy.header.as_str(),
format!("azsdk-rust-test/1.2.3 (4.5.6; {OS}; {ARCH})")
);
}
#[test]
fn with_application_id() {
let options = UserAgentOptions {
application_id: Some("my_app".to_string()),
};
let policy = UserAgentPolicy::new_with_rustc_version(
Some("test"),
Some("1.2.3"),
Some("4.5.6"),
&options,
);
assert_eq!(
policy.header.as_str(),
format!("my_app azsdk-rust-test/1.2.3 (4.5.6; {OS}; {ARCH})")
);
}
#[test]
fn missing_env() {
let policy =
UserAgentPolicy::new_with_rustc_version(None, None, None, &UserAgentOptions::default());
assert_eq!(
policy.header.as_str(),
format!("azsdk-rust-unknown/unknown (unknown; {OS}; {ARCH})")
);
}
#[test]
#[should_panic(expected = "application_id must be shorter than 25 characters")]
fn panics_when_application_id_too_long() {
let options = UserAgentOptions {
application_id: Some(
"this_application_id_is_way_too_long_and_exceeds_limit".to_string(),
), };
let _policy = UserAgentPolicy::new_with_rustc_version(
Some("test"),
Some("1.2.3"),
Some("4.5.6"),
&options,
);
}
#[test]
fn works_with_application_id_at_limit() {
let options = UserAgentOptions {
application_id: Some("exactly_24_characters!!!".to_string()), };
let policy = UserAgentPolicy::new_with_rustc_version(
Some("test"),
Some("1.2.3"),
Some("4.5.6"),
&options,
);
assert_eq!(
policy.header.as_str(),
format!("exactly_24_characters!!! azsdk-rust-test/1.2.3 (4.5.6; {OS}; {ARCH})")
);
}
#[test]
#[should_panic(
expected = "application_id contains invalid characters. Only RFC 9110 tokens are allowed."
)]
fn test_user_agent_invalid_chars() {
let options = UserAgentOptions {
application_id: Some("invalid application id".to_string()),
};
let _policy = UserAgentPolicy::new_with_rustc_version(
Some("test"),
Some("1.2.3"),
Some("4.5.6"),
&options,
);
}
}