use super::{
AZ_CLIENT_REQUEST_ID_ATTRIBUTE, AZ_NAMESPACE_ATTRIBUTE, AZ_SERVICE_REQUEST_ID_ATTRIBUTE,
ERROR_TYPE_ATTRIBUTE, HTTP_REQUEST_METHOD_ATTRIBUTE, HTTP_REQUEST_RESEND_COUNT_ATTRIBUTE,
HTTP_RESPONSE_STATUS_CODE_ATTRIBUTE, SERVER_ADDRESS_ATTRIBUTE, SERVER_PORT_ATTRIBUTE,
URL_FULL_ATTRIBUTE,
};
use crate::{
http::{headers, Context, Request},
tracing::{Span, SpanKind},
};
use std::sync::Arc;
use typespec_client_core::{
http::policies::{Policy, PolicyResult, RetryPolicyCount},
tracing::Attribute,
};
#[derive(Clone, Debug)]
pub(crate) struct RequestInstrumentationPolicy {
tracer: Option<Arc<dyn crate::tracing::Tracer>>,
}
impl RequestInstrumentationPolicy {
pub fn new(tracer: Option<Arc<dyn crate::tracing::Tracer>>) -> Self {
Self { tracer }
}
}
#[async_trait::async_trait]
impl Policy for RequestInstrumentationPolicy {
async fn send(
&self,
ctx: &Context,
request: &mut Request,
next: &[Arc<dyn Policy>],
) -> PolicyResult {
#[allow(clippy::unnecessary_lazy_evaluations)]
let tracer = ctx
.value::<Arc<dyn crate::tracing::Tracer>>()
.or_else(|| self.tracer.as_ref());
let Some(tracer) = tracer else {
return next[0].send(ctx, request, &next[1..]).await;
};
let mut span_attributes = vec![Attribute {
key: HTTP_REQUEST_METHOD_ATTRIBUTE.into(),
value: request.method().to_string().into(),
}];
if let Some(namespace) = tracer.namespace() {
span_attributes.push(Attribute {
key: AZ_NAMESPACE_ATTRIBUTE.into(),
value: namespace.into(),
});
}
if request.url().username().is_empty() && request.url().password().is_none() {
span_attributes.push(Attribute {
key: URL_FULL_ATTRIBUTE.into(),
value: request.url().into(),
});
}
if let Some(host) = request.url().host() {
span_attributes.push(Attribute {
key: SERVER_ADDRESS_ATTRIBUTE.into(),
value: host.to_string().into(),
});
}
if let Some(port) = request.url().port_or_known_default() {
span_attributes.push(Attribute {
key: SERVER_PORT_ATTRIBUTE.into(),
value: port.into(),
});
}
let method_str = request.method().as_str();
let span = if let Some(parent_span) = ctx.value::<Arc<dyn Span>>() {
tracer.start_span_with_parent(
method_str.into(),
SpanKind::Client,
span_attributes,
parent_span.clone(),
)
} else {
tracer.start_span(method_str.into(), SpanKind::Client, span_attributes)
};
if span.is_recording() {
if let Some(client_request_id) = request
.headers()
.get_optional_str(&headers::CLIENT_REQUEST_ID)
{
span.set_attribute(AZ_CLIENT_REQUEST_ID_ATTRIBUTE, client_request_id.into());
}
if let Some(service_request_id) =
request.headers().get_optional_str(&headers::REQUEST_ID)
{
span.set_attribute(AZ_SERVICE_REQUEST_ID_ATTRIBUTE, service_request_id.into());
}
if let Some(retry_count) = ctx.value::<RetryPolicyCount>() {
if **retry_count > 0 {
span.set_attribute(HTTP_REQUEST_RESEND_COUNT_ATTRIBUTE, (**retry_count).into());
}
}
}
span.propagate_headers(request);
let result = next[0].send(ctx, request, &next[1..]).await;
if span.is_recording() {
if let Some(err) = result.as_ref().err() {
span.set_attribute(ERROR_TYPE_ATTRIBUTE, err.kind().to_string().into());
}
if let Ok(response) = result.as_ref() {
span.set_attribute(
HTTP_RESPONSE_STATUS_CODE_ATTRIBUTE,
u16::from(response.status()).into(),
);
if response.status().is_server_error() || response.status().is_client_error() {
span.set_status(crate::tracing::SpanStatus::Error {
description: "".to_string(),
});
span.set_attribute(ERROR_TYPE_ATTRIBUTE, response.status().to_string().into());
}
}
}
span.end();
return result;
}
}
#[cfg(test)]
pub(crate) mod tests {
use super::*;
use crate::{
http::{
headers::{HeaderName, Headers},
policies::TransportPolicy,
AsyncRawResponse, Method, StatusCode, Transport,
},
tracing::{AttributeValue, SpanStatus, TracerProvider},
Result, Uuid,
};
use azure_core_test::{
http::MockHttpClient,
tracing::{
check_instrumentation_result, ExpectedSpanInformation, ExpectedTracerInformation,
MockTracingProvider,
},
};
use futures::future::BoxFuture;
use std::sync::Arc;
async fn run_instrumentation_test<C>(
test_namespace: Option<&'static str>,
crate_name: Option<&'static str>,
version: Option<&'static str>,
request: &mut Request,
callback: C,
) -> Arc<MockTracingProvider>
where
C: FnMut(&Request) -> BoxFuture<'_, Result<AsyncRawResponse>> + Send + Sync + 'static,
{
let mock_tracer_provider = Arc::new(MockTracingProvider::new());
let tracer = mock_tracer_provider.get_tracer(
test_namespace,
crate_name.unwrap_or("unknown"),
version,
);
let policy = Arc::new(RequestInstrumentationPolicy::new(Some(tracer.clone())));
let transport =
TransportPolicy::new(Transport::new(Arc::new(MockHttpClient::new(callback))));
let ctx = Context::default();
let next: Vec<Arc<dyn Policy>> = vec![Arc::new(transport)];
let _result = policy.send(&ctx, request, &next).await;
mock_tracer_provider
}
#[tokio::test]
async fn simple_instrumentation_policy() {
let url = "http://example.com/path?query=value&api-version=2024-01-01";
let mut request = Request::new(url.parse().unwrap(), Method::Get);
let mock_tracer = run_instrumentation_test(
Some("test namespace"),
Some("test_crate"),
Some("1.0.0"),
&mut request,
|req| {
Box::pin(async move {
assert_eq!(req.url().host_str(), Some("example.com"));
assert_eq!(req.method(), Method::Get);
Ok(AsyncRawResponse::from_bytes(
StatusCode::Ok,
Headers::new(),
vec![],
))
})
},
)
.await;
check_instrumentation_result(
mock_tracer,
vec![ExpectedTracerInformation {
namespace: Some("test namespace"),
name: "test_crate",
version: Some("1.0.0"),
spans: vec![ExpectedSpanInformation {
span_name: "GET",
status: SpanStatus::Unset,
kind: SpanKind::Client,
span_id: Uuid::new_v4(),
parent_id: None,
attributes: vec![
(
AZ_NAMESPACE_ATTRIBUTE,
AttributeValue::from("test namespace"),
),
(
HTTP_RESPONSE_STATUS_CODE_ATTRIBUTE,
AttributeValue::from(200),
),
(HTTP_REQUEST_METHOD_ATTRIBUTE, AttributeValue::from("GET")),
(
SERVER_ADDRESS_ATTRIBUTE,
AttributeValue::from("example.com"),
),
(SERVER_PORT_ATTRIBUTE, AttributeValue::from(80)),
(
URL_FULL_ATTRIBUTE,
AttributeValue::from(
"http://example.com/path?query=value&api-version=2024-01-01",
),
),
],
..Default::default()
}],
}],
);
}
#[test]
fn test_request_instrumentation_policy_creation() {
let policy = RequestInstrumentationPolicy::new(None);
assert!(policy.tracer.is_none());
let mock_tracer_provider = Arc::new(MockTracingProvider::new());
let tracer =
mock_tracer_provider.get_tracer(Some("test namespace"), "test_crate", Some("1.0.0"));
let policy_with_tracer = RequestInstrumentationPolicy::new(Some(tracer));
assert!(policy_with_tracer.tracer.is_some());
}
#[test]
fn test_request_instrumentation_policy_without_tracer() {
let policy = RequestInstrumentationPolicy::new(None);
assert!(policy.tracer.is_none());
}
#[tokio::test]
async fn client_request_id() {
let url = "https://example.com/client_request_id";
let mut request = Request::new(url.parse().unwrap(), Method::Get);
request.insert_header(headers::CLIENT_REQUEST_ID, "test-client-request-id");
let mock_tracer = run_instrumentation_test(
None,
Some("test_crate"),
Some("1.0.0"),
&mut request,
|req| {
Box::pin(async move {
assert_eq!(req.url().host_str(), Some("example.com"));
assert_eq!(req.method(), Method::Get);
assert_eq!(
req.headers()
.get_optional_str(&HeaderName::from_static("traceparent")),
Some("00-<trace_id>-<span_id>-01")
);
Ok(AsyncRawResponse::from_bytes(
StatusCode::Ok,
Headers::new(),
vec![],
))
})
},
)
.await;
check_instrumentation_result(
mock_tracer,
vec![ExpectedTracerInformation {
namespace: None,
name: "test_crate",
version: Some("1.0.0"),
spans: vec![ExpectedSpanInformation {
span_name: "GET",
status: SpanStatus::Unset,
kind: SpanKind::Client,
span_id: Uuid::new_v4(),
parent_id: None,
attributes: vec![
(
AZ_CLIENT_REQUEST_ID_ATTRIBUTE,
AttributeValue::from("test-client-request-id"),
),
(
HTTP_RESPONSE_STATUS_CODE_ATTRIBUTE,
AttributeValue::from(200),
),
(HTTP_REQUEST_METHOD_ATTRIBUTE, AttributeValue::from("GET")),
(
SERVER_ADDRESS_ATTRIBUTE,
AttributeValue::from("example.com"),
),
(SERVER_PORT_ATTRIBUTE, AttributeValue::from(443)),
(
URL_FULL_ATTRIBUTE,
AttributeValue::from("https://example.com/client_request_id"),
),
],
..Default::default()
}],
}],
);
}
#[tokio::test]
async fn test_url_with_password() {
let url = "https://user:password@host:8080/path?query=value#fragment";
let mut request = Request::new(url.parse().unwrap(), Method::Get);
let mock_tracer_provider =
run_instrumentation_test(None, None, None, &mut request, |req| {
Box::pin(async move {
assert_eq!(req.url().host_str(), Some("host"));
assert_eq!(req.method(), Method::Get);
Ok(AsyncRawResponse::from_bytes(
StatusCode::Ok,
Headers::new(),
vec![],
))
})
})
.await;
check_instrumentation_result(
mock_tracer_provider,
vec![ExpectedTracerInformation {
namespace: None,
name: "unknown",
version: None,
spans: vec![ExpectedSpanInformation {
span_name: "GET",
status: SpanStatus::Unset,
kind: SpanKind::Client,
span_id: Uuid::new_v4(),
parent_id: None,
attributes: vec![
(
HTTP_RESPONSE_STATUS_CODE_ATTRIBUTE,
AttributeValue::from(200),
),
(HTTP_REQUEST_METHOD_ATTRIBUTE, AttributeValue::from("GET")),
(SERVER_ADDRESS_ATTRIBUTE, AttributeValue::from("host")),
(SERVER_PORT_ATTRIBUTE, AttributeValue::from(8080)),
],
..Default::default()
}],
}],
);
}
#[tokio::test]
async fn request_failed() {
let url = "https://microsoft.com/request_failed.htm";
let mut request = Request::new(url.parse().unwrap(), Method::Put);
request.insert_header(headers::REQUEST_ID, "test-service-request-id");
let mock_tracer = run_instrumentation_test(
Some("test namespace"),
Some("test_crate"),
Some("1.0.0"),
&mut request,
|req| {
Box::pin(async move {
assert_eq!(req.url().host_str(), Some("microsoft.com"));
assert_eq!(req.method(), Method::Put);
Ok(AsyncRawResponse::from_bytes(
StatusCode::NotFound,
Headers::new(),
vec![],
))
})
},
)
.await;
check_instrumentation_result(
mock_tracer,
vec![ExpectedTracerInformation {
namespace: Some("test namespace"),
name: "test_crate",
version: Some("1.0.0"),
spans: vec![ExpectedSpanInformation {
span_name: "PUT",
status: SpanStatus::Error {
description: "".to_string(),
},
kind: SpanKind::Client,
span_id: Uuid::new_v4(),
parent_id: None,
attributes: vec![
(ERROR_TYPE_ATTRIBUTE, AttributeValue::from("404")),
(
AZ_SERVICE_REQUEST_ID_ATTRIBUTE,
AttributeValue::from("test-service-request-id"),
),
(
AZ_NAMESPACE_ATTRIBUTE,
AttributeValue::from("test namespace"),
),
(
HTTP_RESPONSE_STATUS_CODE_ATTRIBUTE,
AttributeValue::from(404),
),
(HTTP_REQUEST_METHOD_ATTRIBUTE, AttributeValue::from("PUT")),
(
SERVER_ADDRESS_ATTRIBUTE,
AttributeValue::from("microsoft.com"),
),
(SERVER_PORT_ATTRIBUTE, AttributeValue::from(443)),
(
URL_FULL_ATTRIBUTE,
AttributeValue::from("https://microsoft.com/request_failed.htm"),
),
],
..Default::default()
}],
}],
);
}
}