1use opentelemetry::trace::{SamplingDecision as OtelSamplingDecision, SamplingResult, TraceId};
2use opentelemetry_sdk::trace::ShouldSample;
3use std::collections::HashMap;
4use std::sync::Arc;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum SamplingDecision {
9 RecordAndSample,
11 RecordOnly,
13 Drop,
15}
16
17impl From<SamplingDecision> for OtelSamplingDecision {
18 fn from(decision: SamplingDecision) -> Self {
19 match decision {
20 SamplingDecision::RecordAndSample => OtelSamplingDecision::RecordAndSample,
21 SamplingDecision::RecordOnly => OtelSamplingDecision::RecordOnly,
22 SamplingDecision::Drop => OtelSamplingDecision::Drop,
23 }
24 }
25}
26
27#[derive(Clone)]
29pub enum SamplingStrategy {
30 AlwaysOn,
32 AlwaysOff,
34 Probability(f64),
36 TraceIdRatio(f64),
38 Custom(Arc<dyn Fn(&TraceId, &str) -> SamplingDecision + Send + Sync>),
40}
41
42impl std::fmt::Debug for SamplingStrategy {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 match self {
45 Self::AlwaysOn => write!(f, "AlwaysOn"),
46 Self::AlwaysOff => write!(f, "AlwaysOff"),
47 Self::Probability(rate) => f.debug_tuple("Probability").field(rate).finish(),
48 Self::TraceIdRatio(ratio) => f.debug_tuple("TraceIdRatio").field(ratio).finish(),
49 Self::Custom(_) => write!(f, "Custom(...)"),
50 }
51 }
52}
53
54#[derive(Clone, Debug)]
56pub struct TraceSampler {
57 strategy: SamplingStrategy,
58 attributes: HashMap<String, String>,
59}
60
61impl TraceSampler {
62 pub fn new(strategy: SamplingStrategy) -> Self {
64 Self {
65 strategy,
66 attributes: HashMap::new(),
67 }
68 }
69
70 pub fn always_on() -> Self {
72 Self::new(SamplingStrategy::AlwaysOn)
73 }
74
75 pub fn always_off() -> Self {
77 Self::new(SamplingStrategy::AlwaysOff)
78 }
79
80 pub fn probability(rate: f64) -> Self {
82 let rate = rate.clamp(0.0, 1.0);
83 Self::new(SamplingStrategy::Probability(rate))
84 }
85
86 pub fn trace_id_ratio(ratio: f64) -> Self {
88 let ratio = ratio.clamp(0.0, 1.0);
89 Self::new(SamplingStrategy::TraceIdRatio(ratio))
90 }
91
92 pub fn with_attribute(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
94 self.attributes.insert(key.into(), value.into());
95 self
96 }
97
98 pub fn should_sample(
100 &self,
101 trace_id: &TraceId,
102 name: &str,
103 parent_sampled: Option<bool>,
104 ) -> (SamplingDecision, HashMap<String, String>) {
105 let decision = match &self.strategy {
106 SamplingStrategy::AlwaysOn => SamplingDecision::RecordAndSample,
107 SamplingStrategy::AlwaysOff => SamplingDecision::Drop,
108 SamplingStrategy::Probability(rate) => {
109 if rand::random::<f64>() < *rate {
110 SamplingDecision::RecordAndSample
111 } else {
112 SamplingDecision::Drop
113 }
114 }
115 SamplingStrategy::TraceIdRatio(ratio) => {
116 let trace_id_bytes = trace_id.to_bytes();
118 let hash = u64::from_be_bytes([
119 trace_id_bytes[8],
120 trace_id_bytes[9],
121 trace_id_bytes[10],
122 trace_id_bytes[11],
123 trace_id_bytes[12],
124 trace_id_bytes[13],
125 trace_id_bytes[14],
126 trace_id_bytes[15],
127 ]);
128 let threshold = (ratio * u64::MAX as f64) as u64;
129
130 if hash < threshold {
131 SamplingDecision::RecordAndSample
132 } else {
133 SamplingDecision::Drop
134 }
135 }
136 SamplingStrategy::Custom(sampler) => sampler(trace_id, name),
137 };
138
139 let final_decision = if let Some(true) = parent_sampled {
141 match decision {
142 SamplingDecision::Drop => SamplingDecision::RecordOnly,
143 _ => decision,
144 }
145 } else {
146 decision
147 };
148
149 (final_decision, self.attributes.clone())
150 }
151}
152
153impl Default for TraceSampler {
154 fn default() -> Self {
155 Self::always_on()
156 }
157}
158
159#[derive(Debug, Clone)]
161pub struct OtelSamplerAdapter {
162 sampler: TraceSampler,
163}
164
165impl OtelSamplerAdapter {
166 pub fn new(sampler: TraceSampler) -> Self {
167 Self { sampler }
168 }
169}
170
171impl ShouldSample for OtelSamplerAdapter {
172 fn should_sample(
173 &self,
174 parent_context: Option<&opentelemetry::Context>,
175 trace_id: TraceId,
176 name: &str,
177 _span_kind: &opentelemetry::trace::SpanKind,
178 _attributes: &[opentelemetry::KeyValue],
179 _links: &[opentelemetry::trace::Link],
180 ) -> SamplingResult {
181 use opentelemetry::trace::TraceContextExt;
182
183 let parent_sampled = parent_context.and_then(|ctx| {
184 let span = ctx.span();
185 let span_context = span.span_context();
186 if span_context.is_valid() {
187 Some(span_context.trace_flags().is_sampled())
188 } else {
189 None
190 }
191 });
192
193 let (decision, attributes) = self.sampler.should_sample(&trace_id, name, parent_sampled);
194
195 let otel_attributes = attributes
196 .into_iter()
197 .map(|(k, v)| opentelemetry::KeyValue::new(k, v))
198 .collect();
199
200 SamplingResult {
201 decision: decision.into(),
202 attributes: otel_attributes,
203 trace_state: Default::default(),
204 }
205 }
206}
207
208use rand;
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214
215 #[test]
216 fn test_sampling_strategies() {
217 let always_on = TraceSampler::always_on();
218 let always_off = TraceSampler::always_off();
219 let probability = TraceSampler::probability(0.5);
220
221 let trace_id = TraceId::from_bytes([1; 16]);
222
223 let (on_decision, _) = always_on.should_sample(&trace_id, "test", None);
224 assert_eq!(on_decision, SamplingDecision::RecordAndSample);
225
226 let (off_decision, _) = always_off.should_sample(&trace_id, "test", None);
227 assert_eq!(off_decision, SamplingDecision::Drop);
228
229 let mut sampled_count = 0;
231 for _ in 0..1000 {
232 let (decision, _) = probability.should_sample(&trace_id, "test", None);
233 if decision == SamplingDecision::RecordAndSample {
234 sampled_count += 1;
235 }
236 }
237
238 assert!(sampled_count > 400 && sampled_count < 600);
240 }
241
242 #[test]
243 fn test_trace_id_ratio_sampling() {
244 let sampler = TraceSampler::trace_id_ratio(0.5);
245
246 let trace_id1 = TraceId::from_bytes([0; 16]);
248 let trace_id2 = TraceId::from_bytes([255; 16]);
249
250 let (decision1, _) = sampler.should_sample(&trace_id1, "test", None);
251 let (decision2, _) = sampler.should_sample(&trace_id2, "test", None);
252
253 let (decision1_repeat, _) = sampler.should_sample(&trace_id1, "test", None);
255 let (decision2_repeat, _) = sampler.should_sample(&trace_id2, "test", None);
256 assert_eq!(decision1, decision1_repeat);
257 assert_eq!(decision2, decision2_repeat);
258 }
259
260 #[test]
261 fn test_parent_sampling() {
262 let sampler = TraceSampler::always_off();
263
264 let trace_id = TraceId::from_bytes([1; 16]);
265
266 let (decision, _) = sampler.should_sample(&trace_id, "test", None);
268 assert_eq!(decision, SamplingDecision::Drop);
269
270 let (decision_with_parent, _) = sampler.should_sample(&trace_id, "test", Some(true));
272 assert_eq!(decision_with_parent, SamplingDecision::RecordOnly);
273 }
274}