alibabacloud-rum 0.1.0

Alibaba Cloud RUM SDK for native Rust applications.
Documentation
use rand::RngCore;
use uuid::Uuid;

use crate::error::{Result, RumError};
use crate::propagation::{
    encode_sw8_with_ids, generate_span_id, generate_trace_id, parse_traceparent,
};

const DEFAULT_SW8_SERVICE: &str = env!("CARGO_PKG_NAME");
const DEFAULT_SW8_ENDPOINT: &str = "/";
const DEFAULT_SW8_TARGET: &str = "client";

#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum TraceProtocol {
    TraceParent,
    Sw8,
}

#[derive(Clone, Debug, Eq, PartialEq)]
pub struct TraceContext {
    trace_id: String,
    span_id: String,
    protocol: TraceProtocol,
}

impl TraceContext {
    pub fn new(
        trace_id: impl Into<String>,
        span_id: impl Into<String>,
        protocol: TraceProtocol,
    ) -> Result<Self> {
        let trace_id = trace_id.into();
        let span_id = span_id.into();
        validate_trace_context(&trace_id, &span_id, protocol)?;
        Ok(Self {
            trace_id,
            span_id,
            protocol,
        })
    }

    pub fn generate(protocol: TraceProtocol) -> Self {
        Self {
            trace_id: TraceGenerator::generate_trace_id(protocol),
            span_id: TraceGenerator::generate_span_id(protocol),
            protocol,
        }
    }

    pub fn trace_id(&self) -> &str {
        &self.trace_id
    }

    pub fn span_id(&self) -> &str {
        &self.span_id
    }

    pub fn protocol(&self) -> TraceProtocol {
        self.protocol
    }
}

pub struct TraceGenerator;

impl TraceGenerator {
    pub fn generate_trace_id(protocol: TraceProtocol) -> String {
        match protocol {
            TraceProtocol::TraceParent => generate_trace_id(),
            TraceProtocol::Sw8 => Uuid::new_v4().simple().to_string(),
        }
    }

    pub fn generate_span_id(protocol: TraceProtocol) -> String {
        match protocol {
            TraceProtocol::TraceParent => generate_span_id(),
            TraceProtocol::Sw8 => generate_positive_decimal_span_id(),
        }
    }
}

pub struct TraceHeaderWriter;

impl TraceHeaderWriter {
    pub fn generate_headers(context: &TraceContext) -> Vec<(String, String)> {
        match context.protocol {
            TraceProtocol::TraceParent => vec![(
                "traceparent".to_string(),
                format!("00-{}-{}-01", context.trace_id, context.span_id),
            )],
            TraceProtocol::Sw8 => {
                let segment_id = Uuid::new_v4().simple().to_string();
                vec![(
                    "sw8".to_string(),
                    encode_sw8_with_ids(
                        &context.trace_id,
                        &segment_id,
                        &context.span_id,
                        DEFAULT_SW8_SERVICE,
                        DEFAULT_SW8_SERVICE,
                        DEFAULT_SW8_ENDPOINT,
                        DEFAULT_SW8_TARGET,
                    ),
                )]
            }
        }
    }

    pub fn generate_single_header(context: &TraceContext) -> Option<(String, String)> {
        Self::generate_headers(context).into_iter().next()
    }
}

fn validate_trace_context(trace_id: &str, span_id: &str, protocol: TraceProtocol) -> Result<()> {
    match protocol {
        TraceProtocol::TraceParent => {
            if parse_traceparent(&format!("00-{trace_id}-{span_id}-01")).is_none() {
                return Err(RumError::InvalidTraceContext {
                    field: "trace_context",
                    message: "traceparent context requires a 32-char lowercase hex trace_id and a 16-char lowercase hex span_id, both non-zero".to_string(),
                });
            }
        }
        TraceProtocol::Sw8 => {
            if trace_id.is_empty() {
                return Err(RumError::InvalidTraceContext {
                    field: "trace_context.trace_id",
                    message: "sw8 trace_id must not be empty".to_string(),
                });
            }
            if !span_id.bytes().all(|byte| byte.is_ascii_digit())
                || span_id
                    .parse::<u64>()
                    .ok()
                    .filter(|value| *value > 0)
                    .is_none()
            {
                return Err(RumError::InvalidTraceContext {
                    field: "trace_context.span_id",
                    message: "sw8 span_id must be a positive decimal integer".to_string(),
                });
            }
        }
    }
    Ok(())
}

fn generate_positive_decimal_span_id() -> String {
    let value = rand::thread_rng().next_u64() & (i64::MAX as u64);
    value.max(1).to_string()
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::propagation::decode_sw8;

    #[test]
    fn traceparent_context_generates_w3c_ids_and_header() {
        let context = TraceContext::generate(TraceProtocol::TraceParent);

        assert_eq!(context.trace_id().len(), 32);
        assert_eq!(context.span_id().len(), 16);
        let headers = TraceHeaderWriter::generate_headers(&context);
        assert_eq!(
            headers,
            vec![(
                "traceparent".to_string(),
                format!("00-{}-{}-01", context.trace_id(), context.span_id())
            )]
        );
    }

    #[test]
    fn sw8_context_generates_protocol_specific_ids_and_header() {
        let context = TraceContext::generate(TraceProtocol::Sw8);

        assert_eq!(context.trace_id().len(), 32);
        assert!(context.span_id().parse::<u64>().unwrap() > 0);
        let headers = TraceHeaderWriter::generate_headers(&context);
        assert_eq!(headers[0].0, "sw8");
        let parsed = decode_sw8(&headers[0].1).unwrap();
        assert_eq!(parsed.trace_id, context.trace_id());
        assert_eq!(parsed.segment_id.len(), 32);
        assert_eq!(parsed.parent_span_id, context.span_id());
    }

    #[test]
    fn sw8_context_writes_public_span_id_to_sw8_span_id_field() {
        let context = TraceContext::new("sw8-trace-id", "42", TraceProtocol::Sw8).unwrap();

        let headers = TraceHeaderWriter::generate_headers(&context);
        let parsed = decode_sw8(&headers[0].1).unwrap();

        assert_eq!(parsed.trace_id, context.trace_id());
        assert_eq!(parsed.segment_id.len(), 32);
        assert_eq!(parsed.parent_span_id, context.span_id());
    }

    #[test]
    fn trace_context_validates_protocol_specific_ids() {
        assert!(TraceContext::new(
            "4bf92f3577b34da6a3ce929d0e0e4736",
            "00f067aa0ba902b7",
            TraceProtocol::TraceParent,
        )
        .is_ok());
        assert!(TraceContext::new("trace", "42", TraceProtocol::Sw8).is_ok());
        assert!(TraceContext::new("trace", "not-decimal", TraceProtocol::Sw8).is_err());
    }
}