use std::{
env::VarError,
error::Error,
fmt,
io::{self, IsTerminal},
str::FromStr,
sync::atomic::{AtomicU64, Ordering},
time::{Duration, SystemTime, UNIX_EPOCH},
};
use opentelemetry::trace::TracerProvider;
use opentelemetry_otlp::WithExportConfig;
use opentelemetry_sdk::{metrics::SdkMeterProvider, trace::SdkTracerProvider};
use opentelemetry_semantic_conventions::resource::SERVICE_NAME;
use tracing::{Metadata, Subscriber, info, level_filters::LevelFilter, span, subscriber::Interest, warn};
use tracing_subscriber::{
EnvFilter, Registry,
filter::Filtered,
fmt::{
FmtContext, FormatEvent, FormatFields, FormattedFields, Layer,
format::{FmtSpan, Format, Json, JsonFields, Writer},
},
layer::{Context, Filter, Layered, SubscriberExt},
prelude::*,
registry::LookupSpan,
util::SubscriberInitExt,
};
const AMARU_LOG_VAR: &str = "AMARU_LOG";
const DEFAULT_AMARU_LOG_FILTER: &str = "info,amaru::consensus=debug,amaru::ledger=debug,pure_stage=warn";
const AMARU_TRACE_VAR: &str = "AMARU_TRACE";
const DEFAULT_AMARU_TRACE_FILTER: &str = "amaru=trace,pure_stage=trace,amaru_protocols=warn,amaru_consensus=info";
const OTEL_ERROR_THROTTLE_MS: u64 = 5_000;
type OpenTelemetryLayer<S> = Layered<OpenTelemetryFilter<S>, S>;
type OpenTelemetryFilter<S> =
Filtered<tracing_opentelemetry::OpenTelemetryLayer<S, opentelemetry_sdk::trace::Tracer>, ThrottledEnvFilter, S>;
type JsonLayer<S> = Layered<JsonFilter<S>, S>;
type JsonFilter<S> = Filtered<Layer<S, JsonFields, SpanJsonFormat>, ThrottledEnvFilter, S>;
type DelayedWarning = Option<Box<dyn FnOnce()>>;
pub struct SpanJsonFormat(Format<Json>);
impl<S, N> FormatEvent<S, N> for SpanJsonFormat
where
S: tracing::Subscriber + for<'a> LookupSpan<'a>,
N: for<'a> FormatFields<'a> + 'static,
{
fn format_event(
&self,
ctx: &FmtContext<'_, S, N>,
mut writer: Writer<'_>,
event: &tracing::Event<'_>,
) -> fmt::Result {
let mut buf = String::new();
self.0.format_event(ctx, Writer::new(&mut buf), event)?;
if let Some(current) = ctx.lookup_current()
&& let Some(pos) = buf.rfind('}')
{
let mut extra = String::new();
let extensions = current.extensions();
if let Some(fields) = extensions.get::<FormattedFields<JsonFields>>() {
let s = fields.as_str().trim();
let inner = s.strip_prefix('{').and_then(|s| s.strip_suffix('}')).unwrap_or(s);
if !inner.is_empty() {
extra.push(',');
extra.push_str(inner);
}
}
if event.metadata().is_span() {
let id = current.id().into_u64();
extra.push_str(&format!(",\"id\":{id}"));
}
if let Some(parent) = current.parent() {
let parent_id = parent.id().into_u64();
extra.push_str(&format!(",\"parent_id\":{parent_id}"));
}
if !extra.is_empty() {
buf.insert_str(pos, &extra);
}
}
writer.write_str(&buf)
}
}
#[expect(clippy::large_enum_variant)]
#[derive(Default)]
pub enum TracingSubscriber<S> {
#[default]
Empty,
Registry(Registry),
WithOpenTelemetry(OpenTelemetryLayer<S>),
WithJson(JsonLayer<S>),
WithBoth(JsonLayer<OpenTelemetryLayer<S>>),
}
impl TracingSubscriber<Registry> {
pub fn new() -> Self {
Self::Registry(tracing_subscriber::registry())
}
#[expect(clippy::panic)]
#[expect(clippy::wildcard_enum_match_arm)]
pub fn with_open_telemetry(&mut self, layer: OpenTelemetryFilter<Registry>) {
match std::mem::take(self) {
Self::Registry(registry) => {
*self = TracingSubscriber::WithOpenTelemetry(registry.with(layer));
}
_ => panic!("'with_open_telemetry' called after 'with_json'"),
}
}
#[expect(clippy::panic)]
#[expect(clippy::wildcard_enum_match_arm)]
pub fn with_json<F, G>(&mut self, layer_json: F, layer_both: G) -> DelayedWarning
where
F: FnOnce() -> (JsonFilter<Registry>, DelayedWarning),
G: FnOnce() -> (JsonFilter<OpenTelemetryLayer<Registry>>, DelayedWarning),
{
match std::mem::take(self) {
Self::Registry(registry) => {
let (layer, warning) = layer_json();
*self = TracingSubscriber::WithJson(registry.with(layer));
warning
}
Self::WithOpenTelemetry(layered) => {
let (layer, warning) = layer_both();
*self = TracingSubscriber::WithBoth(layered.with(layer));
warning
}
_ => panic!("'with_open_telemetry' called after 'with_json'"),
}
}
pub fn init(self, color: bool) {
let (default_filter, warning) = new_default_filter(AMARU_LOG_VAR, DEFAULT_AMARU_LOG_FILTER);
let log_format = || tracing_subscriber::fmt::format().with_ansi(color).compact();
let log_writer = || io::stderr as fn() -> io::Stderr;
let log_events = || FmtSpan::CLOSE;
let log_filter = || default_filter;
match self {
TracingSubscriber::Empty => unreachable!(),
TracingSubscriber::Registry(registry) => registry
.with(
tracing_subscriber::fmt::layer()
.with_writer(log_writer())
.event_format(log_format())
.with_span_events(log_events())
.with_filter(log_filter()),
)
.init(),
TracingSubscriber::WithOpenTelemetry(layered) => layered
.with(
tracing_subscriber::fmt::layer()
.with_writer(log_writer())
.event_format(log_format())
.with_span_events(log_events())
.with_filter(log_filter()),
)
.init(),
TracingSubscriber::WithJson(layered) => layered.init(),
TracingSubscriber::WithBoth(layered) => layered.init(),
};
if let Some(notify) = warning {
notify();
}
}
}
pub fn setup_json_traces(subscriber: &mut TracingSubscriber<Registry>) -> DelayedWarning {
let format = || SpanJsonFormat(tracing_subscriber::fmt::format().json().with_span_list(false));
let events = || FmtSpan::ENTER | FmtSpan::EXIT;
let filter = || new_default_filter(AMARU_TRACE_VAR, DEFAULT_AMARU_TRACE_FILTER);
subscriber.with_json(
|| {
let (default_filter, warning) = filter();
(
tracing_subscriber::fmt::layer()
.with_span_events(events())
.event_format(format())
.fmt_fields(JsonFields::new())
.with_filter(default_filter),
warning,
)
},
|| {
let (default_filter, warning) = filter();
(
tracing_subscriber::fmt::layer()
.with_span_events(events())
.event_format(format())
.fmt_fields(JsonFields::new())
.with_filter(default_filter),
warning,
)
},
)
}
pub struct OpenTelemetryHandle {
pub metrics: Option<SdkMeterProvider>,
pub teardown: Box<dyn FnOnce() -> Result<(), Box<dyn std::error::Error>>>,
}
impl Default for OpenTelemetryHandle {
fn default() -> Self {
OpenTelemetryHandle {
metrics: None::<SdkMeterProvider>,
teardown: Box::new(|| Ok(())) as Box<dyn FnOnce() -> Result<(), Box<dyn std::error::Error>>>,
}
}
}
pub const DEFAULT_OTLP_SERVICE_NAME: &str = "amaru";
pub const DEFAULT_OTLP_METRIC_URL: &str = "http://localhost:4318/v1/metrics";
#[expect(clippy::panic)]
pub fn setup_open_telemetry(subscriber: &mut TracingSubscriber<Registry>) -> (OpenTelemetryHandle, DelayedWarning) {
use opentelemetry::KeyValue;
use opentelemetry_sdk::{Resource, metrics::Temporality};
let service_name = std::env::var("OTEL_SERVICE_NAME").unwrap_or_else(|_| DEFAULT_OTLP_SERVICE_NAME.to_string());
let resource = Resource::builder().with_attribute(KeyValue::new(SERVICE_NAME, service_name)).build();
let opentelemetry_provider = SdkTracerProvider::builder()
.with_resource(resource.clone())
.with_batch_exporter(
opentelemetry_otlp::SpanExporter::builder()
.with_tonic()
.build()
.unwrap_or_else(|e| panic!("failed to setup opentelemetry span exporter: {e}")),
)
.build();
let metric_exporter = opentelemetry_otlp::MetricExporter::builder()
.with_http()
.with_endpoint(DEFAULT_OTLP_METRIC_URL)
.with_temporality(Temporality::default())
.build()
.unwrap_or_else(|e| panic!("unable to create metric exporter: {e:?}"));
let metric_reader = opentelemetry_sdk::metrics::PeriodicReader::builder(metric_exporter)
.with_interval(Duration::from_secs(10))
.build();
let metrics_provider = opentelemetry_sdk::metrics::SdkMeterProvider::builder()
.with_reader(metric_reader)
.with_resource(resource)
.build();
opentelemetry::global::set_meter_provider(metrics_provider.clone());
let opentelemetry_tracer = opentelemetry_provider.tracer(DEFAULT_OTLP_SERVICE_NAME);
let (default_filter, warning) = new_default_filter(AMARU_TRACE_VAR, DEFAULT_AMARU_TRACE_FILTER);
let opentelemetry_layer =
tracing_opentelemetry::layer().with_tracer(opentelemetry_tracer).with_level(true).with_filter(default_filter);
subscriber.with_open_telemetry(opentelemetry_layer);
(
OpenTelemetryHandle {
metrics: Some(metrics_provider.clone()),
teardown: Box::new(|| teardown_open_telemetry(opentelemetry_provider, metrics_provider)),
},
warning,
)
}
fn teardown_open_telemetry(
tracing: SdkTracerProvider,
metrics: SdkMeterProvider,
) -> Result<(), Box<dyn std::error::Error>> {
tracing.shutdown()?;
metrics.shutdown()?;
Ok(())
}
pub struct ThrottledEnvFilter {
inner: EnvFilter,
last_otel_event: AtomicU64,
throttle_ms: u64,
}
impl ThrottledEnvFilter {
fn new(inner: EnvFilter, throttle_ms: u64) -> Self {
Self { inner, last_otel_event: AtomicU64::new(0), throttle_ms }
}
fn is_otel_internal(meta: &Metadata<'_>) -> bool {
meta.target().starts_with("opentelemetry")
}
}
impl<S: Subscriber> Filter<S> for ThrottledEnvFilter {
fn enabled(&self, meta: &Metadata<'_>, cx: &Context<'_, S>) -> bool {
if !<EnvFilter as Filter<S>>::enabled(&self.inner, meta, cx) {
return false;
}
if Self::is_otel_internal(meta) {
let Some(now) = SystemTime::now().duration_since(UNIX_EPOCH).ok().map(|d| d.as_millis() as u64) else {
return true;
};
return self
.last_otel_event
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |last| {
(now.saturating_sub(last) >= self.throttle_ms).then_some(now)
})
.is_ok();
}
true
}
fn callsite_enabled(&self, meta: &'static Metadata<'static>) -> Interest {
if Self::is_otel_internal(meta) {
return Interest::sometimes();
}
<EnvFilter as Filter<S>>::callsite_enabled(&self.inner, meta)
}
fn max_level_hint(&self) -> Option<LevelFilter> {
<EnvFilter as Filter<S>>::max_level_hint(&self.inner)
}
fn on_new_span(&self, attrs: &span::Attributes<'_>, id: &span::Id, ctx: Context<'_, S>) {
<EnvFilter as Filter<S>>::on_new_span(&self.inner, attrs, id, ctx);
}
fn on_record(&self, id: &span::Id, values: &span::Record<'_>, ctx: Context<'_, S>) {
<EnvFilter as Filter<S>>::on_record(&self.inner, id, values, ctx);
}
fn on_enter(&self, id: &span::Id, ctx: Context<'_, S>) {
<EnvFilter as Filter<S>>::on_enter(&self.inner, id, ctx);
}
fn on_exit(&self, id: &span::Id, ctx: Context<'_, S>) {
<EnvFilter as Filter<S>>::on_exit(&self.inner, id, ctx);
}
fn on_close(&self, id: span::Id, ctx: Context<'_, S>) {
<EnvFilter as Filter<S>>::on_close(&self.inner, id, ctx);
}
}
fn new_default_filter(var: &str, default: &str) -> (ThrottledEnvFilter, DelayedWarning) {
let (filter, warning) = match EnvFilter::try_from_env(var) {
Ok(filter) => {
let var = var.to_string();
let value = std::env::var(&var).unwrap_or_default();
let notice = Box::new(move || info!(var, value, "using ENV variable")) as Box<dyn FnOnce()>;
(filter, Some(notice))
}
Err(e) => {
let fallback = default.to_string();
let var = var.to_string();
let warning = match e.source().and_then(|e| e.downcast_ref::<VarError>()) {
Some(VarError::NotPresent) => {
Box::new(move || info!(var, fallback, "unspecified ENV variable")) as Box<dyn FnOnce()>
}
_ => Box::new(move || warn!(var, fallback, reason = %e, "invalid ENV variable")) as Box<dyn FnOnce()>,
};
#[expect(clippy::expect_used)]
let filter = EnvFilter::try_new(default).expect("invalid default filter");
(filter, Some(warning))
}
};
(ThrottledEnvFilter::new(filter, OTEL_ERROR_THROTTLE_MS), warning)
}
pub fn setup_observability(
with_open_telemetry: bool,
with_json_traces: bool,
color: bool,
) -> (Option<SdkMeterProvider>, Box<dyn FnOnce() -> Result<(), Box<dyn std::error::Error>>>) {
let mut subscriber = TracingSubscriber::new();
let (OpenTelemetryHandle { metrics, teardown }, warning_otlp) = if with_open_telemetry {
setup_open_telemetry(&mut subscriber)
} else {
(OpenTelemetryHandle::default(), None)
};
let warning_json = if with_json_traces { setup_json_traces(&mut subscriber) } else { None };
subscriber.init(color);
if let Some(notify) = warning_otlp.or(warning_json) {
notify();
}
(metrics, teardown)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Color {
Never,
Always,
Auto,
}
impl FromStr for Color {
type Err = &'static str;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"never" => Ok(Color::Never),
"always" => Ok(Color::Always),
"auto" => Ok(Color::Auto),
_ => Err("valid color settings are 'never', 'always' or 'auto'"),
}
}
}
impl Color {
pub fn is_enabled(this: Option<Self>) -> bool {
match this {
Some(Color::Never) => false,
Some(Color::Always) => true,
Some(Color::Auto) => std::io::stderr().is_terminal(),
None => {
if std::env::var("NO_COLOR").iter().any(|s| !s.is_empty()) {
false
} else {
std::io::stderr().is_terminal()
}
}
}
}
}
#[cfg(test)]
mod tests {
use std::sync::{
Arc, Mutex,
atomic::{AtomicUsize, Ordering as AtomicOrdering},
};
use super::*;
struct CountingLayer {
count: Arc<AtomicUsize>,
}
impl<S: tracing::Subscriber> tracing_subscriber::Layer<S> for CountingLayer {
fn on_event(&self, _event: &tracing::Event<'_>, _ctx: tracing_subscriber::layer::Context<'_, S>) {
self.count.fetch_add(1, AtomicOrdering::Relaxed);
}
}
fn count_events<F: FnOnce()>(filter: ThrottledEnvFilter, f: F) -> usize {
let count = Arc::new(AtomicUsize::new(0));
let subscriber =
tracing_subscriber::registry().with(CountingLayer { count: Arc::clone(&count) }.with_filter(filter));
tracing::subscriber::with_default(subscriber, f);
count.load(AtomicOrdering::Relaxed)
}
#[test]
fn otel_target_is_recognised_as_internal() {
static CHECK: Mutex<Option<bool>> = Mutex::new(None);
struct CaptureMeta;
impl tracing::Subscriber for CaptureMeta {
fn enabled(&self, meta: &tracing::Metadata<'_>) -> bool {
if meta.target().starts_with("opentelemetry") {
*CHECK.lock().unwrap() = Some(ThrottledEnvFilter::is_otel_internal(meta));
}
true
}
fn new_span(&self, _: &tracing::span::Attributes<'_>) -> tracing::span::Id {
tracing::span::Id::from_u64(1)
}
fn record(&self, _: &tracing::span::Id, _: &tracing::span::Record<'_>) {}
fn record_follows_from(&self, _: &tracing::span::Id, _: &tracing::span::Id) {}
fn event(&self, _: &tracing::Event<'_>) {}
fn enter(&self, _: &tracing::span::Id) {}
fn exit(&self, _: &tracing::span::Id) {}
}
tracing::subscriber::with_default(CaptureMeta, || {
tracing::event!(target: "opentelemetry_sdk::internal", tracing::Level::ERROR, "test");
});
assert_eq!(*CHECK.lock().unwrap(), Some(true));
}
#[test]
fn non_otel_target_is_not_recognised_as_internal() {
static CHECK: Mutex<Option<bool>> = Mutex::new(None);
struct CaptureMeta;
impl tracing::Subscriber for CaptureMeta {
fn enabled(&self, meta: &tracing::Metadata<'_>) -> bool {
if meta.target() == "amaru::stages" {
*CHECK.lock().unwrap() = Some(ThrottledEnvFilter::is_otel_internal(meta));
}
true
}
fn new_span(&self, _: &tracing::span::Attributes<'_>) -> tracing::span::Id {
tracing::span::Id::from_u64(1)
}
fn record(&self, _: &tracing::span::Id, _: &tracing::span::Record<'_>) {}
fn record_follows_from(&self, _: &tracing::span::Id, _: &tracing::span::Id) {}
fn event(&self, _: &tracing::Event<'_>) {}
fn enter(&self, _: &tracing::span::Id) {}
fn exit(&self, _: &tracing::span::Id) {}
}
tracing::subscriber::with_default(CaptureMeta, || {
tracing::event!(target: "amaru::stages", tracing::Level::DEBUG, "test");
});
assert_eq!(*CHECK.lock().unwrap(), Some(false));
}
#[test]
fn first_otel_event_is_allowed() {
let filter = ThrottledEnvFilter::new(EnvFilter::try_new("error").unwrap(), 1_000);
let seen = count_events(filter, || {
tracing::event!(target: "opentelemetry_sdk::internal", tracing::Level::ERROR, "test");
});
assert_eq!(seen, 1);
}
#[test]
fn second_otel_event_within_throttle_is_rejected() {
let filter = ThrottledEnvFilter::new(EnvFilter::try_new("error").unwrap(), 1_000);
let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_millis() as u64;
filter.last_otel_event.store(now, Ordering::Relaxed);
let seen = count_events(filter, || {
tracing::event!(target: "opentelemetry_sdk::internal", tracing::Level::ERROR, "test");
});
assert_eq!(seen, 0);
}
#[test]
fn otel_event_after_throttle_period_is_allowed() {
let throttle_ms = 1_000u64;
let filter = ThrottledEnvFilter::new(EnvFilter::try_new("error").unwrap(), throttle_ms);
let past = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_millis() as u64 - throttle_ms - 1;
filter.last_otel_event.store(past, Ordering::Relaxed);
let seen = count_events(filter, || {
tracing::event!(target: "opentelemetry_sdk::internal", tracing::Level::ERROR, "test");
});
assert_eq!(seen, 1);
}
#[test]
fn non_otel_event_is_not_throttled() {
let filter = ThrottledEnvFilter::new(EnvFilter::try_new("debug").unwrap(), 1_000);
let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_millis() as u64;
filter.last_otel_event.store(now, Ordering::Relaxed);
let seen = count_events(filter, || {
tracing::debug!(target: "amaru::stages", "test");
});
assert_eq!(seen, 1);
}
#[test]
fn throttle_period_advances_after_allowed_event() {
let filter = ThrottledEnvFilter::new(EnvFilter::try_new("error").unwrap(), 100);
let seen = count_events(filter, || {
tracing::event!(target: "opentelemetry_sdk::internal", tracing::Level::ERROR, "first");
tracing::event!(target: "opentelemetry_sdk::internal", tracing::Level::ERROR, "second");
});
assert_eq!(seen, 1);
}
}