dbnexus 0.1.3

An enterprise-grade database abstraction layer for Rust with built-in permission control and connection pooling
// Copyright (c) 2026 Kirky.X
//
// Licensed under the MIT License
// See LICENSE file in the project root for full license information.

//! 分布式追踪模块
//!
//! 提供基于 OpenTelemetry 的分布式追踪功能。
//! 支持 OTLP 和标准输出导出器。

pub mod attributes;
pub mod context;

/// 采样器模块
///
/// 提供追踪采样率配置功能,允许控制追踪的覆盖率以减少性能影响。
pub mod sampler;

use opentelemetry::KeyValue;
use opentelemetry::global;
use opentelemetry_otlp::WithExportConfig;
use opentelemetry_sdk::propagation::TraceContextPropagator;
use opentelemetry_sdk::trace::{Config, TracerProvider};
use std::collections::HashMap;

/// 追踪初始化结果
pub struct TracingGuard {
    _provider: TracerProvider,
}

impl Drop for TracingGuard {
    fn drop(&mut self) {
        global::shutdown_tracer_provider();
    }
}

/// 初始化分布式追踪
pub async fn init(exporter: &str, endpoint: &str) -> Result<TracingGuard, String> {
    let provider: TracerProvider = match exporter.to_lowercase().as_str() {
        "otlp" => init_otlp(endpoint).await?,
        "stdout" => init_stdout()?,
        _ => init_stdout()?,
    };

    global::set_tracer_provider(provider.clone());

    let propagator = TraceContextPropagator::default();
    global::set_text_map_propagator(propagator);

    Ok(TracingGuard { _provider: provider })
}

/// 初始化分布式追踪并配置采样率
///
/// # 参数
///
/// * `exporter` - 导出器类型("otlp" 或 "stdout")
/// * `endpoint` - OTLP 端点(仅当 exporter 为 "otlp" 时使用)
/// * `sampling_rate` - 采样率,范围 [0.0, 1.0]
pub async fn init_with_sampling(exporter: &str, endpoint: &str, sampling_rate: f64) -> Result<TracingGuard, String> {
    let provider: TracerProvider = match exporter.to_lowercase().as_str() {
        "otlp" => init_otlp(endpoint).await?,
        "stdout" => init_stdout()?,
        _ => init_stdout()?,
    };

    // 应用采样率
    let sampled_provider = crate::tracing::sampler::create_trace_provider_with_sampling(sampling_rate)
        .map_err(|e| format!("Failed to create sampler: {}", e))?;

    global::set_tracer_provider(sampled_provider);
    let propagator = TraceContextPropagator::default();
    global::set_text_map_propagator(propagator);

    Ok(TracingGuard { _provider: provider })
}

/// 使用 OTLP 初始化追踪
async fn init_otlp(endpoint: &str) -> Result<TracerProvider, String> {
    let resource = opentelemetry_sdk::Resource::new(vec![KeyValue::new("service.name", "dbnexus")]);

    let config = Config::default().with_resource(resource);

    let provider = opentelemetry_otlp::new_pipeline()
        .tracing()
        .with_exporter(opentelemetry_otlp::new_exporter().tonic().with_endpoint(endpoint))
        .with_trace_config(config)
        .install_simple()
        .map_err(|e| e.to_string())?;

    Ok(provider)
}

/// 使用标准输出初始化追踪
fn init_stdout() -> Result<TracerProvider, String> {
    let resource = opentelemetry_sdk::Resource::new(vec![KeyValue::new("service.name", "dbnexus")]);

    let config = Config::default().with_resource(resource);

    let provider = opentelemetry_otlp::new_pipeline()
        .tracing()
        .with_exporter(opentelemetry_otlp::new_exporter().tonic().with_endpoint("stdout"))
        .with_trace_config(config)
        .install_simple()
        .map_err(|e| e.to_string())?;

    Ok(provider)
}

/// 从 HashMap 注入追踪上下文
pub fn inject(headers: &mut HashMap<String, String>) {
    global::get_text_map_propagator(|propagator| {
        propagator.inject(headers);
    });
}

/// 从 HashMap 提取追踪上下文
pub fn extract(headers: &HashMap<String, String>) {
    global::get_text_map_propagator(|propagator| {
        let _ = propagator.extract(headers);
    });
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn test_tracing_init_and_propagation() {
        let mut headers = HashMap::new();
        headers.insert("x-test".to_string(), "1".to_string());

        let guard = init("stdout", "unused").await.expect("init stdout");
        inject(&mut headers);
        extract(&headers);
        drop(guard);

        let guard = init("unknown", "unused").await.expect("init fallback");
        inject(&mut headers);
        extract(&headers);
        drop(guard);

        let guard = init("otlp", "http://localhost:4317").await.expect("init otlp");
        inject(&mut headers);
        extract(&headers);
        drop(guard);
    }
}