1use serde::Deserialize;
4use serde::Deserializer;
5use serde::Serialize;
6use serde::de;
7use serde::de::Unexpected;
8use serde::de::Visitor;
9use std::collections::HashMap;
10use std::fmt;
11use tracing_core::Event;
12use tracing_core::{Metadata, subscriber::Interest};
13use tracing_subscriber::filter;
14use tracing_subscriber::layer;
15
16use super::ConfigError;
17
18#[derive(Default, Debug, Serialize, Copy, Clone, PartialEq, PartialOrd)]
20#[repr(u8)]
21pub enum TelemetryLevel {
22 OFF = 0,
24 ERROR = 1,
26 WARN = 2,
28 INFO = 3,
30 DEBUG = 4,
32 #[default]
34 TRACE = 5,
35}
36
37impl From<TelemetryLevel> for filter::LevelFilter {
38 fn from(val: TelemetryLevel) -> Self {
39 match val {
40 TelemetryLevel::OFF => filter::LevelFilter::OFF,
41 TelemetryLevel::ERROR => filter::LevelFilter::ERROR,
42 TelemetryLevel::WARN => filter::LevelFilter::WARN,
43 TelemetryLevel::INFO => filter::LevelFilter::INFO,
44 TelemetryLevel::DEBUG => filter::LevelFilter::DEBUG,
45 TelemetryLevel::TRACE => filter::LevelFilter::TRACE,
46 }
47 }
48}
49
50impl From<TelemetryLevel> for &str {
51 fn from(val: TelemetryLevel) -> Self {
52 match val {
53 TelemetryLevel::OFF => "off",
54 TelemetryLevel::ERROR => "error",
55 TelemetryLevel::WARN => "warn",
56 TelemetryLevel::INFO => "info",
57 TelemetryLevel::DEBUG => "debug",
58 TelemetryLevel::TRACE => "trace",
59 }
60 }
61}
62
63impl TryFrom<&str> for TelemetryLevel {
64 type Error = ConfigError;
65
66 fn try_from(value: &str) -> Result<Self, Self::Error> {
67 match value.to_lowercase().as_str() {
68 "off" => Ok(TelemetryLevel::OFF),
69 "error" => Ok(TelemetryLevel::ERROR),
70 "warn" => Ok(TelemetryLevel::WARN),
71 "info" => Ok(TelemetryLevel::INFO),
72 "debug" => Ok(TelemetryLevel::DEBUG),
73 "trace" => Ok(TelemetryLevel::TRACE),
74 _ => Err(ConfigError::WrongValue(
75 "TelemetryLevel".into(),
76 value.to_string(),
77 )),
78 }
79 }
80}
81
82impl Visitor<'_> for TelemetryLevel {
83 type Value = TelemetryLevel;
84
85 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
86 formatter.write_str(
87 "Telemetry Level from values: off[0], error[1], warn[2], info[3], debug[4], trace[5]",
88 )
89 }
90
91 fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
92 where
93 E: de::Error,
94 {
95 match s.to_lowercase().as_str() {
96 "off" => Ok(TelemetryLevel::OFF),
97 "error" => Ok(TelemetryLevel::ERROR),
98 "warn" => Ok(TelemetryLevel::WARN),
99 "info" => Ok(TelemetryLevel::INFO),
100 "debug" => Ok(TelemetryLevel::DEBUG),
101 "trace" => Ok(TelemetryLevel::TRACE),
102 _ => Err(de::Error::invalid_value(Unexpected::Str(s), &self)),
103 }
104 }
105
106 fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
107 where
108 E: de::Error,
109 {
110 match value {
111 0 => Ok(TelemetryLevel::OFF),
112 1 => Ok(TelemetryLevel::ERROR),
113 2 => Ok(TelemetryLevel::WARN),
114 3 => Ok(TelemetryLevel::INFO),
115 4 => Ok(TelemetryLevel::DEBUG),
116 5 => Ok(TelemetryLevel::TRACE),
117 _ => Err(de::Error::invalid_value(Unexpected::Signed(value), &self)),
118 }
119 }
120
121 fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
122 where
123 E: de::Error,
124 {
125 self.visit_i64(value as i64)
126 }
127}
128
129impl<'de> Deserialize<'de> for TelemetryLevel {
130 fn deserialize<D>(deserializer: D) -> Result<TelemetryLevel, D::Error>
131 where
132 D: Deserializer<'de>,
133 {
134 deserializer.deserialize_any(TelemetryLevel::default())
135 }
136}
137
138#[derive(Debug, Clone)]
156pub struct TelemetryFilter {
157 proc_levels: HashMap<String, filter::LevelFilter>,
158 pub(crate) level: filter::LevelFilter,
159}
160
161impl TelemetryFilter {
162 pub fn new(level: filter::LevelFilter) -> TelemetryFilter {
164 TelemetryFilter {
165 proc_levels: HashMap::new(),
166 level,
167 }
168 }
169
170 pub fn clone_with_level(&self, level: TelemetryLevel) -> TelemetryFilter {
172 let mut filter = self.clone();
173 let level: filter::LevelFilter = level.into();
174 if level < filter.level {
175 filter.level = level;
176 }
177
178 filter
179 }
180
181 pub fn add_proc_filter(&mut self, proc_name: String, level: filter::LevelFilter) {
183 self.proc_levels.insert(proc_name, level);
184 }
185
186 fn is_enabled(&self, metadata: &Metadata<'_>) -> bool {
187 let level = if let Some(value) = self.proc_levels.get(metadata.name()) {
188 value
189 } else if let Some(value) = self.proc_levels.get(metadata.target()) {
190 value
191 } else {
192 &self.level
193 };
194
195 metadata.level() <= level
196 }
197}
198
199impl Default for TelemetryFilter {
200 fn default() -> TelemetryFilter {
201 TelemetryFilter {
202 proc_levels: HashMap::new(),
203 level: filter::LevelFilter::TRACE,
204 }
205 }
206}
207
208impl<S> layer::Filter<S> for TelemetryFilter {
209 fn enabled(&self, metadata: &Metadata<'_>, _: &layer::Context<'_, S>) -> bool {
210 self.is_enabled(metadata)
211 }
212
213 fn callsite_enabled(&self, metadata: &'static Metadata<'static>) -> Interest {
214 if self.is_enabled(metadata) {
215 Interest::always()
216 } else {
217 Interest::never()
218 }
219 }
220
221 fn event_enabled(&self, event: &Event<'_>, _: &layer::Context<'_, S>) -> bool {
222 self.is_enabled(event.metadata())
223 }
224
225 fn max_level_hint(&self) -> Option<filter::LevelFilter> {
226 Some(self.level)
227 }
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233
234 #[test]
235 fn telemetry_level() {
236 assert!(
237 TelemetryLevel::try_from("warn").unwrap() < TelemetryLevel::INFO,
238 "{:?} < Info",
239 TelemetryLevel::try_from("warn")
240 );
241 assert_eq!(
242 "The config parameter TelemetryLevel have an incorrect value `wrong`".to_owned(),
243 TelemetryLevel::try_from("wrong").err().unwrap().to_string()
244 );
245
246 assert_eq!(
247 filter::LevelFilter::DEBUG,
248 filter::LevelFilter::from(TelemetryLevel::DEBUG)
249 );
250 }
251}