use super::policies::ClientRequestIdPolicy;
use crate::{
error::CheckSuccessOptions,
http::{
check_success,
headers::{RETRY_AFTER_MS, X_MS_RETRY_AFTER_MS},
policies::{
Policy, PublicApiInstrumentationPolicy, RequestInstrumentationPolicy, UserAgentPolicy,
},
ClientOptions,
},
};
use std::{
any::{Any, TypeId},
sync::Arc,
};
use typespec_client_core::http::{
self, headers::RETRY_AFTER, policies::RetryHeaders, PipelineOptions,
};
#[derive(Debug, Clone)]
pub struct Pipeline(http::Pipeline);
#[derive(Debug, Default)]
pub struct PipelineSendOptions {
pub skip_checks: bool,
pub check_success: CheckSuccessOptions,
}
#[derive(Debug, Default)]
struct CorePipelineSendOptions {
check_success: CheckSuccessOptions,
skip_checks: bool,
}
impl PipelineSendOptions {
fn deconstruct(self) -> (CorePipelineSendOptions, Option<http::PipelineSendOptions>) {
(
CorePipelineSendOptions {
skip_checks: self.skip_checks,
check_success: self.check_success,
},
None,
)
}
}
#[derive(Debug, Default)]
pub struct PipelineStreamOptions {
pub skip_checks: bool,
pub check_success: CheckSuccessOptions,
}
#[derive(Debug, Default)]
struct CorePipelineStreamOptions {
check_success: CheckSuccessOptions,
skip_checks: bool,
}
impl PipelineStreamOptions {
fn deconstruct(
self,
) -> (
CorePipelineStreamOptions,
Option<http::PipelineStreamOptions>,
) {
(
CorePipelineStreamOptions {
skip_checks: self.skip_checks,
check_success: self.check_success,
},
None,
)
}
}
impl Pipeline {
pub fn new(
crate_name: Option<&'static str>,
crate_version: Option<&'static str>,
options: ClientOptions,
per_call_policies: Vec<Arc<dyn Policy>>,
per_try_policies: Vec<Arc<dyn Policy>>,
pipeline_options: Option<PipelineOptions>,
) -> Self {
let (core_client_options, options) = options.deconstruct();
let tracer = core_client_options
.instrumentation
.tracer_provider
.map(|provider| {
provider.get_tracer(None, crate_name.unwrap_or("Unknown"), crate_version)
});
let mut per_call_policies = per_call_policies.clone();
push_unique(&mut per_call_policies, ClientRequestIdPolicy::default());
if let Some(ref tracer) = tracer {
let public_api_policy = PublicApiInstrumentationPolicy::new(Some(tracer.clone()));
push_unique(&mut per_call_policies, public_api_policy);
}
let user_agent_policy =
UserAgentPolicy::new(crate_name, crate_version, &core_client_options.user_agent);
push_unique(&mut per_call_policies, user_agent_policy);
let mut per_try_policies = per_try_policies.clone();
if let Some(ref tracer) = tracer {
let request_instrumentation_policy =
RequestInstrumentationPolicy::new(Some(tracer.clone()), &options.logging);
push_unique(&mut per_try_policies, request_instrumentation_policy);
}
let pipeline_options = pipeline_options.unwrap_or_else(|| PipelineOptions {
retry_headers: RetryHeaders {
retry_headers: vec![X_MS_RETRY_AFTER_MS, RETRY_AFTER_MS, RETRY_AFTER],
},
..PipelineOptions::default()
});
Self(http::Pipeline::new(
options,
per_call_policies,
per_try_policies,
Some(pipeline_options),
))
}
pub async fn send(
&self,
ctx: &http::Context<'_>,
request: &mut http::Request,
options: Option<PipelineSendOptions>,
) -> crate::Result<http::RawResponse> {
let (core_send_options, send_options) = options.unwrap_or_default().deconstruct();
let result = self.0.send(ctx, request, send_options).await?;
if !core_send_options.skip_checks {
check_success(result, Some(core_send_options.check_success)).await
} else {
Ok(result)
}
}
pub async fn stream(
&self,
ctx: &http::Context<'_>,
request: &mut http::Request,
options: Option<PipelineStreamOptions>,
) -> crate::Result<http::AsyncRawResponse> {
let (core_stream_options, stream_options) = options.unwrap_or_default().deconstruct();
let result = self.0.stream(ctx, request, stream_options).await?;
if !core_stream_options.skip_checks {
check_success(result, Some(core_stream_options.check_success)).await
} else {
Ok(result)
}
}
}
#[inline]
fn push_unique<T: Policy + 'static>(policies: &mut Vec<Arc<dyn Policy>>, policy: T) {
if policies.iter().all(|p| TypeId::of::<T>() != p.type_id()) {
policies.push(Arc::new(policy));
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
http::{
headers::{self, HeaderName, Headers},
policies::Policy,
request::options::ClientRequestId,
AsyncRawResponse, ClientOptions, Context, Method, Request, StatusCode, Transport,
UserAgentOptions,
},
Bytes,
};
use azure_core_test::http::MockHttpClient;
use futures::FutureExt as _;
use std::sync::Arc;
#[tokio::test]
async fn pipeline_with_custom_client_request_id_policy() {
const CUSTOM_HEADER_NAME: &str = "x-custom-request-id";
const CUSTOM_HEADER: HeaderName = HeaderName::from_static(CUSTOM_HEADER_NAME);
const CLIENT_REQUEST_ID: &str = "custom-request-id";
let mut ctx = Context::new();
ctx.insert(ClientRequestId::new(CLIENT_REQUEST_ID.to_string()));
let transport = Transport::new(Arc::new(MockHttpClient::new(|req| {
async {
let header_value = req
.headers()
.get_optional_str(&CUSTOM_HEADER)
.expect("Custom header should be present");
assert_eq!(
header_value, CLIENT_REQUEST_ID,
"Custom header value should match the client request ID"
);
Ok(AsyncRawResponse::from_bytes(
StatusCode::Ok,
Headers::new(),
Bytes::new(),
))
}
.boxed()
})));
let options = ClientOptions {
transport: Some(transport),
..Default::default()
};
let per_call_policies: Vec<Arc<dyn Policy>> =
vec![
Arc::new(ClientRequestIdPolicy::with_header_name(CUSTOM_HEADER_NAME))
as Arc<dyn Policy>,
];
let per_retry_policies = vec![];
let pipeline = Pipeline::new(
Some("test-crate"),
Some("1.0.0"),
options,
per_call_policies,
per_retry_policies,
None,
);
let mut request = Request::new("https://example.com".parse().unwrap(), Method::Get);
pipeline
.send(&ctx, &mut request, None)
.await
.expect("Pipeline execution failed");
}
#[tokio::test]
async fn pipeline_without_client_request_id_policy() {
const CLIENT_REQUEST_ID: &str = "default-request-id";
let mut ctx = Context::new();
ctx.insert(ClientRequestId::new(CLIENT_REQUEST_ID.to_string()));
let transport = Transport::new(Arc::new(MockHttpClient::new(|req| {
async {
let header_value = req
.headers()
.get_optional_str(&headers::CLIENT_REQUEST_ID)
.expect("Default header should be present");
assert_eq!(
header_value, CLIENT_REQUEST_ID,
"Default header value should match the client request ID"
);
Ok(AsyncRawResponse::from_bytes(
StatusCode::Ok,
Headers::new(),
Bytes::new(),
))
}
.boxed()
})));
let options = ClientOptions {
transport: Some(transport),
..Default::default()
};
let per_call_policies = vec![]; let per_retry_policies = vec![];
let pipeline = Pipeline::new(
Some("test-crate"),
Some("1.0.0"),
options,
per_call_policies,
per_retry_policies,
None,
);
let mut request = Request::new("https://example.com".parse().unwrap(), Method::Get);
pipeline
.send(&ctx, &mut request, None)
.await
.expect("Pipeline execution failed");
}
#[tokio::test]
async fn pipeline_with_user_agent_enabled_default() {
let ctx = Context::new();
let transport = Transport::new(Arc::new(MockHttpClient::new(|req| {
async {
let user_agent = req
.headers()
.get_optional_str(&headers::USER_AGENT)
.expect("User-Agent header should be present by default");
assert!(
user_agent.starts_with("azsdk-rust-test-crate/1.0.0 "),
"User-Agent header should start with expected prefix, got: {}",
user_agent
);
Ok(AsyncRawResponse::from_bytes(
StatusCode::Ok,
Headers::new(),
Bytes::new(),
))
}
.boxed()
})));
let options = ClientOptions {
transport: Some(transport),
..Default::default()
};
let per_call_policies = vec![];
let per_retry_policies = vec![];
let pipeline = Pipeline::new(
Some("test-crate"),
Some("1.0.0"),
options,
per_call_policies,
per_retry_policies,
None,
);
let mut request = Request::new("https://example.com".parse().unwrap(), Method::Get);
pipeline
.send(&ctx, &mut request, None)
.await
.expect("Pipeline execution failed");
}
#[tokio::test]
async fn pipeline_with_custom_application_id() {
const CUSTOM_APPLICATION_ID: &str = "my-custom-app/2.1.0";
let ctx = Context::new();
let transport = Transport::new(Arc::new(MockHttpClient::new(|req| {
async {
let user_agent = req
.headers()
.get_optional_str(&headers::USER_AGENT)
.expect("User-Agent header should be present");
assert!(
user_agent.starts_with("my-custom-app/2.1.0 azsdk-rust-test-crate/1.0.0 "),
"User-Agent header should start with custom application_id and expected prefix, got: {}",
user_agent
);
Ok(AsyncRawResponse::from_bytes(
StatusCode::Ok,
Headers::new(),
Bytes::new(),
))
}
.boxed()
})));
let user_agent_options = UserAgentOptions {
application_id: Some(CUSTOM_APPLICATION_ID.to_string()),
};
let options = ClientOptions {
transport: Some(transport),
user_agent: user_agent_options,
..Default::default()
};
let per_call_policies = vec![];
let per_retry_policies = vec![];
let pipeline = Pipeline::new(
Some("test-crate"),
Some("1.0.0"),
options,
per_call_policies,
per_retry_policies,
None,
);
let mut request = Request::new("https://example.com".parse().unwrap(), Method::Get);
pipeline
.send(&ctx, &mut request, None)
.await
.expect("Pipeline execution failed");
}
}