ts_opentelemetry_sdk/trace/sampler/jaeger_remote/
sampler.rs

1use crate::runtime::RuntimeChannel;
2use crate::trace::sampler::jaeger_remote::remote::SamplingStrategyResponse;
3use crate::trace::sampler::jaeger_remote::sampling_strategy::Inner;
4use crate::trace::{BatchMessage, Sampler, ShouldSample};
5use futures_util::{stream, StreamExt as _};
6use http::Uri;
7use ts_opentelemetry_api::trace::{Link, SamplingResult, SpanKind, TraceError, TraceId};
8use ts_opentelemetry_api::{global, Context, Key, OrderMap, Value};
9use ts_opentelemetry_http::HttpClient;
10use std::str::FromStr;
11use std::sync::Arc;
12use std::time::Duration;
13
14const DEFAULT_REMOTE_SAMPLER_ENDPOINT: &str = "http://localhost:5778/sampling";
15
16/// Builder for [`JaegerRemoteSampler`].
17/// See [Sampler::jaeger_remote] for details.
18#[derive(Debug)]
19pub struct JaegerRemoteSamplerBuilder<C, S, R>
20where
21    R: RuntimeChannel<BatchMessage>,
22    C: HttpClient + 'static,
23    S: ShouldSample + 'static,
24{
25    pub(crate) update_interval: Duration,
26    pub(crate) client: C,
27    pub(crate) endpoint: String,
28    pub(crate) default_sampler: S,
29    pub(crate) leaky_bucket_size: f64,
30    pub(crate) runtime: R,
31    pub(crate) service_name: String,
32}
33
34impl<C, S, R> JaegerRemoteSamplerBuilder<C, S, R>
35where
36    C: HttpClient + 'static,
37    S: ShouldSample + 'static,
38    R: RuntimeChannel<BatchMessage>,
39{
40    pub(crate) fn new<Svc>(
41        runtime: R,
42        http_client: C,
43        default_sampler: S,
44        service_name: Svc,
45    ) -> Self
46    where
47        Svc: Into<String>,
48    {
49        JaegerRemoteSamplerBuilder {
50            runtime,
51            update_interval: Duration::from_secs(60 * 5),
52            client: http_client,
53            endpoint: DEFAULT_REMOTE_SAMPLER_ENDPOINT.to_string(),
54            default_sampler,
55            leaky_bucket_size: 100.0,
56            service_name: service_name.into(),
57        }
58    }
59
60    /// Change how often the SDK should fetch the sampling strategy from remote servers
61    ///
62    /// By default it fetches every 5 minutes.
63    ///
64    /// A shorter interval have a performance overhead and should be avoid.
65    pub fn with_update_interval(self, interval: Duration) -> Self {
66        Self {
67            update_interval: interval,
68            ..self
69        }
70    }
71
72    /// The endpoint of remote servers.
73    ///
74    /// By default it's `http://localhost:5778/sampling`.
75    ///
76    /// If service name is provided as part of the endpoint, it will be ignored.
77    pub fn with_endpoint<Str: Into<String>>(self, endpoint: Str) -> Self {
78        Self {
79            endpoint: endpoint.into(),
80            ..self
81        }
82    }
83
84    /// The size of the leaky bucket.
85    ///
86    /// By default the size is 100.
87    ///
88    /// It's used when sampling strategy is rate limiting.
89    pub fn with_leaky_bucket_size(self, size: f64) -> Self {
90        Self {
91            leaky_bucket_size: size,
92            ..self
93        }
94    }
95
96    /// Build a [JaegerRemoteSampler] using provided configuration.
97    ///
98    /// Return errors if:
99    ///
100    /// - the endpoint provided is empty.
101    /// - the service name provided is empty.
102    pub fn build(self) -> Result<Sampler, TraceError> {
103        let endpoint = Self::get_endpoint(&self.endpoint, &self.service_name)
104            .map_err(|err_str| TraceError::Other(err_str.into()))?;
105
106        Ok(Sampler::JaegerRemote(JaegerRemoteSampler::new(
107            self.runtime,
108            self.update_interval,
109            self.client,
110            endpoint,
111            self.default_sampler,
112            self.leaky_bucket_size,
113        )))
114    }
115
116    fn get_endpoint(endpoint: &str, service_name: &str) -> Result<Uri, String> {
117        if endpoint.is_empty() || service_name.is_empty() {
118            return Err("endpoint and service name cannot be empty".to_string());
119        }
120        let mut endpoint = url::Url::parse(endpoint)
121            .unwrap_or_else(|_| url::Url::parse(DEFAULT_REMOTE_SAMPLER_ENDPOINT).unwrap());
122
123        endpoint
124            .query_pairs_mut()
125            .append_pair("service", service_name);
126
127        Uri::from_str(endpoint.as_str()).map_err(|_err| "invalid service name".to_string())
128    }
129}
130
131/// Sampler that fetches the sampling configuration from remotes.
132///
133/// It offers the following sampling strategies:
134/// - **Probabilistic**, fetch a probability between [0.0, 1.0] from remotes and use it to sample traces. If the probability is 0.0, it will never sample traces. If the probability is 1.0, it will always sample traces.
135/// - **Rate limiting**, ses a leaky bucket rate limiter to ensure that traces are sampled with a certain constant rate.
136/// - **Per Operations**, instead of sampling all traces, it samples traces based on the span name. Only probabilistic sampling is supported at the moment.
137///
138/// User can build a [`JaegerRemoteSampler`] by getting a [`JaegerRemoteSamplerBuilder`] from [`Sampler::jaeger_remote`].
139///
140/// Note that the backend doesn't need to be Jaeger so long as it supports jaeger remote sampling
141/// protocol.
142#[derive(Clone, Debug)]
143pub struct JaegerRemoteSampler {
144    inner: Arc<Inner>,
145    default_sampler: Arc<dyn ShouldSample + 'static>,
146}
147
148impl JaegerRemoteSampler {
149    fn new<C, R, S>(
150        runtime: R,
151        update_timeout: Duration,
152        client: C,
153        endpoint: Uri,
154        default_sampler: S,
155        leaky_bucket_size: f64,
156    ) -> Self
157    where
158        R: RuntimeChannel<BatchMessage>,
159        C: HttpClient + 'static,
160        S: ShouldSample + 'static,
161    {
162        let (shutdown_tx, shutdown_rx) = futures_channel::mpsc::channel(1);
163        let inner = Arc::new(Inner::new(leaky_bucket_size, shutdown_tx));
164        let sampler = JaegerRemoteSampler {
165            inner,
166            default_sampler: Arc::new(default_sampler),
167        };
168        Self::run_update_task(
169            runtime,
170            sampler.inner.clone(),
171            update_timeout,
172            client,
173            shutdown_rx,
174            endpoint,
175        );
176        sampler
177    }
178
179    // start a updating thread/task
180    fn run_update_task<C, R>(
181        runtime: R,
182        strategy: Arc<Inner>,
183        update_timeout: Duration,
184        client: C,
185        shutdown: futures_channel::mpsc::Receiver<()>,
186        endpoint: Uri,
187    ) where
188        R: RuntimeChannel<BatchMessage>,
189        C: HttpClient + 'static,
190    {
191        // todo: review if we need 'static here
192        let interval = runtime.interval(update_timeout);
193        runtime.spawn(Box::pin(async move {
194            // either update or shutdown
195            let mut update = Box::pin(stream::select(
196                shutdown.map(|_| false),
197                interval.map(|_| true),
198            ));
199
200            while let Some(should_update) = update.next().await {
201                if should_update {
202                    // poll next available configuration or shutdown
203                    // send request
204                    match Self::request_new_strategy(&client, endpoint.clone()).await {
205                        Ok(remote_strategy_resp) => strategy.update(remote_strategy_resp),
206                        Err(err_msg) => global::handle_error(TraceError::Other(err_msg.into())),
207                    };
208                } else {
209                    // shutdown
210                    break;
211                }
212            }
213        }));
214    }
215
216    async fn request_new_strategy<C>(
217        client: &C,
218        endpoint: Uri,
219    ) -> Result<SamplingStrategyResponse, String>
220    where
221        C: HttpClient,
222    {
223        let request = http::Request::get(endpoint)
224            .header("Content-Type", "application/json")
225            .body(Vec::new())
226            .unwrap();
227
228        let resp = client
229            .send(request)
230            .await
231            .map_err(|err| format!("the request is failed to send {}", err))?;
232
233        // process failures
234        if resp.status() != http::StatusCode::OK {
235            return Err(format!(
236                "the http response code is not 200 but {}",
237                resp.status()
238            ));
239        }
240
241        // deserialize the response
242        serde_json::from_slice(&resp.body()[..])
243            .map_err(|err| format!("cannot deserialize the response, {}", err))
244    }
245}
246
247impl ShouldSample for JaegerRemoteSampler {
248    fn should_sample(
249        &self,
250        parent_context: Option<&Context>,
251        trace_id: TraceId,
252        name: &str,
253        span_kind: &SpanKind,
254        attributes: &OrderMap<Key, Value>,
255        links: &[Link],
256    ) -> SamplingResult {
257        self.inner
258            .should_sample(parent_context, trace_id, name)
259            .unwrap_or_else(|| {
260                self.default_sampler.should_sample(
261                    parent_context,
262                    trace_id,
263                    name,
264                    span_kind,
265                    attributes,
266                    links,
267                )
268            })
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use crate::trace::sampler::jaeger_remote::remote::SamplingStrategyType;
275    use std::fmt::{Debug, Formatter};
276
277    impl Debug for SamplingStrategyType {
278        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
279            match &self {
280                SamplingStrategyType::Probabilistic => f.write_str("Probabilistic"),
281                SamplingStrategyType::RateLimiting => f.write_str("RateLimiting"),
282            }
283        }
284    }
285
286    #[test]
287    fn deserialize_sampling_strategy_response() {
288        let json = r#"{
289            "strategyType": "PROBABILISTIC",
290            "probabilisticSampling": {
291                "samplingRate": 0.5
292            }
293        }"#;
294        let resp: super::SamplingStrategyResponse = serde_json::from_str(json).unwrap();
295        assert_eq!(resp.strategy_type, SamplingStrategyType::Probabilistic);
296        assert_eq!(resp.probabilistic_sampling.unwrap().sampling_rate, 0.5);
297    }
298}