use commonware_utils::sync::Mutex;
use std::{
fmt,
ops::{Deref, DerefMut},
sync::Arc,
};
use thiserror::Error;
use tracing::{field, span, Event, Level, Subscriber};
use tracing_subscriber::{layer::Context, registry::LookupSpan, Layer};
#[derive(Error, Debug, Clone, Eq, PartialEq)]
#[error("Trace assertion error: {0}")]
pub struct TraceAssertionError(String);
impl From<String> for TraceAssertionError {
fn from(value: String) -> Self {
Self(value)
}
}
#[derive(Default, Clone, Debug)]
pub struct EventMetadata {
pub content: String,
pub fields: Vec<(String, String)>,
}
impl EventMetadata {
pub fn expect_content_exact(&self, content: &str) -> Result<(), TraceAssertionError> {
if self.content == content {
Ok(())
} else {
Err(format!("Expected content '{content}', found '{}'", self.content).into())
}
}
pub fn expect_content_contains(&self, substring: &str) -> Result<(), TraceAssertionError> {
if self.content.contains(substring) {
Ok(())
} else {
Err(format!(
"Expected content containing '{substring}', found '{}'",
self.content
)
.into())
}
}
pub fn expect_field_count(&self, n: usize) -> Result<(), TraceAssertionError> {
if self.fields.len() == n {
Ok(())
} else {
Err(format!("Expected {n} fields, found {}", self.fields.len()).into())
}
}
pub fn expect_field_at_index<F>(
&self,
index: usize,
predicate: F,
) -> Result<(), TraceAssertionError>
where
F: Fn(&(String, String)) -> Result<(), TraceAssertionError>,
{
self.fields.get(index).map_or_else(
|| Err(format!("Missing field at index {index}").into()),
predicate,
)
}
pub fn expect_field_exact(
&self,
field_name: &str,
field_value: &str,
) -> Result<(), TraceAssertionError> {
let found = self
.fields
.iter()
.any(|(name, value)| name == field_name && value == field_value);
if found {
Ok(())
} else {
Err(format!("Expected a field '{field_name}' with value '{field_value}'").into())
}
}
pub fn expect_field_contains(
&self,
field_name: &str,
field_value: &str,
) -> Result<(), TraceAssertionError> {
let found = self
.fields
.iter()
.any(|(name, value)| name == field_name && value.contains(field_value));
if found {
Ok(())
} else {
Err(format!("Expected a field '{field_name}' containing value '{field_value}'").into())
}
}
}
impl field::Visit for EventMetadata {
fn record_str(&mut self, field: &field::Field, value: &str) {
if field.name() == "message" {
self.content = value.to_string();
} else {
self.fields
.push((field.name().to_string(), value.to_string()));
}
}
fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn fmt::Debug) {
let rendered = format!("{value:?}");
if field.name() == "message" {
self.content = rendered;
} else {
self.fields.push((field.name().to_string(), rendered));
}
}
}
#[derive(Debug, Clone)]
pub struct RecordedEvent {
pub level: Level,
pub target: String,
pub spans: Vec<EventMetadata>,
pub metadata: EventMetadata,
}
impl RecordedEvent {
pub fn expect_span_count(&self, n: usize) -> Result<(), TraceAssertionError> {
if self.spans.len() == n {
Ok(())
} else {
Err(format!("Expected {n} active spans, found {}", self.spans.len()).into())
}
}
pub fn expect_span_at_index<F>(
&self,
index: usize,
predicate: F,
) -> Result<(), TraceAssertionError>
where
F: Fn(&EventMetadata) -> Result<(), TraceAssertionError>,
{
self.spans.get(index).map_or_else(
|| Err(format!("Missing span at index {index}").into()),
predicate,
)
}
pub fn expect_span<F>(&self, predicate: F) -> Result<(), TraceAssertionError>
where
F: Fn(&EventMetadata) -> bool,
{
if self.spans.iter().any(predicate) {
Ok(())
} else {
Err("Missing span matching predicate".to_string().into())
}
}
}
#[derive(Default, Debug, Clone)]
pub struct RecordedEvents(Vec<RecordedEvent>);
impl RecordedEvents {
pub fn expect_event_at_index<F>(
&self,
index: usize,
predicate: F,
) -> Result<(), TraceAssertionError>
where
F: Fn(&RecordedEvent) -> Result<(), TraceAssertionError>,
{
self.get(index).map_or_else(
|| Err(format!("Missing event at index {index}").into()),
predicate,
)
}
pub fn expect_event<F>(&self, predicate: F) -> Result<(), TraceAssertionError>
where
F: Fn(&RecordedEvent) -> bool,
{
if self.iter().any(predicate) {
Ok(())
} else {
Err("Missing event matching predicate".to_string().into())
}
}
pub fn expect_message_exact(&self, message: &str) -> Result<(), TraceAssertionError> {
let found = self.iter().any(|event| event.metadata.content == message);
if found {
Ok(())
} else {
Err(format!("Missing message: '{message}'").into())
}
}
pub fn expect_message_contains(&self, substring: &str) -> Result<(), TraceAssertionError> {
let found = self
.iter()
.any(|event| event.metadata.content.contains(substring));
if found {
Ok(())
} else {
Err(format!("Missing message containing: '{substring}'").into())
}
}
}
impl Deref for RecordedEvents {
type Target = Vec<RecordedEvent>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for RecordedEvents {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl From<Vec<RecordedEvent>> for RecordedEvents {
fn from(events: Vec<RecordedEvent>) -> Self {
Self(events)
}
}
#[derive(Debug, Default, Clone)]
pub struct TraceStorage(Arc<Mutex<RecordedEvents>>);
impl TraceStorage {
pub fn get_by_level(&self, level: Level) -> RecordedEvents {
self.0
.lock()
.iter()
.filter_map(|event| (event.level == level).then_some(event.clone()))
.collect::<Vec<_>>()
.into()
}
pub fn get_all(&self) -> RecordedEvents {
self.0.lock().clone()
}
pub fn is_empty(&self) -> bool {
self.0.lock().is_empty()
}
}
#[derive(Debug, Default)]
pub struct CollectingLayer(TraceStorage);
impl CollectingLayer {
pub const fn new(storage: TraceStorage) -> Self {
Self(storage)
}
}
impl<S> Layer<S> for CollectingLayer
where
S: Subscriber + for<'a> LookupSpan<'a>,
{
fn on_new_span(&self, attrs: &span::Attributes<'_>, id: &span::Id, ctx: Context<'_, S>) {
if let Some(span) = ctx.span(id) {
let mut span_fields = EventMetadata::default();
attrs.record(&mut span_fields);
span.extensions_mut().insert(span_fields);
}
}
fn on_event(&self, event: &Event<'_>, ctx: Context<'_, S>) {
let metadata = event.metadata();
let level = *metadata.level();
let mut event_meta = EventMetadata::default();
event.record(&mut event_meta);
let mut spans = Vec::new();
if let Some(current) = ctx.lookup_current() {
let mut current = Some(current);
while let Some(span) = current {
let metadata = span.metadata();
let EventMetadata { fields, .. } = span
.extensions()
.get::<EventMetadata>()
.cloned()
.unwrap_or_default();
spans.push(EventMetadata {
content: metadata.name().to_string(),
fields,
});
current = span.parent();
}
}
let mut storage = self.0 .0.lock();
storage.push(RecordedEvent {
level,
target: metadata.target().to_string(),
spans,
metadata: event_meta,
});
}
}