use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use tokio::sync::mpsc::{UnboundedSender, unbounded_channel};
use tracing::{Event, Level, Subscriber};
use tracing_subscriber::Layer;
use tracing_subscriber::layer::Context;
use tracing_subscriber::registry::LookupSpan;
use crate::batch::BatchProcessor;
use crate::config::BetterStackConfig;
use crate::log_event::{LogEvent, LogValue};
pub struct BetterStackLayer {
sender: Arc<Mutex<Option<UnboundedSender<LogEvent>>>>,
config: Arc<BetterStackConfig>,
}
impl BetterStackLayer {
pub fn new(config: BetterStackConfig) -> Self {
Self {
sender: Arc::new(Mutex::new(None)),
config: Arc::new(config),
}
}
pub fn builder(
ingesting_host: impl Into<String>,
source_token: impl Into<String>,
) -> crate::config::BetterStackConfigBuilder {
BetterStackConfig::builder(ingesting_host, source_token)
}
fn ensure_initialized(&self) {
let mut sender = self.sender.lock().unwrap();
if sender.is_none()
&& let Ok(handle) = tokio::runtime::Handle::try_current()
{
let (tx, rx) = unbounded_channel();
let config = Arc::clone(&self.config);
handle.spawn(async move {
let processor = BatchProcessor::new(rx, (*config).clone());
processor.run().await;
});
*sender = Some(tx);
}
}
fn send_event(&self, event: LogEvent) {
self.ensure_initialized();
if let Some(tx) = self.sender.lock().unwrap().as_ref()
&& tx.send(event).is_err()
{
eprintln!("Failed to send event to Better Stack: channel closed");
}
}
}
impl<S> Layer<S> for BetterStackLayer
where
S: Subscriber + for<'a> LookupSpan<'a>,
{
fn on_event(&self, event: &Event<'_>, ctx: Context<'_, S>) {
let mut fields = HashMap::new();
let mut visitor = FieldVisitor::new(&mut fields);
event.record(&mut visitor);
let level = match *event.metadata().level() {
Level::ERROR => "error",
Level::WARN => "warn",
Level::INFO => "info",
Level::DEBUG => "debug",
Level::TRACE => "trace",
};
let message = fields
.remove("message")
.and_then(|v| match v {
LogValue::String(s) => Some(s),
_ => None,
})
.unwrap_or_default();
let mut log_event = LogEvent::new(level, message).with_target(event.metadata().target());
if self.config.include_location {
let file = event.metadata().file().map(|s| s.to_string());
let line = event.metadata().line();
log_event = log_event.with_location(file, line);
}
if self.config.include_spans
&& let Some(scope) = ctx.event_scope(event)
{
let mut span_fields = HashMap::new();
let mut span_name = String::new();
for span in scope.from_root() {
if span_name.is_empty() {
span_name = span.name().to_string();
}
let extensions = span.extensions();
if let Some(visitor) = extensions.get::<SpanFieldStorage>() {
span_fields.extend(visitor.fields.clone());
}
}
if !span_name.is_empty() && !span_fields.is_empty() {
log_event = log_event.with_span(span_name, span_fields);
}
}
for (key, value) in fields {
log_event.add_field(key, value);
}
self.send_event(log_event);
}
fn on_new_span(
&self,
attrs: &tracing::span::Attributes<'_>,
id: &tracing::span::Id,
ctx: Context<'_, S>,
) {
if !self.config.include_spans {
return;
}
let span = ctx.span(id).unwrap();
let mut fields = HashMap::new();
let mut visitor = FieldVisitor::new(&mut fields);
attrs.record(&mut visitor);
let mut extensions = span.extensions_mut();
extensions.insert(SpanFieldStorage { fields });
}
}
struct SpanFieldStorage {
fields: HashMap<String, LogValue>,
}
struct FieldVisitor<'a> {
fields: &'a mut HashMap<String, LogValue>,
}
impl<'a> FieldVisitor<'a> {
fn new(fields: &'a mut HashMap<String, LogValue>) -> Self {
Self { fields }
}
}
impl<'a> tracing::field::Visit for FieldVisitor<'a> {
fn record_str(&mut self, field: &tracing::field::Field, value: &str) {
self.fields.insert(
field.name().to_string(),
LogValue::String(value.to_string()),
);
}
fn record_i64(&mut self, field: &tracing::field::Field, value: i64) {
self.fields
.insert(field.name().to_string(), LogValue::Number(value));
}
fn record_u64(&mut self, field: &tracing::field::Field, value: u64) {
self.fields
.insert(field.name().to_string(), LogValue::Number(value as i64));
}
fn record_bool(&mut self, field: &tracing::field::Field, value: bool) {
self.fields
.insert(field.name().to_string(), LogValue::Bool(value));
}
fn record_f64(&mut self, field: &tracing::field::Field, value: f64) {
self.fields
.insert(field.name().to_string(), LogValue::Float(value));
}
fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) {
self.fields.insert(
field.name().to_string(),
LogValue::String(format!("{:?}", value)),
);
}
}