use async_trait::async_trait;
use std::collections::HashMap;
use parking_lot::Mutex;
use super::span::{Span, SpanExporter, SpanStatus};
use crate::context::Context;
use crate::errors::ModuleError;
use crate::middleware::base::Middleware;
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum SamplingStrategy {
#[serde(rename = "full")]
Always,
#[serde(rename = "proportional")]
Probabilistic,
#[serde(rename = "error_first")]
ErrorFirst,
#[serde(rename = "off")]
Never,
}
#[derive(Debug, Default)]
struct TraceState {
spans: HashMap<String, Vec<Span>>,
sampling: HashMap<String, bool>,
}
#[derive(Debug)]
pub struct TracingMiddleware {
exporter: Box<dyn SpanExporter>,
pub sampling_strategy: SamplingStrategy,
pub sampling_rate: f64,
state: Mutex<TraceState>,
}
impl TracingMiddleware {
#[must_use]
pub fn new(exporter: Box<dyn SpanExporter>) -> Self {
Self {
exporter,
sampling_strategy: SamplingStrategy::Always,
sampling_rate: 1.0,
state: Mutex::new(TraceState::default()),
}
}
#[must_use]
pub fn with_sampling(
exporter: Box<dyn SpanExporter>,
strategy: SamplingStrategy,
rate: f64,
) -> Self {
Self {
exporter,
sampling_strategy: strategy,
sampling_rate: rate.clamp(0.0, 1.0),
state: Mutex::new(TraceState::default()),
}
}
fn should_sample(&self, ctx: &Context<serde_json::Value>) -> bool {
if let Some(sampled) = ctx.data.read().get("_apcore.mw.tracing.sampled").cloned() {
if let Some(b) = sampled.as_bool() {
return b;
}
}
let mut state = self.state.lock();
if let Some(&decision) = state.sampling.get(&ctx.trace_id) {
return decision;
}
let decision = match self.sampling_strategy {
SamplingStrategy::Always => true,
SamplingStrategy::Never => false,
SamplingStrategy::Probabilistic | SamplingStrategy::ErrorFirst => {
#[allow(clippy::cast_precision_loss)]
let random_val = uuid::Uuid::new_v4().as_u128() as f64 / u128::MAX as f64;
random_val < self.sampling_rate
}
};
state.sampling.insert(ctx.trace_id.clone(), decision);
decision
}
}
#[async_trait]
impl Middleware for TracingMiddleware {
fn name(&self) -> &'static str {
"tracing"
}
async fn before(
&self,
module_id: &str,
_inputs: serde_json::Value,
ctx: &Context<serde_json::Value>,
) -> Result<Option<serde_json::Value>, ModuleError> {
self.should_sample(ctx);
let mut span = Span::new("apcore.module.execute", &ctx.trace_id);
span.set_attribute("module_id".to_string(), serde_json::json!(module_id));
span.set_attribute("method".to_string(), serde_json::json!("execute"));
if let Some(ref caller_id) = ctx.caller_id {
span.set_attribute("caller_id".to_string(), serde_json::json!(caller_id));
}
{
let mut state = self.state.lock();
let stack = state.spans.entry(ctx.trace_id.clone()).or_default();
if let Some(parent) = stack.last() {
span.parent_span_id = Some(parent.span_id.clone());
}
stack.push(span);
}
Ok(None)
}
async fn after(
&self,
_module_id: &str,
_inputs: serde_json::Value,
_output: serde_json::Value,
ctx: &Context<serde_json::Value>,
) -> Result<Option<serde_json::Value>, ModuleError> {
let (span, should_clean_sampling) = {
let mut state = self.state.lock();
let popped = state
.spans
.get_mut(&ctx.trace_id)
.and_then(std::vec::Vec::pop);
let should_clean = if let Some(stack) = state.spans.get(&ctx.trace_id) {
if stack.is_empty() {
state.spans.remove(&ctx.trace_id);
true
} else {
false
}
} else {
false
};
(popped, should_clean)
};
if let Some(mut span) = span {
span.status = SpanStatus::Ok;
span.end();
let duration_ms = span
.end_time
.map_or(0.0, |e| (e - span.start_time) * 1000.0);
span.set_attribute("duration_ms".to_string(), serde_json::json!(duration_ms));
span.set_attribute("success".to_string(), serde_json::json!(true));
if self.should_sample(ctx) {
let _ = self.exporter.export(&span).await;
}
}
if should_clean_sampling {
let mut state = self.state.lock();
state.sampling.remove(&ctx.trace_id);
}
Ok(None)
}
async fn on_error(
&self,
_module_id: &str,
_inputs: serde_json::Value,
error: &ModuleError,
ctx: &Context<serde_json::Value>,
) -> Result<Option<serde_json::Value>, ModuleError> {
let (span, should_clean_sampling) = {
let mut state = self.state.lock();
let popped = state
.spans
.get_mut(&ctx.trace_id)
.and_then(std::vec::Vec::pop);
let should_clean = if let Some(stack) = state.spans.get(&ctx.trace_id) {
if stack.is_empty() {
state.spans.remove(&ctx.trace_id);
true
} else {
false
}
} else {
false
};
(popped, should_clean)
};
if let Some(mut span) = span {
span.status = SpanStatus::Error;
span.end();
let duration_ms = span
.end_time
.map_or(0.0, |e| (e - span.start_time) * 1000.0);
span.set_attribute("duration_ms".to_string(), serde_json::json!(duration_ms));
span.set_attribute("success".to_string(), serde_json::json!(false));
span.set_attribute(
"error_code".to_string(),
serde_json::json!(format!("{:?}", error.code)),
);
span.set_attribute(
"error.message".to_string(),
serde_json::json!(error.message),
);
span.add_event("exception");
let should_export = match self.sampling_strategy {
SamplingStrategy::ErrorFirst => true,
_ => self.should_sample(ctx),
};
if should_export {
let _ = self.exporter.export(&span).await;
}
}
if should_clean_sampling {
let mut state = self.state.lock();
state.sampling.remove(&ctx.trace_id);
}
Ok(None)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::context::Identity;
use crate::observability::exporters::InMemoryExporter;
fn make_ctx(trace_id: &str) -> Context<serde_json::Value> {
Context::create(
Identity::new(
"test-user".to_string(),
"user".to_string(),
vec![],
HashMap::new(),
),
serde_json::Value::Null,
None,
None,
)
.tap_trace_id(trace_id)
}
trait TapTraceId {
fn tap_trace_id(self, trace_id: &str) -> Self;
}
impl TapTraceId for Context<serde_json::Value> {
fn tap_trace_id(mut self, trace_id: &str) -> Self {
self.trace_id = trace_id.to_string();
self
}
}
#[tokio::test]
async fn test_nested_spans_parent_child_linking() {
let exporter = InMemoryExporter::new();
let mw = TracingMiddleware::new(Box::new(exporter.clone()));
let ctx = make_ctx("trace-1");
mw.before("mod_a", serde_json::json!({}), &ctx)
.await
.unwrap();
mw.before("mod_b", serde_json::json!({}), &ctx)
.await
.unwrap();
mw.after("mod_b", serde_json::json!({}), serde_json::json!({}), &ctx)
.await
.unwrap();
mw.after("mod_a", serde_json::json!({}), serde_json::json!({}), &ctx)
.await
.unwrap();
let spans = exporter.get_spans();
assert_eq!(spans.len(), 2, "expected 2 exported spans");
let span_b = &spans[0];
let span_a = &spans[1];
assert_eq!(
span_b.attributes.get("module_id").unwrap(),
&serde_json::json!("mod_b")
);
assert_eq!(
span_a.attributes.get("module_id").unwrap(),
&serde_json::json!("mod_a")
);
assert_eq!(
span_b.parent_span_id.as_ref().unwrap(),
&span_a.span_id,
"inner span should reference outer span as parent"
);
assert!(
span_a.parent_span_id.is_none(),
"root span should have no parent"
);
}
#[tokio::test]
async fn test_nested_spans_cleanup_after_all_pops() {
let exporter = InMemoryExporter::new();
let mw = TracingMiddleware::new(Box::new(exporter.clone()));
let ctx = make_ctx("trace-cleanup");
mw.before("mod_a", serde_json::json!({}), &ctx)
.await
.unwrap();
mw.before("mod_b", serde_json::json!({}), &ctx)
.await
.unwrap();
mw.after("mod_b", serde_json::json!({}), serde_json::json!({}), &ctx)
.await
.unwrap();
mw.after("mod_a", serde_json::json!({}), serde_json::json!({}), &ctx)
.await
.unwrap();
let state = mw.state.lock();
assert!(
!state.spans.contains_key("trace-cleanup"),
"span stack should be removed after all spans are popped"
);
assert!(
!state.sampling.contains_key("trace-cleanup"),
"sampling decision should be removed after trace completes"
);
}
#[tokio::test]
async fn test_sampling_decision_inherited_from_parent() {
let exporter = InMemoryExporter::new();
let mw = TracingMiddleware::with_sampling(
Box::new(exporter.clone()),
SamplingStrategy::Always,
1.0,
);
let ctx = make_ctx("trace-inherit");
mw.before("mod_a", serde_json::json!({}), &ctx)
.await
.unwrap();
{
let state = mw.state.lock();
assert_eq!(
state.sampling.get("trace-inherit"),
Some(&true),
"sampling decision should be cached after first before()"
);
}
mw.before("mod_b", serde_json::json!({}), &ctx)
.await
.unwrap();
{
let state = mw.state.lock();
assert_eq!(
state.sampling.get("trace-inherit"),
Some(&true),
"sampling decision should remain the same for nested calls"
);
}
mw.after("mod_b", serde_json::json!({}), serde_json::json!({}), &ctx)
.await
.unwrap();
mw.after("mod_a", serde_json::json!({}), serde_json::json!({}), &ctx)
.await
.unwrap();
}
#[tokio::test]
async fn test_sampling_never_does_not_export() {
let exporter = InMemoryExporter::new();
let mw = TracingMiddleware::with_sampling(
Box::new(exporter.clone()),
SamplingStrategy::Never,
0.0,
);
let ctx = make_ctx("trace-never");
mw.before("mod_a", serde_json::json!({}), &ctx)
.await
.unwrap();
mw.after("mod_a", serde_json::json!({}), serde_json::json!({}), &ctx)
.await
.unwrap();
let spans = exporter.get_spans();
assert!(
spans.is_empty(),
"Never strategy should not export any spans"
);
}
}