Skip to main content

libdd_sampling/
sampling_rule_config.rs

1// Copyright 2025-Present Datadog, Inc. https://www.datadoghq.com/
2// SPDX-License-Identifier: Apache-2.0
3
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::fmt::Display;
7use std::ops::Deref;
8use std::str::FromStr;
9
10/// Configuration for a single sampling rule
11#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
12pub struct SamplingRuleConfig {
13    /// The sample rate to apply (0.0-1.0)
14    pub sample_rate: f64,
15
16    /// Optional service name pattern to match
17    #[serde(default)]
18    pub service: Option<String>,
19
20    /// Optional span name pattern to match
21    #[serde(default)]
22    pub name: Option<String>,
23
24    /// Optional resource name pattern to match
25    #[serde(default)]
26    pub resource: Option<String>,
27
28    /// Tags that must match (key-value pairs).
29    ///
30    /// Accepts either the map shape `{"env": "prod"}` or the Remote Config
31    /// wire shape `[{"key": "env", "value_glob": "prod"}]`. Internally both
32    /// normalize to the map shape; the list-shape entries are required to
33    /// have both `key` and `value_glob` (missing either rejects the rule).
34    #[serde(default, deserialize_with = "deserialize_tags")]
35    pub tags: HashMap<String, String>,
36
37    /// Where this rule comes from (customer, dynamic, default).
38    /// Not exposed in the public `datadog-opentelemetry` API — set automatically
39    /// during conversion from the public `SamplingRuleConfig` type.
40    #[serde(default = "default_provenance")]
41    pub provenance: String,
42}
43
44impl Default for SamplingRuleConfig {
45    fn default() -> Self {
46        // Keep `Default` in sync with the serde defaults so that constructing a config
47        // with `..Default::default()` matches what deserialization would produce.
48        Self {
49            sample_rate: 0.0,
50            service: None,
51            name: None,
52            resource: None,
53            tags: HashMap::new(),
54            provenance: default_provenance(),
55        }
56    }
57}
58
59impl Display for SamplingRuleConfig {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        write!(f, "{}", serde_json::json!(self))
62    }
63}
64
65fn default_provenance() -> String {
66    "default".to_string()
67}
68
69/// Deserializes the `tags` field, accepting either:
70///   - map shape:  `{"env": "prod", "region": "us-east-1"}`
71///   - list shape: `[{"key": "env", "value_glob": "prod"}, ...]`
72///
73/// A list entry missing `key` or `value_glob` produces a deserialization
74/// error; we never silently drop entries because that could broaden a
75/// tag-constrained sampling rule.
76fn deserialize_tags<'de, D>(deserializer: D) -> Result<HashMap<String, String>, D::Error>
77where
78    D: serde::Deserializer<'de>,
79{
80    use serde::de::{MapAccess, SeqAccess, Visitor};
81    use std::fmt;
82
83    #[derive(serde::Deserialize)]
84    struct ListEntry {
85        key: String,
86        value_glob: String,
87    }
88
89    struct TagsVisitor;
90
91    impl<'de> Visitor<'de> for TagsVisitor {
92        type Value = HashMap<String, String>;
93
94        fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
95            f.write_str("a map of string to string or a list of {key, value_glob} objects")
96        }
97
98        fn visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error>
99        where
100            M: MapAccess<'de>,
101        {
102            let mut map = HashMap::with_capacity(access.size_hint().unwrap_or(0));
103            while let Some((k, v)) = access.next_entry::<String, String>()? {
104                map.insert(k, v);
105            }
106            Ok(map)
107        }
108
109        fn visit_seq<S>(self, mut access: S) -> Result<Self::Value, S::Error>
110        where
111            S: SeqAccess<'de>,
112        {
113            let mut map = HashMap::with_capacity(access.size_hint().unwrap_or(0));
114            while let Some(entry) = access.next_element::<ListEntry>()? {
115                map.insert(entry.key, entry.value_glob);
116            }
117            Ok(map)
118        }
119    }
120
121    deserializer.deserialize_any(TagsVisitor)
122}
123
124#[derive(Debug, Default, Clone, PartialEq)]
125pub struct ParsedSamplingRules {
126    pub rules: Vec<SamplingRuleConfig>,
127}
128
129impl Deref for ParsedSamplingRules {
130    type Target = [SamplingRuleConfig];
131
132    fn deref(&self) -> &Self::Target {
133        &self.rules
134    }
135}
136
137impl From<ParsedSamplingRules> for Vec<SamplingRuleConfig> {
138    fn from(parsed: ParsedSamplingRules) -> Self {
139        parsed.rules
140    }
141}
142
143impl FromStr for ParsedSamplingRules {
144    type Err = serde_json::Error;
145
146    fn from_str(s: &str) -> Result<Self, Self::Err> {
147        if s.trim().is_empty() {
148            return Ok(ParsedSamplingRules::default());
149        }
150        // DD_TRACE_SAMPLING_RULES is expected to be a JSON array of SamplingRuleConfig objects.
151        let rules_vec: Vec<SamplingRuleConfig> = serde_json::from_str(s)?;
152        Ok(ParsedSamplingRules { rules: rules_vec })
153    }
154}
155
156impl Display for ParsedSamplingRules {
157    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
158        write!(
159            f,
160            "{}",
161            serde_json::to_string(&self.rules).unwrap_or_default()
162        )
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169
170    // --- SamplingRuleConfig ---
171
172    #[test]
173    fn test_sampling_rule_config_defaults() {
174        let config = SamplingRuleConfig::default();
175        assert_eq!(config.sample_rate, 0.0);
176        assert!(config.service.is_none());
177        assert!(config.name.is_none());
178        assert!(config.resource.is_none());
179        assert!(config.tags.is_empty());
180        // `Default` matches the serde default for `provenance`.
181        assert_eq!(config.provenance, "default");
182    }
183
184    #[test]
185    fn test_sampling_rule_config_default_matches_serde_default() {
186        // Constructing from an empty-but-valid JSON object must yield the same value
187        // as `Default::default()`.
188        let from_serde: SamplingRuleConfig =
189            serde_json::from_str(r#"{"sample_rate": 0.0}"#).unwrap();
190        assert_eq!(from_serde, SamplingRuleConfig::default());
191    }
192
193    #[test]
194    fn test_sampling_rule_config_serde_default_provenance() {
195        // When provenance is absent from JSON, serde fills it in as "default"
196        let json = r#"{"sample_rate": 0.5}"#;
197        let config: SamplingRuleConfig = serde_json::from_str(json).unwrap();
198        assert_eq!(config.provenance, "default");
199    }
200
201    #[test]
202    fn test_sampling_rule_config_deserialize_full() {
203        let json = r#"{
204            "sample_rate": 0.5,
205            "service": "my-service",
206            "name": "http.*",
207            "resource": "/api/*",
208            "tags": {"env": "prod"},
209            "provenance": "customer"
210        }"#;
211        let config: SamplingRuleConfig = serde_json::from_str(json).unwrap();
212        assert_eq!(config.sample_rate, 0.5);
213        assert_eq!(config.service.as_deref(), Some("my-service"));
214        assert_eq!(config.name.as_deref(), Some("http.*"));
215        assert_eq!(config.resource.as_deref(), Some("/api/*"));
216        assert_eq!(config.tags.get("env").map(String::as_str), Some("prod"));
217        assert_eq!(config.provenance, "customer");
218    }
219
220    #[test]
221    fn test_sampling_rule_config_deserialize_minimal() {
222        let json = r#"{"sample_rate": 1.0}"#;
223        let config: SamplingRuleConfig = serde_json::from_str(json).unwrap();
224        assert_eq!(config.sample_rate, 1.0);
225        assert!(config.service.is_none());
226        assert_eq!(config.provenance, "default");
227    }
228
229    #[test]
230    fn test_sampling_rule_config_roundtrip() {
231        let original = SamplingRuleConfig {
232            sample_rate: 0.25,
233            service: Some("svc".into()),
234            name: Some("op".into()),
235            resource: Some("/res".into()),
236            tags: HashMap::from([("k".into(), "v".into())]),
237            provenance: "dynamic".into(),
238        };
239        let json = serde_json::to_string(&original).unwrap();
240        let restored: SamplingRuleConfig = serde_json::from_str(&json).unwrap();
241        assert_eq!(original, restored);
242    }
243
244    #[test]
245    fn test_sampling_rule_config_tags_accepts_map_shape() {
246        // Already supported — guard against regression.
247        let json = r#"{
248            "sample_rate": 0.5,
249            "service": "svc",
250            "tags": {"env": "prod", "region": "us-east-1"}
251        }"#;
252        let cfg: SamplingRuleConfig = serde_json::from_str(json).unwrap();
253        assert_eq!(cfg.tags.get("env").map(String::as_str), Some("prod"));
254        assert_eq!(
255            cfg.tags.get("region").map(String::as_str),
256            Some("us-east-1")
257        );
258    }
259
260    #[test]
261    fn test_sampling_rule_config_tags_accepts_rc_list_shape() {
262        // Remote Config wire shape: list of {key, value_glob} entries.
263        let json = r#"{
264            "sample_rate": 0.5,
265            "service": "svc",
266            "tags": [
267                {"key": "env", "value_glob": "prod"},
268                {"key": "region", "value_glob": "us-east-1"}
269            ]
270        }"#;
271        let cfg: SamplingRuleConfig = serde_json::from_str(json).unwrap();
272        assert_eq!(cfg.tags.get("env").map(String::as_str), Some("prod"));
273        assert_eq!(
274            cfg.tags.get("region").map(String::as_str),
275            Some("us-east-1")
276        );
277    }
278
279    #[test]
280    fn test_sampling_rule_config_tags_list_with_malformed_entry_rejects() {
281        // A list entry missing `value_glob` must reject the whole rule rather
282        // than silently dropping the entry — silently dropping a constraint could
283        // broaden a tag-constrained rule and produce a security-relevant change
284        // in sampling decisions.
285        let json = r#"{
286            "sample_rate": 0.5,
287            "tags": [
288                {"key": "env", "value_glob": "prod"},
289                {"key": "region"}
290            ]
291        }"#;
292        let result: Result<SamplingRuleConfig, _> = serde_json::from_str(json);
293        assert!(result.is_err(), "expected deserialization to fail");
294    }
295
296    #[test]
297    fn test_sampling_rule_config_tags_absent_defaults_to_empty() {
298        let json = r#"{"sample_rate": 0.5}"#;
299        let cfg: SamplingRuleConfig = serde_json::from_str(json).unwrap();
300        assert!(cfg.tags.is_empty());
301    }
302
303    #[test]
304    fn test_sampling_rule_config_display() {
305        let config = SamplingRuleConfig {
306            sample_rate: 1.0,
307            service: Some("svc".into()),
308            ..Default::default()
309        };
310        let s = config.to_string();
311        assert!(s.contains("sample_rate"));
312        assert!(s.contains("svc"));
313    }
314
315    // --- ParsedSamplingRules ---
316
317    #[test]
318    fn test_parsed_sampling_rules_empty_string() {
319        let parsed: ParsedSamplingRules = "".parse().unwrap();
320        assert!(parsed.rules.is_empty());
321    }
322
323    #[test]
324    fn test_parsed_sampling_rules_whitespace_only() {
325        let parsed: ParsedSamplingRules = "   ".parse().unwrap();
326        assert!(parsed.rules.is_empty());
327    }
328
329    #[test]
330    fn test_parsed_sampling_rules_valid_json() {
331        let json = r#"[{"sample_rate": 0.5, "service": "svc"}, {"sample_rate": 1.0}]"#;
332        let parsed: ParsedSamplingRules = json.parse().unwrap();
333        assert_eq!(parsed.rules.len(), 2);
334        assert_eq!(parsed.rules[0].sample_rate, 0.5);
335        assert_eq!(parsed.rules[1].sample_rate, 1.0);
336    }
337
338    #[test]
339    fn test_parsed_sampling_rules_invalid_json() {
340        let result: Result<ParsedSamplingRules, _> = "not json".parse();
341        assert!(result.is_err());
342    }
343
344    #[test]
345    fn test_parsed_sampling_rules_deref() {
346        let json = r#"[{"sample_rate": 0.5}]"#;
347        let parsed: ParsedSamplingRules = json.parse().unwrap();
348        // Deref to &[SamplingRuleConfig]
349        assert_eq!(parsed.len(), 1);
350        assert_eq!(parsed[0].sample_rate, 0.5);
351    }
352
353    #[test]
354    fn test_parsed_sampling_rules_into_vec() {
355        let json = r#"[{"sample_rate": 0.5}, {"sample_rate": 1.0}]"#;
356        let parsed: ParsedSamplingRules = json.parse().unwrap();
357        let vec: Vec<SamplingRuleConfig> = parsed.into();
358        assert_eq!(vec.len(), 2);
359    }
360
361    #[test]
362    fn test_parsed_sampling_rules_display() {
363        let json = r#"[{"sample_rate":0.5}]"#;
364        let parsed: ParsedSamplingRules = json.parse().unwrap();
365        let s = parsed.to_string();
366        assert!(s.contains("sample_rate"));
367        assert!(s.contains("0.5"));
368    }
369
370    #[test]
371    fn test_parsed_sampling_rules_default_is_empty() {
372        let parsed = ParsedSamplingRules::default();
373        assert!(parsed.rules.is_empty());
374    }
375}