1use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::fmt::Display;
7use std::ops::Deref;
8use std::str::FromStr;
9
10#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
12pub struct SamplingRuleConfig {
13 pub sample_rate: f64,
15
16 #[serde(default)]
18 pub service: Option<String>,
19
20 #[serde(default)]
22 pub name: Option<String>,
23
24 #[serde(default)]
26 pub resource: Option<String>,
27
28 #[serde(default, deserialize_with = "deserialize_tags")]
35 pub tags: HashMap<String, String>,
36
37 #[serde(default = "default_provenance")]
41 pub provenance: String,
42}
43
44impl Default for SamplingRuleConfig {
45 fn default() -> Self {
46 Self {
49 sample_rate: 0.0,
50 service: None,
51 name: None,
52 resource: None,
53 tags: HashMap::new(),
54 provenance: default_provenance(),
55 }
56 }
57}
58
59impl Display for SamplingRuleConfig {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 write!(f, "{}", serde_json::json!(self))
62 }
63}
64
65fn default_provenance() -> String {
66 "default".to_string()
67}
68
69fn deserialize_tags<'de, D>(deserializer: D) -> Result<HashMap<String, String>, D::Error>
77where
78 D: serde::Deserializer<'de>,
79{
80 use serde::de::{MapAccess, SeqAccess, Visitor};
81 use std::fmt;
82
83 #[derive(serde::Deserialize)]
84 struct ListEntry {
85 key: String,
86 value_glob: String,
87 }
88
89 struct TagsVisitor;
90
91 impl<'de> Visitor<'de> for TagsVisitor {
92 type Value = HashMap<String, String>;
93
94 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
95 f.write_str("a map of string to string or a list of {key, value_glob} objects")
96 }
97
98 fn visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error>
99 where
100 M: MapAccess<'de>,
101 {
102 let mut map = HashMap::with_capacity(access.size_hint().unwrap_or(0));
103 while let Some((k, v)) = access.next_entry::<String, String>()? {
104 map.insert(k, v);
105 }
106 Ok(map)
107 }
108
109 fn visit_seq<S>(self, mut access: S) -> Result<Self::Value, S::Error>
110 where
111 S: SeqAccess<'de>,
112 {
113 let mut map = HashMap::with_capacity(access.size_hint().unwrap_or(0));
114 while let Some(entry) = access.next_element::<ListEntry>()? {
115 map.insert(entry.key, entry.value_glob);
116 }
117 Ok(map)
118 }
119 }
120
121 deserializer.deserialize_any(TagsVisitor)
122}
123
124#[derive(Debug, Default, Clone, PartialEq)]
125pub struct ParsedSamplingRules {
126 pub rules: Vec<SamplingRuleConfig>,
127}
128
129impl Deref for ParsedSamplingRules {
130 type Target = [SamplingRuleConfig];
131
132 fn deref(&self) -> &Self::Target {
133 &self.rules
134 }
135}
136
137impl From<ParsedSamplingRules> for Vec<SamplingRuleConfig> {
138 fn from(parsed: ParsedSamplingRules) -> Self {
139 parsed.rules
140 }
141}
142
143impl FromStr for ParsedSamplingRules {
144 type Err = serde_json::Error;
145
146 fn from_str(s: &str) -> Result<Self, Self::Err> {
147 if s.trim().is_empty() {
148 return Ok(ParsedSamplingRules::default());
149 }
150 let rules_vec: Vec<SamplingRuleConfig> = serde_json::from_str(s)?;
152 Ok(ParsedSamplingRules { rules: rules_vec })
153 }
154}
155
156impl Display for ParsedSamplingRules {
157 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
158 write!(
159 f,
160 "{}",
161 serde_json::to_string(&self.rules).unwrap_or_default()
162 )
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169
170 #[test]
173 fn test_sampling_rule_config_defaults() {
174 let config = SamplingRuleConfig::default();
175 assert_eq!(config.sample_rate, 0.0);
176 assert!(config.service.is_none());
177 assert!(config.name.is_none());
178 assert!(config.resource.is_none());
179 assert!(config.tags.is_empty());
180 assert_eq!(config.provenance, "default");
182 }
183
184 #[test]
185 fn test_sampling_rule_config_default_matches_serde_default() {
186 let from_serde: SamplingRuleConfig =
189 serde_json::from_str(r#"{"sample_rate": 0.0}"#).unwrap();
190 assert_eq!(from_serde, SamplingRuleConfig::default());
191 }
192
193 #[test]
194 fn test_sampling_rule_config_serde_default_provenance() {
195 let json = r#"{"sample_rate": 0.5}"#;
197 let config: SamplingRuleConfig = serde_json::from_str(json).unwrap();
198 assert_eq!(config.provenance, "default");
199 }
200
201 #[test]
202 fn test_sampling_rule_config_deserialize_full() {
203 let json = r#"{
204 "sample_rate": 0.5,
205 "service": "my-service",
206 "name": "http.*",
207 "resource": "/api/*",
208 "tags": {"env": "prod"},
209 "provenance": "customer"
210 }"#;
211 let config: SamplingRuleConfig = serde_json::from_str(json).unwrap();
212 assert_eq!(config.sample_rate, 0.5);
213 assert_eq!(config.service.as_deref(), Some("my-service"));
214 assert_eq!(config.name.as_deref(), Some("http.*"));
215 assert_eq!(config.resource.as_deref(), Some("/api/*"));
216 assert_eq!(config.tags.get("env").map(String::as_str), Some("prod"));
217 assert_eq!(config.provenance, "customer");
218 }
219
220 #[test]
221 fn test_sampling_rule_config_deserialize_minimal() {
222 let json = r#"{"sample_rate": 1.0}"#;
223 let config: SamplingRuleConfig = serde_json::from_str(json).unwrap();
224 assert_eq!(config.sample_rate, 1.0);
225 assert!(config.service.is_none());
226 assert_eq!(config.provenance, "default");
227 }
228
229 #[test]
230 fn test_sampling_rule_config_roundtrip() {
231 let original = SamplingRuleConfig {
232 sample_rate: 0.25,
233 service: Some("svc".into()),
234 name: Some("op".into()),
235 resource: Some("/res".into()),
236 tags: HashMap::from([("k".into(), "v".into())]),
237 provenance: "dynamic".into(),
238 };
239 let json = serde_json::to_string(&original).unwrap();
240 let restored: SamplingRuleConfig = serde_json::from_str(&json).unwrap();
241 assert_eq!(original, restored);
242 }
243
244 #[test]
245 fn test_sampling_rule_config_tags_accepts_map_shape() {
246 let json = r#"{
248 "sample_rate": 0.5,
249 "service": "svc",
250 "tags": {"env": "prod", "region": "us-east-1"}
251 }"#;
252 let cfg: SamplingRuleConfig = serde_json::from_str(json).unwrap();
253 assert_eq!(cfg.tags.get("env").map(String::as_str), Some("prod"));
254 assert_eq!(
255 cfg.tags.get("region").map(String::as_str),
256 Some("us-east-1")
257 );
258 }
259
260 #[test]
261 fn test_sampling_rule_config_tags_accepts_rc_list_shape() {
262 let json = r#"{
264 "sample_rate": 0.5,
265 "service": "svc",
266 "tags": [
267 {"key": "env", "value_glob": "prod"},
268 {"key": "region", "value_glob": "us-east-1"}
269 ]
270 }"#;
271 let cfg: SamplingRuleConfig = serde_json::from_str(json).unwrap();
272 assert_eq!(cfg.tags.get("env").map(String::as_str), Some("prod"));
273 assert_eq!(
274 cfg.tags.get("region").map(String::as_str),
275 Some("us-east-1")
276 );
277 }
278
279 #[test]
280 fn test_sampling_rule_config_tags_list_with_malformed_entry_rejects() {
281 let json = r#"{
286 "sample_rate": 0.5,
287 "tags": [
288 {"key": "env", "value_glob": "prod"},
289 {"key": "region"}
290 ]
291 }"#;
292 let result: Result<SamplingRuleConfig, _> = serde_json::from_str(json);
293 assert!(result.is_err(), "expected deserialization to fail");
294 }
295
296 #[test]
297 fn test_sampling_rule_config_tags_absent_defaults_to_empty() {
298 let json = r#"{"sample_rate": 0.5}"#;
299 let cfg: SamplingRuleConfig = serde_json::from_str(json).unwrap();
300 assert!(cfg.tags.is_empty());
301 }
302
303 #[test]
304 fn test_sampling_rule_config_display() {
305 let config = SamplingRuleConfig {
306 sample_rate: 1.0,
307 service: Some("svc".into()),
308 ..Default::default()
309 };
310 let s = config.to_string();
311 assert!(s.contains("sample_rate"));
312 assert!(s.contains("svc"));
313 }
314
315 #[test]
318 fn test_parsed_sampling_rules_empty_string() {
319 let parsed: ParsedSamplingRules = "".parse().unwrap();
320 assert!(parsed.rules.is_empty());
321 }
322
323 #[test]
324 fn test_parsed_sampling_rules_whitespace_only() {
325 let parsed: ParsedSamplingRules = " ".parse().unwrap();
326 assert!(parsed.rules.is_empty());
327 }
328
329 #[test]
330 fn test_parsed_sampling_rules_valid_json() {
331 let json = r#"[{"sample_rate": 0.5, "service": "svc"}, {"sample_rate": 1.0}]"#;
332 let parsed: ParsedSamplingRules = json.parse().unwrap();
333 assert_eq!(parsed.rules.len(), 2);
334 assert_eq!(parsed.rules[0].sample_rate, 0.5);
335 assert_eq!(parsed.rules[1].sample_rate, 1.0);
336 }
337
338 #[test]
339 fn test_parsed_sampling_rules_invalid_json() {
340 let result: Result<ParsedSamplingRules, _> = "not json".parse();
341 assert!(result.is_err());
342 }
343
344 #[test]
345 fn test_parsed_sampling_rules_deref() {
346 let json = r#"[{"sample_rate": 0.5}]"#;
347 let parsed: ParsedSamplingRules = json.parse().unwrap();
348 assert_eq!(parsed.len(), 1);
350 assert_eq!(parsed[0].sample_rate, 0.5);
351 }
352
353 #[test]
354 fn test_parsed_sampling_rules_into_vec() {
355 let json = r#"[{"sample_rate": 0.5}, {"sample_rate": 1.0}]"#;
356 let parsed: ParsedSamplingRules = json.parse().unwrap();
357 let vec: Vec<SamplingRuleConfig> = parsed.into();
358 assert_eq!(vec.len(), 2);
359 }
360
361 #[test]
362 fn test_parsed_sampling_rules_display() {
363 let json = r#"[{"sample_rate":0.5}]"#;
364 let parsed: ParsedSamplingRules = json.parse().unwrap();
365 let s = parsed.to_string();
366 assert!(s.contains("sample_rate"));
367 assert!(s.contains("0.5"));
368 }
369
370 #[test]
371 fn test_parsed_sampling_rules_default_is_empty() {
372 let parsed = ParsedSamplingRules::default();
373 assert!(parsed.rules.is_empty());
374 }
375}