ts_opentelemetry_sdk/trace/sampler/jaeger_remote/
sampler.rs1use crate::runtime::RuntimeChannel;
2use crate::trace::sampler::jaeger_remote::remote::SamplingStrategyResponse;
3use crate::trace::sampler::jaeger_remote::sampling_strategy::Inner;
4use crate::trace::{BatchMessage, Sampler, ShouldSample};
5use futures_util::{stream, StreamExt as _};
6use http::Uri;
7use ts_opentelemetry_api::trace::{Link, SamplingResult, SpanKind, TraceError, TraceId};
8use ts_opentelemetry_api::{global, Context, Key, OrderMap, Value};
9use ts_opentelemetry_http::HttpClient;
10use std::str::FromStr;
11use std::sync::Arc;
12use std::time::Duration;
13
14const DEFAULT_REMOTE_SAMPLER_ENDPOINT: &str = "http://localhost:5778/sampling";
15
16#[derive(Debug)]
19pub struct JaegerRemoteSamplerBuilder<C, S, R>
20where
21 R: RuntimeChannel<BatchMessage>,
22 C: HttpClient + 'static,
23 S: ShouldSample + 'static,
24{
25 pub(crate) update_interval: Duration,
26 pub(crate) client: C,
27 pub(crate) endpoint: String,
28 pub(crate) default_sampler: S,
29 pub(crate) leaky_bucket_size: f64,
30 pub(crate) runtime: R,
31 pub(crate) service_name: String,
32}
33
34impl<C, S, R> JaegerRemoteSamplerBuilder<C, S, R>
35where
36 C: HttpClient + 'static,
37 S: ShouldSample + 'static,
38 R: RuntimeChannel<BatchMessage>,
39{
40 pub(crate) fn new<Svc>(
41 runtime: R,
42 http_client: C,
43 default_sampler: S,
44 service_name: Svc,
45 ) -> Self
46 where
47 Svc: Into<String>,
48 {
49 JaegerRemoteSamplerBuilder {
50 runtime,
51 update_interval: Duration::from_secs(60 * 5),
52 client: http_client,
53 endpoint: DEFAULT_REMOTE_SAMPLER_ENDPOINT.to_string(),
54 default_sampler,
55 leaky_bucket_size: 100.0,
56 service_name: service_name.into(),
57 }
58 }
59
60 pub fn with_update_interval(self, interval: Duration) -> Self {
66 Self {
67 update_interval: interval,
68 ..self
69 }
70 }
71
72 pub fn with_endpoint<Str: Into<String>>(self, endpoint: Str) -> Self {
78 Self {
79 endpoint: endpoint.into(),
80 ..self
81 }
82 }
83
84 pub fn with_leaky_bucket_size(self, size: f64) -> Self {
90 Self {
91 leaky_bucket_size: size,
92 ..self
93 }
94 }
95
96 pub fn build(self) -> Result<Sampler, TraceError> {
103 let endpoint = Self::get_endpoint(&self.endpoint, &self.service_name)
104 .map_err(|err_str| TraceError::Other(err_str.into()))?;
105
106 Ok(Sampler::JaegerRemote(JaegerRemoteSampler::new(
107 self.runtime,
108 self.update_interval,
109 self.client,
110 endpoint,
111 self.default_sampler,
112 self.leaky_bucket_size,
113 )))
114 }
115
116 fn get_endpoint(endpoint: &str, service_name: &str) -> Result<Uri, String> {
117 if endpoint.is_empty() || service_name.is_empty() {
118 return Err("endpoint and service name cannot be empty".to_string());
119 }
120 let mut endpoint = url::Url::parse(endpoint)
121 .unwrap_or_else(|_| url::Url::parse(DEFAULT_REMOTE_SAMPLER_ENDPOINT).unwrap());
122
123 endpoint
124 .query_pairs_mut()
125 .append_pair("service", service_name);
126
127 Uri::from_str(endpoint.as_str()).map_err(|_err| "invalid service name".to_string())
128 }
129}
130
131#[derive(Clone, Debug)]
143pub struct JaegerRemoteSampler {
144 inner: Arc<Inner>,
145 default_sampler: Arc<dyn ShouldSample + 'static>,
146}
147
148impl JaegerRemoteSampler {
149 fn new<C, R, S>(
150 runtime: R,
151 update_timeout: Duration,
152 client: C,
153 endpoint: Uri,
154 default_sampler: S,
155 leaky_bucket_size: f64,
156 ) -> Self
157 where
158 R: RuntimeChannel<BatchMessage>,
159 C: HttpClient + 'static,
160 S: ShouldSample + 'static,
161 {
162 let (shutdown_tx, shutdown_rx) = futures_channel::mpsc::channel(1);
163 let inner = Arc::new(Inner::new(leaky_bucket_size, shutdown_tx));
164 let sampler = JaegerRemoteSampler {
165 inner,
166 default_sampler: Arc::new(default_sampler),
167 };
168 Self::run_update_task(
169 runtime,
170 sampler.inner.clone(),
171 update_timeout,
172 client,
173 shutdown_rx,
174 endpoint,
175 );
176 sampler
177 }
178
179 fn run_update_task<C, R>(
181 runtime: R,
182 strategy: Arc<Inner>,
183 update_timeout: Duration,
184 client: C,
185 shutdown: futures_channel::mpsc::Receiver<()>,
186 endpoint: Uri,
187 ) where
188 R: RuntimeChannel<BatchMessage>,
189 C: HttpClient + 'static,
190 {
191 let interval = runtime.interval(update_timeout);
193 runtime.spawn(Box::pin(async move {
194 let mut update = Box::pin(stream::select(
196 shutdown.map(|_| false),
197 interval.map(|_| true),
198 ));
199
200 while let Some(should_update) = update.next().await {
201 if should_update {
202 match Self::request_new_strategy(&client, endpoint.clone()).await {
205 Ok(remote_strategy_resp) => strategy.update(remote_strategy_resp),
206 Err(err_msg) => global::handle_error(TraceError::Other(err_msg.into())),
207 };
208 } else {
209 break;
211 }
212 }
213 }));
214 }
215
216 async fn request_new_strategy<C>(
217 client: &C,
218 endpoint: Uri,
219 ) -> Result<SamplingStrategyResponse, String>
220 where
221 C: HttpClient,
222 {
223 let request = http::Request::get(endpoint)
224 .header("Content-Type", "application/json")
225 .body(Vec::new())
226 .unwrap();
227
228 let resp = client
229 .send(request)
230 .await
231 .map_err(|err| format!("the request is failed to send {}", err))?;
232
233 if resp.status() != http::StatusCode::OK {
235 return Err(format!(
236 "the http response code is not 200 but {}",
237 resp.status()
238 ));
239 }
240
241 serde_json::from_slice(&resp.body()[..])
243 .map_err(|err| format!("cannot deserialize the response, {}", err))
244 }
245}
246
247impl ShouldSample for JaegerRemoteSampler {
248 fn should_sample(
249 &self,
250 parent_context: Option<&Context>,
251 trace_id: TraceId,
252 name: &str,
253 span_kind: &SpanKind,
254 attributes: &OrderMap<Key, Value>,
255 links: &[Link],
256 ) -> SamplingResult {
257 self.inner
258 .should_sample(parent_context, trace_id, name)
259 .unwrap_or_else(|| {
260 self.default_sampler.should_sample(
261 parent_context,
262 trace_id,
263 name,
264 span_kind,
265 attributes,
266 links,
267 )
268 })
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use crate::trace::sampler::jaeger_remote::remote::SamplingStrategyType;
275 use std::fmt::{Debug, Formatter};
276
277 impl Debug for SamplingStrategyType {
278 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
279 match &self {
280 SamplingStrategyType::Probabilistic => f.write_str("Probabilistic"),
281 SamplingStrategyType::RateLimiting => f.write_str("RateLimiting"),
282 }
283 }
284 }
285
286 #[test]
287 fn deserialize_sampling_strategy_response() {
288 let json = r#"{
289 "strategyType": "PROBABILISTIC",
290 "probabilisticSampling": {
291 "samplingRate": 0.5
292 }
293 }"#;
294 let resp: super::SamplingStrategyResponse = serde_json::from_str(json).unwrap();
295 assert_eq!(resp.strategy_type, SamplingStrategyType::Probabilistic);
296 assert_eq!(resp.probabilistic_sampling.unwrap().sampling_rate, 0.5);
297 }
298}