use rand::RngCore;
use uuid::Uuid;
use crate::error::{Result, RumError};
use crate::propagation::{
encode_sw8_with_ids, generate_span_id, generate_trace_id, parse_traceparent,
};
const DEFAULT_SW8_SERVICE: &str = env!("CARGO_PKG_NAME");
const DEFAULT_SW8_ENDPOINT: &str = "/";
const DEFAULT_SW8_TARGET: &str = "client";
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum TraceProtocol {
TraceParent,
Sw8,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct TraceContext {
trace_id: String,
span_id: String,
protocol: TraceProtocol,
}
impl TraceContext {
pub fn new(
trace_id: impl Into<String>,
span_id: impl Into<String>,
protocol: TraceProtocol,
) -> Result<Self> {
let trace_id = trace_id.into();
let span_id = span_id.into();
validate_trace_context(&trace_id, &span_id, protocol)?;
Ok(Self {
trace_id,
span_id,
protocol,
})
}
pub fn generate(protocol: TraceProtocol) -> Self {
Self {
trace_id: TraceGenerator::generate_trace_id(protocol),
span_id: TraceGenerator::generate_span_id(protocol),
protocol,
}
}
pub fn trace_id(&self) -> &str {
&self.trace_id
}
pub fn span_id(&self) -> &str {
&self.span_id
}
pub fn protocol(&self) -> TraceProtocol {
self.protocol
}
}
pub struct TraceGenerator;
impl TraceGenerator {
pub fn generate_trace_id(protocol: TraceProtocol) -> String {
match protocol {
TraceProtocol::TraceParent => generate_trace_id(),
TraceProtocol::Sw8 => Uuid::new_v4().simple().to_string(),
}
}
pub fn generate_span_id(protocol: TraceProtocol) -> String {
match protocol {
TraceProtocol::TraceParent => generate_span_id(),
TraceProtocol::Sw8 => generate_positive_decimal_span_id(),
}
}
}
pub struct TraceHeaderWriter;
impl TraceHeaderWriter {
pub fn generate_headers(context: &TraceContext) -> Vec<(String, String)> {
match context.protocol {
TraceProtocol::TraceParent => vec![(
"traceparent".to_string(),
format!("00-{}-{}-01", context.trace_id, context.span_id),
)],
TraceProtocol::Sw8 => {
let segment_id = Uuid::new_v4().simple().to_string();
vec![(
"sw8".to_string(),
encode_sw8_with_ids(
&context.trace_id,
&segment_id,
&context.span_id,
DEFAULT_SW8_SERVICE,
DEFAULT_SW8_SERVICE,
DEFAULT_SW8_ENDPOINT,
DEFAULT_SW8_TARGET,
),
)]
}
}
}
pub fn generate_single_header(context: &TraceContext) -> Option<(String, String)> {
Self::generate_headers(context).into_iter().next()
}
}
fn validate_trace_context(trace_id: &str, span_id: &str, protocol: TraceProtocol) -> Result<()> {
match protocol {
TraceProtocol::TraceParent => {
if parse_traceparent(&format!("00-{trace_id}-{span_id}-01")).is_none() {
return Err(RumError::InvalidTraceContext {
field: "trace_context",
message: "traceparent context requires a 32-char lowercase hex trace_id and a 16-char lowercase hex span_id, both non-zero".to_string(),
});
}
}
TraceProtocol::Sw8 => {
if trace_id.is_empty() {
return Err(RumError::InvalidTraceContext {
field: "trace_context.trace_id",
message: "sw8 trace_id must not be empty".to_string(),
});
}
if !span_id.bytes().all(|byte| byte.is_ascii_digit())
|| span_id
.parse::<u64>()
.ok()
.filter(|value| *value > 0)
.is_none()
{
return Err(RumError::InvalidTraceContext {
field: "trace_context.span_id",
message: "sw8 span_id must be a positive decimal integer".to_string(),
});
}
}
}
Ok(())
}
fn generate_positive_decimal_span_id() -> String {
let value = rand::thread_rng().next_u64() & (i64::MAX as u64);
value.max(1).to_string()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::propagation::decode_sw8;
#[test]
fn traceparent_context_generates_w3c_ids_and_header() {
let context = TraceContext::generate(TraceProtocol::TraceParent);
assert_eq!(context.trace_id().len(), 32);
assert_eq!(context.span_id().len(), 16);
let headers = TraceHeaderWriter::generate_headers(&context);
assert_eq!(
headers,
vec![(
"traceparent".to_string(),
format!("00-{}-{}-01", context.trace_id(), context.span_id())
)]
);
}
#[test]
fn sw8_context_generates_protocol_specific_ids_and_header() {
let context = TraceContext::generate(TraceProtocol::Sw8);
assert_eq!(context.trace_id().len(), 32);
assert!(context.span_id().parse::<u64>().unwrap() > 0);
let headers = TraceHeaderWriter::generate_headers(&context);
assert_eq!(headers[0].0, "sw8");
let parsed = decode_sw8(&headers[0].1).unwrap();
assert_eq!(parsed.trace_id, context.trace_id());
assert_eq!(parsed.segment_id.len(), 32);
assert_eq!(parsed.parent_span_id, context.span_id());
}
#[test]
fn sw8_context_writes_public_span_id_to_sw8_span_id_field() {
let context = TraceContext::new("sw8-trace-id", "42", TraceProtocol::Sw8).unwrap();
let headers = TraceHeaderWriter::generate_headers(&context);
let parsed = decode_sw8(&headers[0].1).unwrap();
assert_eq!(parsed.trace_id, context.trace_id());
assert_eq!(parsed.segment_id.len(), 32);
assert_eq!(parsed.parent_span_id, context.span_id());
}
#[test]
fn trace_context_validates_protocol_specific_ids() {
assert!(TraceContext::new(
"4bf92f3577b34da6a3ce929d0e0e4736",
"00f067aa0ba902b7",
TraceProtocol::TraceParent,
)
.is_ok());
assert!(TraceContext::new("trace", "42", TraceProtocol::Sw8).is_ok());
assert!(TraceContext::new("trace", "not-decimal", TraceProtocol::Sw8).is_err());
}
}