use std::sync::Arc;
use tokio::sync::mpsc;
use tracing::field::{Field, Visit};
use tracing::{Event, Subscriber};
use tracing_subscriber::layer::Context;
use tracing_subscriber::registry::LookupSpan;
use tracing_subscriber::Layer;
use super::component_log::{ComponentLogRegistry, LogLevel, LogMessage};
use crate::channels::ComponentType;
use std::sync::OnceLock;
const LOG_CHANNEL_CAPACITY: usize = 10_000;
static GLOBAL_LOG_REGISTRY: OnceLock<Arc<ComponentLogRegistry>> = OnceLock::new();
static GLOBAL_LOG_SENDER: OnceLock<mpsc::Sender<LogMessage>> = OnceLock::new();
pub fn get_or_init_global_registry() -> Arc<ComponentLogRegistry> {
GLOBAL_LOG_REGISTRY
.get_or_init(|| {
let registry = Arc::new(ComponentLogRegistry::new());
let (tx, rx) = mpsc::channel::<LogMessage>(LOG_CHANNEL_CAPACITY);
let _ = GLOBAL_LOG_SENDER.set(tx);
if let Err(e) = spawn_log_worker(registry.clone(), rx) {
eprintln!("drasi-lib: failed to spawn log worker thread: {e}. Component logs will not be captured.");
}
init_tracing_internal(registry.clone());
registry
})
.clone()
}
fn spawn_log_worker(
registry: Arc<ComponentLogRegistry>,
mut rx: mpsc::Receiver<LogMessage>,
) -> std::result::Result<(), std::io::Error> {
std::thread::Builder::new()
.name("drasi-log-worker".to_string())
.spawn(move || {
let rt = match tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
{
Ok(rt) => rt,
Err(e) => {
eprintln!("drasi-lib: failed to create log worker runtime: {e}");
return;
}
};
rt.block_on(async move {
while let Some(message) = rx.recv().await {
registry.log(message).await;
}
});
})?;
Ok(())
}
#[deprecated(
since = "0.4.0",
note = "Use get_or_init_global_registry() instead, which handles initialization automatically"
)]
pub fn init_tracing(log_registry: Arc<ComponentLogRegistry>) {
let _ = get_or_init_global_registry();
if !Arc::ptr_eq(&log_registry, &get_or_init_global_registry()) {
tracing::warn!(
"init_tracing called with custom registry, but global registry already initialized. \
The provided registry will be ignored. Use get_or_init_global_registry() instead."
);
}
}
fn init_tracing_internal(log_registry: Arc<ComponentLogRegistry>) {
use tracing_subscriber::prelude::*;
use tracing_subscriber::{fmt, EnvFilter};
let _ = tracing_log::LogTracer::init();
let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
let subscriber = tracing_subscriber::registry()
.with(filter)
.with(ComponentLogLayer::new(log_registry))
.with(fmt::layer().with_target(true).with_level(true));
let _ = tracing::subscriber::set_global_default(subscriber);
}
#[deprecated(
since = "0.4.0",
note = "Use get_or_init_global_registry() instead, which handles initialization automatically"
)]
pub fn try_init_tracing(log_registry: Arc<ComponentLogRegistry>) -> bool {
if GLOBAL_LOG_REGISTRY.get().is_some() {
return false;
}
let _ = get_or_init_global_registry();
if !Arc::ptr_eq(&log_registry, &get_or_init_global_registry()) {
tracing::warn!(
"try_init_tracing called with custom registry, but initialization uses global registry. \
The provided registry will be ignored."
);
}
true
}
pub struct ComponentLogLayer {
registry: Arc<ComponentLogRegistry>,
}
impl ComponentLogLayer {
pub fn new(registry: Arc<ComponentLogRegistry>) -> Self {
Self { registry }
}
}
impl<S> Layer<S> for ComponentLogLayer
where
S: Subscriber + for<'a> LookupSpan<'a>,
{
fn on_new_span(
&self,
attrs: &tracing::span::Attributes<'_>,
id: &tracing::span::Id,
ctx: Context<'_, S>,
) {
let mut visitor = ComponentInfoVisitor::default();
attrs.record(&mut visitor);
if let Some(info) = visitor.into_component_info() {
if let Some(span) = ctx.span(id) {
let mut extensions = span.extensions_mut();
extensions.insert(info);
}
}
}
fn on_event(&self, event: &Event<'_>, ctx: Context<'_, S>) {
let component_info = ctx.event_span(event).and_then(|span| {
let mut current = Some(span);
while let Some(span_ref) = current {
if let Some(info) = extract_component_info(&span_ref) {
return Some(info);
}
current = span_ref.parent();
}
None
});
if let Some(info) = component_info {
let level = convert_level(*event.metadata().level());
let message = extract_message(event);
let log_message = LogMessage::with_instance(
level,
message,
info.instance_id,
info.component_id,
info.component_type,
);
if let Some(sender) = GLOBAL_LOG_SENDER.get() {
if sender.try_send(log_message).is_err() {
}
}
}
}
}
#[derive(Clone)]
struct ComponentInfo {
instance_id: String,
component_id: String,
component_type: ComponentType,
}
fn extract_component_info<S>(
span: &tracing_subscriber::registry::SpanRef<'_, S>,
) -> Option<ComponentInfo>
where
S: Subscriber + for<'a> LookupSpan<'a>,
{
let extensions = span.extensions();
extensions.get::<ComponentInfo>().cloned()
}
#[derive(Default)]
struct ComponentInfoVisitor {
instance_id: Option<String>,
component_id: Option<String>,
component_type: Option<String>,
}
impl Visit for ComponentInfoVisitor {
fn record_debug(&mut self, field: &Field, value: &dyn std::fmt::Debug) {
match field.name() {
"instance_id" => {
self.instance_id = Some(format!("{value:?}").trim_matches('"').to_string())
}
"component_id" => {
self.component_id = Some(format!("{value:?}").trim_matches('"').to_string())
}
"component_type" => {
self.component_type = Some(format!("{value:?}").trim_matches('"').to_string())
}
_ => {}
}
}
fn record_str(&mut self, field: &Field, value: &str) {
match field.name() {
"instance_id" => self.instance_id = Some(value.to_string()),
"component_id" => self.component_id = Some(value.to_string()),
"component_type" => self.component_type = Some(value.to_string()),
_ => {}
}
}
}
impl ComponentInfoVisitor {
fn into_component_info(self) -> Option<ComponentInfo> {
let component_id = self.component_id?;
let component_type = self
.component_type
.as_deref()
.and_then(parse_component_type)?;
Some(ComponentInfo {
instance_id: self.instance_id.unwrap_or_default(),
component_id,
component_type,
})
}
}
fn parse_component_type(s: &str) -> Option<ComponentType> {
match s.to_lowercase().as_str() {
"source" => Some(ComponentType::Source),
"query" => Some(ComponentType::Query),
"reaction" => Some(ComponentType::Reaction),
_ => None,
}
}
#[derive(Default)]
struct MessageVisitor {
message: Option<String>,
fields: Vec<String>,
}
impl Visit for MessageVisitor {
fn record_debug(&mut self, field: &Field, value: &dyn std::fmt::Debug) {
if field.name() == "message" {
self.message = Some(format!("{value:?}").trim_matches('"').to_string());
} else {
self.fields.push(format!("{}={value:?}", field.name()));
}
}
fn record_str(&mut self, field: &Field, value: &str) {
if field.name() == "message" {
self.message = Some(value.to_string());
} else {
self.fields.push(format!("{}={}", field.name(), value));
}
}
}
fn extract_message(event: &Event<'_>) -> String {
let mut visitor = MessageVisitor::default();
event.record(&mut visitor);
if let Some(msg) = visitor.message {
msg
} else if !visitor.fields.is_empty() {
visitor.fields.join(", ")
} else {
event.metadata().name().to_string()
}
}
fn convert_level(level: tracing::Level) -> LogLevel {
match level {
tracing::Level::ERROR => LogLevel::Error,
tracing::Level::WARN => LogLevel::Warn,
tracing::Level::INFO => LogLevel::Info,
tracing::Level::DEBUG => LogLevel::Debug,
tracing::Level::TRACE => LogLevel::Trace,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_component_type() {
assert_eq!(parse_component_type("source"), Some(ComponentType::Source));
assert_eq!(parse_component_type("Source"), Some(ComponentType::Source));
assert_eq!(parse_component_type("SOURCE"), Some(ComponentType::Source));
assert_eq!(parse_component_type("query"), Some(ComponentType::Query));
assert_eq!(
parse_component_type("reaction"),
Some(ComponentType::Reaction)
);
assert_eq!(parse_component_type("unknown"), None);
}
#[test]
fn test_convert_level() {
assert_eq!(convert_level(tracing::Level::ERROR), LogLevel::Error);
assert_eq!(convert_level(tracing::Level::WARN), LogLevel::Warn);
assert_eq!(convert_level(tracing::Level::INFO), LogLevel::Info);
assert_eq!(convert_level(tracing::Level::DEBUG), LogLevel::Debug);
assert_eq!(convert_level(tracing::Level::TRACE), LogLevel::Trace);
}
}