revoke_trace/
sampler.rs

1use opentelemetry::trace::{SamplingDecision as OtelSamplingDecision, SamplingResult, TraceId};
2use opentelemetry_sdk::trace::ShouldSample;
3use std::collections::HashMap;
4use std::sync::Arc;
5
6/// 采样决策
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum SamplingDecision {
9    /// 记录并采样
10    RecordAndSample,
11    /// 仅记录,不采样
12    RecordOnly,
13    /// 丢弃
14    Drop,
15}
16
17impl From<SamplingDecision> for OtelSamplingDecision {
18    fn from(decision: SamplingDecision) -> Self {
19        match decision {
20            SamplingDecision::RecordAndSample => OtelSamplingDecision::RecordAndSample,
21            SamplingDecision::RecordOnly => OtelSamplingDecision::RecordOnly,
22            SamplingDecision::Drop => OtelSamplingDecision::Drop,
23        }
24    }
25}
26
27/// 采样策略
28#[derive(Clone)]
29pub enum SamplingStrategy {
30    /// 始终采样
31    AlwaysOn,
32    /// 从不采样
33    AlwaysOff,
34    /// 基于概率采样
35    Probability(f64),
36    /// 基于 trace ID 的确定性采样
37    TraceIdRatio(f64),
38    /// 自定义采样函数
39    Custom(Arc<dyn Fn(&TraceId, &str) -> SamplingDecision + Send + Sync>),
40}
41
42impl std::fmt::Debug for SamplingStrategy {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        match self {
45            Self::AlwaysOn => write!(f, "AlwaysOn"),
46            Self::AlwaysOff => write!(f, "AlwaysOff"),
47            Self::Probability(rate) => f.debug_tuple("Probability").field(rate).finish(),
48            Self::TraceIdRatio(ratio) => f.debug_tuple("TraceIdRatio").field(ratio).finish(),
49            Self::Custom(_) => write!(f, "Custom(...)"),
50        }
51    }
52}
53
54/// 追踪采样器
55#[derive(Clone, Debug)]
56pub struct TraceSampler {
57    strategy: SamplingStrategy,
58    attributes: HashMap<String, String>,
59}
60
61impl TraceSampler {
62    /// 创建新的采样器
63    pub fn new(strategy: SamplingStrategy) -> Self {
64        Self {
65            strategy,
66            attributes: HashMap::new(),
67        }
68    }
69
70    /// 创建始终采样的采样器
71    pub fn always_on() -> Self {
72        Self::new(SamplingStrategy::AlwaysOn)
73    }
74
75    /// 创建从不采样的采样器
76    pub fn always_off() -> Self {
77        Self::new(SamplingStrategy::AlwaysOff)
78    }
79
80    /// 创建概率采样器
81    pub fn probability(rate: f64) -> Self {
82        let rate = rate.clamp(0.0, 1.0);
83        Self::new(SamplingStrategy::Probability(rate))
84    }
85
86    /// 创建基于 trace ID 的采样器
87    pub fn trace_id_ratio(ratio: f64) -> Self {
88        let ratio = ratio.clamp(0.0, 1.0);
89        Self::new(SamplingStrategy::TraceIdRatio(ratio))
90    }
91
92    /// 添加采样属性
93    pub fn with_attribute(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
94        self.attributes.insert(key.into(), value.into());
95        self
96    }
97
98    /// 做出采样决策
99    pub fn should_sample(
100        &self,
101        trace_id: &TraceId,
102        name: &str,
103        parent_sampled: Option<bool>,
104    ) -> (SamplingDecision, HashMap<String, String>) {
105        let decision = match &self.strategy {
106            SamplingStrategy::AlwaysOn => SamplingDecision::RecordAndSample,
107            SamplingStrategy::AlwaysOff => SamplingDecision::Drop,
108            SamplingStrategy::Probability(rate) => {
109                if rand::random::<f64>() < *rate {
110                    SamplingDecision::RecordAndSample
111                } else {
112                    SamplingDecision::Drop
113                }
114            }
115            SamplingStrategy::TraceIdRatio(ratio) => {
116                // 基于 trace ID 的确定性采样
117                let trace_id_bytes = trace_id.to_bytes();
118                let hash = u64::from_be_bytes([
119                    trace_id_bytes[8],
120                    trace_id_bytes[9],
121                    trace_id_bytes[10],
122                    trace_id_bytes[11],
123                    trace_id_bytes[12],
124                    trace_id_bytes[13],
125                    trace_id_bytes[14],
126                    trace_id_bytes[15],
127                ]);
128                let threshold = (ratio * u64::MAX as f64) as u64;
129
130                if hash < threshold {
131                    SamplingDecision::RecordAndSample
132                } else {
133                    SamplingDecision::Drop
134                }
135            }
136            SamplingStrategy::Custom(sampler) => sampler(trace_id, name),
137        };
138
139        // 如果有父 span 且父 span 被采样,则子 span 也应该被采样
140        let final_decision = if let Some(true) = parent_sampled {
141            match decision {
142                SamplingDecision::Drop => SamplingDecision::RecordOnly,
143                _ => decision,
144            }
145        } else {
146            decision
147        };
148
149        (final_decision, self.attributes.clone())
150    }
151}
152
153impl Default for TraceSampler {
154    fn default() -> Self {
155        Self::always_on()
156    }
157}
158
159/// OpenTelemetry 采样器适配器
160#[derive(Debug, Clone)]
161pub struct OtelSamplerAdapter {
162    sampler: TraceSampler,
163}
164
165impl OtelSamplerAdapter {
166    pub fn new(sampler: TraceSampler) -> Self {
167        Self { sampler }
168    }
169}
170
171impl ShouldSample for OtelSamplerAdapter {
172    fn should_sample(
173        &self,
174        parent_context: Option<&opentelemetry::Context>,
175        trace_id: TraceId,
176        name: &str,
177        _span_kind: &opentelemetry::trace::SpanKind,
178        _attributes: &[opentelemetry::KeyValue],
179        _links: &[opentelemetry::trace::Link],
180    ) -> SamplingResult {
181        use opentelemetry::trace::TraceContextExt;
182
183        let parent_sampled = parent_context.and_then(|ctx| {
184            let span = ctx.span();
185            let span_context = span.span_context();
186            if span_context.is_valid() {
187                Some(span_context.trace_flags().is_sampled())
188            } else {
189                None
190            }
191        });
192
193        let (decision, attributes) = self.sampler.should_sample(&trace_id, name, parent_sampled);
194
195        let otel_attributes = attributes
196            .into_iter()
197            .map(|(k, v)| opentelemetry::KeyValue::new(k, v))
198            .collect();
199
200        SamplingResult {
201            decision: decision.into(),
202            attributes: otel_attributes,
203            trace_state: Default::default(),
204        }
205    }
206}
207
208/// 添加 rand 依赖
209use rand;
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214
215    #[test]
216    fn test_sampling_strategies() {
217        let always_on = TraceSampler::always_on();
218        let always_off = TraceSampler::always_off();
219        let probability = TraceSampler::probability(0.5);
220
221        let trace_id = TraceId::from_bytes([1; 16]);
222
223        let (on_decision, _) = always_on.should_sample(&trace_id, "test", None);
224        assert_eq!(on_decision, SamplingDecision::RecordAndSample);
225
226        let (off_decision, _) = always_off.should_sample(&trace_id, "test", None);
227        assert_eq!(off_decision, SamplingDecision::Drop);
228
229        // 概率采样应该有时采样,有时不采样
230        let mut sampled_count = 0;
231        for _ in 0..1000 {
232            let (decision, _) = probability.should_sample(&trace_id, "test", None);
233            if decision == SamplingDecision::RecordAndSample {
234                sampled_count += 1;
235            }
236        }
237
238        // 应该大约是 50%
239        assert!(sampled_count > 400 && sampled_count < 600);
240    }
241
242    #[test]
243    fn test_trace_id_ratio_sampling() {
244        let sampler = TraceSampler::trace_id_ratio(0.5);
245
246        // 使用不同的 trace ID 测试
247        let trace_id1 = TraceId::from_bytes([0; 16]);
248        let trace_id2 = TraceId::from_bytes([255; 16]);
249
250        let (decision1, _) = sampler.should_sample(&trace_id1, "test", None);
251        let (decision2, _) = sampler.should_sample(&trace_id2, "test", None);
252
253        // 确定性采样:相同的 trace ID 应该总是得到相同的决策
254        let (decision1_repeat, _) = sampler.should_sample(&trace_id1, "test", None);
255        let (decision2_repeat, _) = sampler.should_sample(&trace_id2, "test", None);
256        assert_eq!(decision1, decision1_repeat);
257        assert_eq!(decision2, decision2_repeat);
258    }
259
260    #[test]
261    fn test_parent_sampling() {
262        let sampler = TraceSampler::always_off();
263
264        let trace_id = TraceId::from_bytes([1; 16]);
265
266        // 没有父 span 时应该丢弃
267        let (decision, _) = sampler.should_sample(&trace_id, "test", None);
268        assert_eq!(decision, SamplingDecision::Drop);
269
270        // 父 span 被采样时应该至少记录
271        let (decision_with_parent, _) = sampler.should_sample(&trace_id, "test", Some(true));
272        assert_eq!(decision_with_parent, SamplingDecision::RecordOnly);
273    }
274}