use azure_core::{
http::{headers::HeaderName, Context, Request},
tracing::{
AsAny, Attribute, AttributeValue, Span, SpanGuard, SpanKind, SpanStatus, Tracer,
TracerProvider,
},
Uuid,
};
use rand::{rng, RngExt};
use std::{
borrow::Cow,
collections::HashMap,
fmt::Debug,
sync::{Arc, Mutex},
};
use tracing::{trace, warn};
#[derive(Debug)]
pub struct MockTracingProvider {
tracers: Mutex<Vec<Arc<MockTracer>>>,
}
impl MockTracingProvider {
pub fn new() -> Self {
Self {
tracers: Mutex::new(Vec::new()),
}
}
}
impl Default for MockTracingProvider {
fn default() -> Self {
Self::new()
}
}
impl TracerProvider for MockTracingProvider {
fn get_tracer(
&self,
azure_namespace: Option<&'static str>,
crate_name: &'static str,
crate_version: Option<&'static str>,
) -> Arc<dyn crate::tracing::Tracer> {
let mut tracers = self.tracers.lock().unwrap();
let tracer = Arc::new(MockTracer {
namespace: azure_namespace,
package_name: crate_name,
package_version: crate_version,
spans: Mutex::new(Vec::new()),
});
tracers.push(tracer.clone());
tracer
}
}
#[derive(Debug)]
pub struct MockTracer {
namespace: Option<&'static str>,
package_name: &'static str,
package_version: Option<&'static str>,
spans: Mutex<Vec<Arc<MockSpanInner>>>,
}
impl Tracer for MockTracer {
fn namespace(&self) -> Option<&'static str> {
self.namespace
}
fn start_span_with_parent(
&self,
name: Cow<'static, str>,
kind: SpanKind,
attributes: Vec<Attribute>,
parent: Arc<dyn crate::tracing::Span>,
) -> Arc<dyn crate::tracing::Span> {
let span = Arc::new(MockSpanInner::new(
name,
kind,
attributes.clone(),
Some(parent),
));
self.spans.lock().unwrap().push(span.clone());
Arc::new(MockSpan { inner: span })
}
fn start_span(
&self,
name: Cow<'static, str>,
kind: SpanKind,
attributes: Vec<Attribute>,
) -> Arc<dyn Span> {
let attributes = attributes
.into_iter()
.map(|attr| Attribute {
key: attr.key.clone(),
value: attr.value.clone(),
})
.collect();
let span = Arc::new(MockSpanInner::new(name, kind, attributes, None));
self.spans.lock().unwrap().push(span.clone());
Arc::new(MockSpan { inner: span })
}
}
#[derive(Debug)]
struct MockSpanInner {
pub name: Cow<'static, str>,
pub kind: SpanKind,
pub parent: Option<[u8; 8]>,
pub id: [u8; 8],
pub attributes: Mutex<Vec<Attribute>>,
pub state: Mutex<SpanStatus>,
pub is_open: Mutex<bool>,
}
impl MockSpanInner {
fn new<C>(
name: C,
kind: SpanKind,
attributes: Vec<Attribute>,
parent: Option<Arc<dyn crate::tracing::Span>>,
) -> Self
where
C: Into<Cow<'static, str>> + Debug,
{
trace!("Creating MockSpan: {:?}", name);
trace!("Attributes: {:?}", attributes);
let id = rng().random();
let parent = parent.map(|p| p.span_id());
Self {
name: name.into(),
kind,
parent,
id,
attributes: Mutex::new(attributes),
state: Mutex::new(SpanStatus::Unset),
is_open: Mutex::new(true),
}
}
fn is_open(&self) -> bool {
let is_open = self.is_open.lock().unwrap();
*is_open
}
}
impl AsAny for MockSpanInner {
fn as_any(&self) -> &dyn std::any::Any {
self as &dyn std::any::Any
}
}
impl Span for MockSpanInner {
fn set_attribute(&self, key: &'static str, value: AttributeValue) {
trace!("{}: Setting attribute {}: {:?}", self.name, key, value);
let mut attributes = self.attributes.lock().unwrap();
attributes.push(Attribute {
key: key.into(),
value,
});
}
fn set_status(&self, status: crate::tracing::SpanStatus) {
trace!("{}: Setting span status: {:?}", self.name, status);
let mut state = self.state.lock().unwrap();
*state = status;
}
fn end(&self) {
trace!("Ending span: {}", self.name);
let mut is_open = self.is_open.lock().unwrap();
*is_open = false;
}
fn is_recording(&self) -> bool {
true
}
fn span_id(&self) -> [u8; 8] {
self.id
}
fn record_error(&self, _error: &dyn std::error::Error) {
todo!()
}
fn set_current(&self, _context: &Context) -> Box<dyn SpanGuard> {
todo!()
}
fn propagate_headers(&self, request: &mut Request) {
request.insert_header(
HeaderName::from_static("traceparent"),
"00-<trace_id>-<span_id>-01",
);
request.insert_header(HeaderName::from_static("tracestate"), "<key>=<value>");
}
}
pub struct MockSpan {
inner: Arc<MockSpanInner>,
}
impl Drop for MockSpan {
fn drop(&mut self) {
if self.inner.is_open() {
warn!("Dropping open span: {}", self.inner.name);
self.inner.end();
}
}
}
impl AsAny for MockSpan {
fn as_any(&self) -> &dyn std::any::Any {
self as &dyn std::any::Any
}
}
impl Span for MockSpan {
fn set_attribute(&self, key: &'static str, value: AttributeValue) {
self.inner.set_attribute(key, value);
}
fn set_status(&self, status: crate::tracing::SpanStatus) {
self.inner.set_status(status);
}
fn end(&self) {
self.inner.end();
}
fn is_recording(&self) -> bool {
self.inner.is_recording()
}
fn span_id(&self) -> [u8; 8] {
self.inner.span_id()
}
fn record_error(&self, error: &dyn std::error::Error) {
self.inner.record_error(error);
}
fn set_current(&self, context: &Context) -> Box<dyn SpanGuard> {
self.inner.set_current(context)
}
fn propagate_headers(&self, request: &mut Request) {
self.inner.propagate_headers(request);
}
}
#[derive(Debug)]
pub struct ExpectedTracerInformation<'a> {
pub name: &'a str,
pub version: Option<&'a str>,
pub namespace: Option<&'a str>,
pub spans: Vec<ExpectedSpanInformation<'a>>,
}
pub fn check_instrumentation_result(
mock_tracer: Arc<MockTracingProvider>,
expected_tracers: Vec<ExpectedTracerInformation<'_>>,
) {
let tracers = mock_tracer.tracers.lock().unwrap();
if tracers.len() != expected_tracers.len() {
trace!("Expected tracers: {:?}", expected_tracers);
trace!("Found tracers: {:?}", tracers);
}
assert_eq!(
tracers.len(),
expected_tracers.len(),
"Unexpected number of tracers, expected: {}, found: {}",
expected_tracers.len(),
tracers.len()
);
for (index, expected) in expected_tracers.iter().enumerate() {
trace!("Checking tracer {}: {}", index, expected.name);
let tracer = &tracers[index];
let mut parent_span_map = HashMap::new();
assert_eq!(tracer.package_name, expected.name);
assert_eq!(tracer.package_version, expected.version);
assert_eq!(tracer.namespace, expected.namespace);
let spans = tracer.spans.lock().unwrap();
if !expected.spans.iter().any(|s| s.is_wildcard) {
assert_eq!(
spans.len(),
expected.spans.len(),
"Unexpected number of spans for tracer {}",
expected.name
);
}
let mut expected_index = 0;
for (span_index, span_actual) in spans.iter().enumerate() {
trace!(
"Checking span {} of tracer {}: {}",
span_index,
expected.name,
span_actual.name
);
check_span_information(
span_actual,
&expected.spans[expected_index],
&parent_span_map,
);
parent_span_map.insert(expected.spans[expected_index].span_id, span_actual.id);
if expected.spans[expected_index].is_wildcard {
trace!(
"Span {} is a wildcard, not incrementing expected index",
span_actual.name
);
if spans.len() > span_index + 1 {
let next_span = &spans[span_index + 1];
if !compare_span_information(
next_span,
&expected.spans[expected_index],
&parent_span_map,
) {
trace!(
"Next actual span does not match expected span: {}",
expected.spans[expected_index].span_name
);
expected_index += 1;
}
} else {
expected_index += 1;
}
} else {
expected_index += 1;
}
}
assert_eq!(
expected_index,
expected.spans.len(),
"Not all expected spans were found for tracer {}",
expected.name
);
}
}
#[derive(Debug)]
pub struct ExpectedSpanInformation<'a> {
pub span_name: &'a str,
pub status: SpanStatus,
pub span_id: Uuid,
pub parent_id: Option<Uuid>,
pub kind: SpanKind,
pub attributes: Vec<(&'a str, AttributeValue)>,
pub is_wildcard: bool,
}
impl Default for ExpectedSpanInformation<'_> {
fn default() -> Self {
Self {
span_name: "get",
status: SpanStatus::Unset,
span_id: Uuid::new_v4(),
parent_id: None,
kind: SpanKind::Client,
attributes: vec![],
is_wildcard: false,
}
}
}
fn check_span_information(
span: &Arc<MockSpanInner>,
expected: &ExpectedSpanInformation<'_>,
parent_span_map: &HashMap<Uuid, [u8; 8]>,
) {
assert_eq!(span.name, expected.span_name);
assert_eq!(span.kind, expected.kind);
assert_eq!(*span.state.lock().unwrap(), expected.status);
match span.parent {
None => assert!(expected.parent_id.is_none()),
Some(ref parent) => {
let parent_id = parent_span_map
.get(expected.parent_id.as_ref().unwrap())
.unwrap();
assert_eq!(*parent, *parent_id);
}
}
let attributes = span.attributes.lock().unwrap();
trace!("Expected attributes: {:?}", expected.attributes);
trace!("Found attributes: {:?}", attributes);
for (index, attr) in attributes.iter().enumerate() {
trace!("Attribute {}: {} = {:?}", index, attr.key, attr.value);
let mut found = false;
for (key, value) in &expected.attributes {
if attr.key == *key {
if *value != AttributeValue::String("<ANY>".into()) {
assert_eq!(attr.value, *value, "Attribute mismatch for key: {}", *key);
}
found = true;
break;
}
}
if !found {
panic!("Unexpected attribute: {} = {:?}", attr.key, attr.value);
}
}
for (key, value) in expected.attributes.iter() {
if !attributes.iter().any(|attr| attr.key == *key) {
panic!("Expected attribute not found: {} = {:?}", key, value);
}
}
assert!(
!*span.is_open.lock().unwrap(),
"Span {} was not ended",
span.name
);
}
fn compare_span_information(
actual: &Arc<MockSpanInner>,
expected: &ExpectedSpanInformation<'_>,
parent_span_map: &HashMap<Uuid, [u8; 8]>,
) -> bool {
if actual.name != expected.span_name {
return false;
}
if actual.kind != expected.kind {
return false;
}
if *actual.state.lock().unwrap() != expected.status {
return false;
}
match actual.parent {
None => {
if expected.parent_id.is_some() {
return false;
}
}
Some(ref parent) => {
let parent_id = parent_span_map
.get(expected.parent_id.as_ref().unwrap())
.unwrap();
if *parent != *parent_id {
return false;
}
}
}
let attributes = actual.attributes.lock().unwrap();
trace!("Expected attributes: {:?}", expected.attributes);
trace!("Found attributes: {:?}", attributes);
for (index, attr) in attributes.iter().enumerate() {
trace!("Attribute {}: {} = {:?}", index, attr.key, attr.value);
let mut found = false;
for (key, value) in &expected.attributes {
if attr.key == *key {
if *value != AttributeValue::String("<ANY>".into()) && attr.value != *value {
return false;
}
found = true;
break;
}
}
if !found {
return false;
}
}
for (key, _) in expected.attributes.iter() {
if !attributes.iter().any(|attr| attr.key == *key) {
return false;
}
}
true
}
#[derive(Debug, Clone)]
pub struct ExpectedApiInformation {
pub api_name: Option<&'static str>,
pub api_children: Vec<ExpectedRestApiSpan>,
pub additional_api_attributes: Vec<(&'static str, AttributeValue)>,
}
impl Default for ExpectedApiInformation {
fn default() -> Self {
Self {
api_name: None,
additional_api_attributes: Vec::new(),
api_children: vec![ExpectedRestApiSpan::default()],
}
}
}
#[derive(Debug, Clone)]
pub struct ExpectedRestApiSpan {
pub api_verb: azure_core::http::Method,
pub expected_status_code: azure_core::http::StatusCode,
pub is_wildcard: bool,
}
impl Default for ExpectedRestApiSpan {
fn default() -> Self {
Self {
api_verb: azure_core::http::Method::Get,
expected_status_code: azure_core::http::StatusCode::Ok,
is_wildcard: false,
}
}
}
#[derive(Debug, Default, Clone)]
pub struct ExpectedInstrumentation {
pub package_name: String,
pub package_version: String,
pub package_namespace: Option<&'static str>,
pub api_calls: Vec<ExpectedApiInformation>,
}
pub async fn assert_instrumentation_information<C, FnInit, FnTest, T>(
create_client: FnInit,
test_api: FnTest,
api_information: ExpectedInstrumentation,
) -> azure_core::Result<()>
where
FnInit: FnOnce(Arc<dyn TracerProvider>) -> azure_core::Result<C>,
FnTest: AsyncFnOnce(C) -> azure_core::Result<T>,
{
let mock_tracer = Arc::new(MockTracingProvider::new());
let client = create_client(mock_tracer.clone())?;
let _ = test_api(client).await;
let mut public_api_tracer = ExpectedTracerInformation {
name: api_information.package_name.as_str(),
version: Some(api_information.package_version.as_str()),
namespace: api_information.package_namespace,
spans: Vec::new(),
};
let mut request_activity_tracer = ExpectedTracerInformation {
name: api_information.package_name.as_str(),
version: Some(api_information.package_version.as_str()),
namespace: None,
spans: Vec::new(),
};
for api_call in api_information.api_calls.iter() {
let mut expected_spans = Vec::new();
let mut public_api_attributes = api_call.additional_api_attributes.clone();
if let Some(namespace) = api_information.package_namespace {
public_api_attributes.push(("az.namespace", namespace.into()));
}
let mut span_status = SpanStatus::Unset;
for rest_api_call in api_call.api_children.iter() {
if !rest_api_call.expected_status_code.is_success() {
public_api_attributes.push((
"error.type",
rest_api_call.expected_status_code.to_string().into(),
));
}
if rest_api_call.expected_status_code.is_server_error() {
span_status = SpanStatus::Error {
description: "".into(),
};
break;
}
}
let api_id = Uuid::new_v4();
if let Some(api_name) = api_call.api_name {
expected_spans.push(ExpectedSpanInformation {
span_name: api_name,
span_id: api_id,
status: span_status,
kind: SpanKind::Internal,
parent_id: None,
is_wildcard: false, attributes: public_api_attributes,
});
}
for rest_api_call in api_call.api_children.iter() {
let mut http_request_attributes = vec![
(
"http.request.method",
rest_api_call.api_verb.as_str().into(),
),
("url.full", "<ANY>".into()),
("server.address", "<ANY>".into()),
("server.port", "<ANY>".into()),
("az.client_request_id", "<ANY>".into()),
(
"http.response.status_code",
(*rest_api_call.expected_status_code).into(),
),
];
if !rest_api_call.expected_status_code.is_success() {
http_request_attributes.push((
"error.type",
rest_api_call.expected_status_code.to_string().into(),
));
}
if api_call.api_name.is_some() {
if let Some(package_namespace) = api_information.package_namespace {
http_request_attributes.push(("az.namespace", package_namespace.into()));
}
}
expected_spans.push(ExpectedSpanInformation {
span_name: rest_api_call.api_verb.as_str(),
parent_id: if api_call.api_name.is_some() {
Some(api_id)
} else {
None
},
is_wildcard: rest_api_call.is_wildcard,
span_id: Uuid::new_v4(),
status: if !rest_api_call.expected_status_code.is_success() {
SpanStatus::Error {
description: "".into(),
}
} else {
SpanStatus::Unset
},
kind: SpanKind::Client,
attributes: http_request_attributes,
});
}
if api_call.api_name.is_some() {
public_api_tracer.spans.extend(expected_spans);
} else {
request_activity_tracer.spans.extend(expected_spans);
}
}
let expected_tracers = vec![public_api_tracer, request_activity_tracer];
check_instrumentation_result(mock_tracer, expected_tracers);
Ok(())
}