use azure_core::http::{
headers::{HeaderValue, USER_AGENT},
policies::{Policy, PolicyResult},
Context, Request,
};
use std::sync::Arc;
#[derive(Clone, Debug)]
pub(crate) struct CosmosHeadersPolicy {
user_agent: HeaderValue,
}
impl CosmosHeadersPolicy {
pub(crate) fn new(crate_version: &str, suffix: Option<&str>) -> Self {
let user_agent = match suffix {
Some(s) if !s.is_empty() => format!("azsdk-rust-cosmos/{crate_version} {s}"),
_ => format!("azsdk-rust-cosmos/{crate_version}"),
};
Self {
user_agent: HeaderValue::from(user_agent),
}
}
}
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
impl Policy for CosmosHeadersPolicy {
async fn send(
&self,
ctx: &Context,
request: &mut Request,
next: &[Arc<dyn Policy>],
) -> PolicyResult {
request.insert_header(USER_AGENT, self.user_agent.clone());
next[0].send(ctx, request, &next[1..]).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use azure_core::http::{Method, Url};
#[derive(Debug)]
struct MockTransport;
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
impl Policy for MockTransport {
async fn send(
&self,
_ctx: &Context,
_request: &mut Request,
_next: &[Arc<dyn Policy>],
) -> PolicyResult {
Err(azure_core::Error::with_message(
azure_core::error::ErrorKind::Other,
"mock transport",
))
}
}
#[tokio::test]
async fn sets_user_agent_without_suffix() {
let policy = CosmosHeadersPolicy::new("0.31.0", None);
let transport: Arc<dyn Policy> = Arc::new(MockTransport);
let policies: Vec<Arc<dyn Policy>> = vec![transport];
let url = Url::parse("https://test.documents.azure.com/").unwrap();
let mut request = Request::new(url, Method::Get);
let ctx = Context::default();
let _ = policy.send(&ctx, &mut request, &policies).await;
assert_eq!(
request.headers().get_optional_str(&USER_AGENT),
Some("azsdk-rust-cosmos/0.31.0")
);
}
#[tokio::test]
async fn sets_user_agent_with_suffix() {
let policy = CosmosHeadersPolicy::new("0.31.0", Some("my-app"));
let transport: Arc<dyn Policy> = Arc::new(MockTransport);
let policies: Vec<Arc<dyn Policy>> = vec![transport];
let url = Url::parse("https://test.documents.azure.com/").unwrap();
let mut request = Request::new(url, Method::Get);
let ctx = Context::default();
let _ = policy.send(&ctx, &mut request, &policies).await;
assert_eq!(
request.headers().get_optional_str(&USER_AGENT),
Some("azsdk-rust-cosmos/0.31.0 my-app")
);
}
#[tokio::test]
async fn overrides_existing_user_agent() {
let policy = CosmosHeadersPolicy::new("0.31.0", None);
let transport: Arc<dyn Policy> = Arc::new(MockTransport);
let policies: Vec<Arc<dyn Policy>> = vec![transport];
let url = Url::parse("https://test.documents.azure.com/").unwrap();
let mut request = Request::new(url, Method::Get);
request
.headers_mut()
.insert(USER_AGENT, HeaderValue::from_static("azure-core-default"));
let ctx = Context::default();
let _ = policy.send(&ctx, &mut request, &policies).await;
assert_eq!(
request.headers().get_optional_str(&USER_AGENT),
Some("azsdk-rust-cosmos/0.31.0")
);
}
}