use std::{
collections::HashMap,
hash::{Hash, Hasher},
pin::Pin,
sync::Arc,
task::{Context as TaskContext, Poll},
time::{Duration, Instant},
};
use async_trait::async_trait;
use futures_core::Stream;
use tracing::Span;
use tracing_opentelemetry::OpenTelemetrySpanExt;
use crate::telemetry::capability_metrics::{
SlowThresholds, record_cache_op, record_messaging_publish,
};
use crate::telemetry::propagate::{extract_context_from_map, inject_current_context_map};
use crate::Error;
use crate::traits::{
Cache, Config, Database, DeliveredMessage, EventBus, MessageId, SecretStore, SubscribeOptions,
Subscription,
};
pub struct Instrumented<T: ?Sized> {
inner: Arc<T>,
slow: SlowThresholds,
}
impl<T: ?Sized> Clone for Instrumented<T> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
slow: self.slow.clone(),
}
}
}
impl<T: ?Sized> Instrumented<T> {
pub fn new(inner: Arc<T>, slow: SlowThresholds) -> Self {
Self { inner, slow }
}
pub fn with_defaults(inner: Arc<T>) -> Self {
Self::new(inner, SlowThresholds::default())
}
}
#[async_trait]
impl<T: Cache + ?Sized> Cache for Instrumented<T> {
async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, Error> {
let span = tracing::info_span!(
"cache.get",
cache.system = self.inner.system(),
cache.op = "get",
cache.key.hash = key_hash(key),
cache.hit = tracing::field::Empty,
);
let _enter = span.enter();
let start = Instant::now();
let result = self.inner.get(key).await;
let elapsed = start.elapsed();
let outcome = match &result {
Ok(Some(_)) => "hit",
Ok(None) => "miss",
Err(_) => "error",
};
if let Ok(opt) = &result {
Span::current().record("cache.hit", opt.is_some());
}
record_cache_op(self.inner.system(), "get", outcome, elapsed);
slow_or_error_log(elapsed, self.slow.cache_op, &result, "cache.get", key);
result
}
async fn set(&self, key: &str, value: &[u8], ttl: Option<Duration>) -> Result<(), Error> {
let span = tracing::info_span!(
"cache.set",
cache.system = self.inner.system(),
cache.op = "set",
cache.key.hash = key_hash(key),
ttl_ms = ttl.map(|d| d.as_millis() as u64),
);
let _enter = span.enter();
let start = Instant::now();
let result = self.inner.set(key, value, ttl).await;
let elapsed = start.elapsed();
let outcome = if result.is_ok() { "ok" } else { "error" };
record_cache_op(self.inner.system(), "set", outcome, elapsed);
slow_or_error_log(elapsed, self.slow.cache_op, &result, "cache.set", key);
result
}
async fn del(&self, key: &str) -> Result<(), Error> {
let span = tracing::info_span!(
"cache.del",
cache.system = self.inner.system(),
cache.op = "del",
cache.key.hash = key_hash(key),
);
let _enter = span.enter();
let start = Instant::now();
let result = self.inner.del(key).await;
let elapsed = start.elapsed();
let outcome = if result.is_ok() { "ok" } else { "error" };
record_cache_op(self.inner.system(), "del", outcome, elapsed);
slow_or_error_log(elapsed, self.slow.cache_op, &result, "cache.del", key);
result
}
async fn set_nx(&self, key: &str, value: &[u8], ttl: Option<Duration>) -> Result<bool, Error> {
let span = tracing::info_span!(
"cache.set_nx",
cache.system = self.inner.system(),
cache.op = "setnx",
cache.key.hash = key_hash(key),
ttl_ms = ttl.map(|d| d.as_millis() as u64),
cache.created = tracing::field::Empty,
);
let _enter = span.enter();
let start = Instant::now();
let result = self.inner.set_nx(key, value, ttl).await;
let elapsed = start.elapsed();
let outcome = match &result {
Ok(true) => "ok",
Ok(false) => "exists",
Err(_) => "error",
};
if let Ok(created) = &result {
Span::current().record("cache.created", *created);
}
record_cache_op(self.inner.system(), "setnx", outcome, elapsed);
slow_or_error_log(elapsed, self.slow.cache_op, &result, "cache.set_nx", key);
result
}
fn system(&self) -> &'static str {
self.inner.system()
}
}
#[async_trait]
impl<T: Database + ?Sized> Database for Instrumented<T> {
fn url(&self) -> &str {
self.inner.url()
}
fn system(&self) -> &'static str {
self.inner.system()
}
}
#[async_trait]
impl<T: SecretStore + ?Sized> SecretStore for Instrumented<T> {
async fn get(&self, key: &str) -> Result<String, Error> {
let span = tracing::info_span!("secret.get", secret.provider = self.inner.provider(),);
let _enter = span.enter();
let start = Instant::now();
let result = self.inner.get(key).await;
let elapsed = start.elapsed();
if result.is_err() {
tracing::warn!(
elapsed_ms = elapsed.as_millis() as u64,
provider = self.inner.provider(),
"secret.get failed",
);
}
result
}
fn provider(&self) -> &'static str {
self.inner.provider()
}
}
#[async_trait]
impl<T: Config + ?Sized> Config for Instrumented<T> {
async fn get(&self, path: &str) -> Result<Option<Vec<u8>>, Error> {
let span = tracing::info_span!(
"config.get",
config.source = self.inner.source(),
config.path = path,
);
let _enter = span.enter();
let start = Instant::now();
let result = self.inner.get(path).await;
let elapsed = start.elapsed();
if result.is_err() {
tracing::warn!(
elapsed_ms = elapsed.as_millis() as u64,
source = self.inner.source(),
path,
"config.get failed",
);
}
result
}
fn watch(
&self,
path: &str,
interval: std::time::Duration,
) -> tokio::sync::watch::Receiver<Option<Vec<u8>>> {
self.inner.watch(path, interval)
}
fn source(&self) -> &'static str {
self.inner.source()
}
}
#[async_trait]
impl<T: EventBus + ?Sized> EventBus for Instrumented<T> {
async fn publish(&self, subject: &str, payload: &[u8]) -> Result<MessageId, Error> {
let mut headers = HashMap::new();
inject_current_context_map(&mut headers);
self.publish_inner(subject, payload, headers).await
}
async fn publish_with_headers(
&self,
subject: &str,
payload: &[u8],
mut headers: HashMap<String, String>,
) -> Result<MessageId, Error> {
if !headers.contains_key("traceparent") {
inject_current_context_map(&mut headers);
}
self.publish_inner(subject, payload, headers).await
}
async fn subscribe(
&self,
subject_pattern: &str,
group: &str,
opts: SubscribeOptions,
) -> Result<Subscription, Error> {
let inner_sub = self.inner.subscribe(subject_pattern, group, opts).await?;
let system: &'static str = self.inner.system();
let group_owned = group.to_string();
let subject_owned = subject_pattern.to_string();
Ok(Subscription::new(InstrumentedSubscription {
inner: inner_sub,
system,
group: group_owned,
subject: subject_owned,
}))
}
fn system(&self) -> &'static str {
self.inner.system()
}
}
impl<T: EventBus + ?Sized> Instrumented<T> {
async fn publish_inner(
&self,
subject: &str,
payload: &[u8],
headers: HashMap<String, String>,
) -> Result<MessageId, Error> {
let span = tracing::info_span!(
"messaging.publish",
messaging.system = self.inner.system(),
messaging.destination.name = subject,
messaging.destination.kind = "topic",
messaging.message.body.size = payload.len(),
);
let _enter = span.enter();
let start = Instant::now();
let result = self
.inner
.publish_with_headers(subject, payload, headers)
.await;
let elapsed = start.elapsed();
let outcome = if result.is_ok() { "ok" } else { "error" };
record_messaging_publish(self.inner.system(), subject, outcome, elapsed);
if elapsed > self.slow.messaging_publish || result.is_err() {
tracing::warn!(
elapsed_ms = elapsed.as_millis() as u64,
subject,
error = result.as_ref().err().map(|e| e.to_string()),
"slow or failing messaging.publish",
);
}
result
}
}
struct InstrumentedSubscription {
inner: Subscription,
system: &'static str,
group: String,
subject: String,
}
impl Stream for InstrumentedSubscription {
type Item = DeliveredMessage;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Option<Self::Item>> {
let this = self.as_mut().get_mut();
match Pin::new(&mut this.inner).poll_next(cx) {
Poll::Ready(Some(msg)) => {
let span = tracing::info_span!(
"messaging.process",
messaging.system = this.system,
messaging.destination.name = %msg.subject,
messaging.consumer.id = %this.group,
messaging.message.id = %msg.id,
messaging.delivery_attempt = msg.delivery_attempt,
subscription.subject_pattern = %this.subject,
);
let parent_cx = extract_context_from_map(&msg.headers);
span.set_parent(parent_cx);
drop(span);
Poll::Ready(Some(msg))
}
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
fn slow_or_error_log<T>(
elapsed: Duration,
threshold: Duration,
result: &Result<T, Error>,
op: &'static str,
key_for_hash: &str,
) {
if elapsed > threshold || result.is_err() {
tracing::warn!(
elapsed_ms = elapsed.as_millis() as u64,
key.hash = key_hash(key_for_hash),
error = result.as_ref().err().map(|e| e.to_string()),
"slow or failing {op}",
op = op,
);
}
}
fn key_hash(key: &str) -> u64 {
let mut h = std::collections::hash_map::DefaultHasher::new();
key.hash(&mut h);
h.finish()
}