1use std::sync::mpsc::{channel, RecvTimeoutError, Sender};
8use std::sync::{Arc, Mutex};
9use std::thread::{self, JoinHandle};
10use std::time::Duration;
11
12use once_cell::sync::OnceCell;
13
14mod data;
15pub use data::METRICS_TARGET_NAME;
16use data::*;
17
18mod recorder;
19use recorder::*;
20
21const AGGREGATION_PERIOD: Duration = Duration::from_secs(5);
23
24static GLOBAL_SINK: OnceCell<MetricsSink> = OnceCell::new();
26
27thread_local! {
28 static LOCAL_SINK: OnceCell<ThreadMetricsSinkHandle> = OnceCell::new();
36}
37
38#[derive(Debug)]
40pub struct MetricsSink {
41 threads: Arc<Mutex<Vec<Arc<Mutex<ThreadMetricsSink>>>>>,
42}
43
44impl MetricsSink {
45 pub fn init() -> MetricsSinkHandle {
54 let sink = Self::new();
55
56 let (tx, rx) = channel();
57
58 let publisher_thread = {
59 let threads = Arc::clone(&sink.threads);
60 thread::spawn(move || {
61 loop {
62 match rx.recv_timeout(AGGREGATION_PERIOD) {
63 Ok(()) | Err(RecvTimeoutError::Disconnected) => break,
64 Err(RecvTimeoutError::Timeout) => Self::aggregate_and_publish(&threads),
65 }
66 }
67 Self::aggregate_and_publish(&threads);
71 })
72 };
73
74 let handle = MetricsSinkHandle {
75 shutdown: tx,
76 handle: Some(publisher_thread),
77 };
78
79 sink.install();
80 metrics::set_recorder(&MetricsRecorder).unwrap();
81
82 handle
83 }
84
85 fn new() -> MetricsSink {
86 let threads = Arc::new(Mutex::new(Vec::new()));
87
88 MetricsSink { threads }
89 }
90
91 fn install(self) {
92 GLOBAL_SINK.set(self).unwrap();
93 }
94
95 fn aggregate_and_publish(threads: &Mutex<Vec<Arc<Mutex<ThreadMetricsSink>>>>) {
96 let metrics = Self::aggregate(threads);
97 Self::publish(metrics);
98 }
99
100 fn aggregate(threads: &Mutex<Vec<Arc<Mutex<ThreadMetricsSink>>>>) -> Metrics {
101 let mut aggregate_metrics = Metrics::default();
102 let threads = threads.lock().unwrap();
103 for thread in threads.iter() {
104 let metrics = std::mem::take(&mut *thread.lock().unwrap());
105 aggregate_metrics.aggregate(metrics.metrics);
106 }
107 aggregate_metrics
108 }
109
110 fn publish(metrics: Metrics) {
111 metrics.emit();
112 }
113}
114
115#[derive(Debug)]
116pub struct MetricsSinkHandle {
117 shutdown: Sender<()>,
118 handle: Option<JoinHandle<()>>,
119}
120
121impl MetricsSinkHandle {
122 pub fn shutdown(self) {
124 }
126}
127
128impl Drop for MetricsSinkHandle {
129 fn drop(&mut self) {
130 let _ = self.shutdown.send(());
131 if let Some(handle) = self.handle.take() {
132 let _ = handle.join();
133 }
134 }
135}
136
137#[derive(Debug, Default)]
138struct ThreadMetricsSink {
139 metrics: Metrics,
140}
141
142#[derive(Debug, Default)]
143struct ThreadMetricsSinkHandle {
144 inner: Arc<Mutex<ThreadMetricsSink>>,
145}
146
147impl ThreadMetricsSinkHandle {
148 pub(crate) fn with<F, T>(f: F) -> T
150 where
151 F: FnOnce(&ThreadMetricsSinkHandle) -> T,
152 {
153 LOCAL_SINK.with(|handle| {
154 let handle = handle.get_or_init(Self::init);
155 f(handle)
156 })
157 }
158
159 fn init() -> ThreadMetricsSinkHandle {
161 if let Some(global_sink) = GLOBAL_SINK.get() {
162 let me = Arc::new(Mutex::new(ThreadMetricsSink::default()));
163 global_sink.threads.lock().unwrap().push(Arc::clone(&me));
164 ThreadMetricsSinkHandle { inner: me }
165 } else {
166 panic!("global metrics sink must be installed first");
167 }
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174 use metrics::Label;
175 use tracing_test::traced_test;
176
177 #[traced_test]
178 #[test]
179 fn test_basic_metrics() {
180 let sink = MetricsSink::new();
181 let threads = Arc::clone(&sink.threads);
182
183 sink.install();
184 metrics::set_recorder(&MetricsRecorder).unwrap();
185
186 metrics::counter!("test_counter", 1, "type" => "foo");
187 metrics::counter!("test_counter", 1, "type" => "bar");
188 metrics::counter!("test_counter", 2, "type" => "foo");
189 metrics::counter!("test_counter", 2, "type" => "bar");
190 metrics::counter!("test_counter", 3, "type" => "foo");
191 metrics::counter!("test_counter", 4, "type" => "bar");
192
193 metrics::gauge!("test_gauge", 5.0, "type" => "foo");
194 metrics::gauge!("test_gauge", 5.0, "type" => "bar");
195 metrics::gauge!("test_gauge", 2.0, "type" => "foo");
196 metrics::gauge!("test_gauge", 3.0, "type" => "bar");
197
198 let metrics = MetricsSink::aggregate(&threads);
199 assert_eq!(metrics.iter().count(), 4);
200 for (key, data) in metrics.iter() {
201 assert_eq!(key.labels().count(), 1);
202 match data {
203 Metric::Counter(inner) => {
204 assert_eq!(key.name(), "test_counter");
205 assert_eq!(inner.n, 3);
206 let label = key.labels().next().unwrap();
207 if label == &Label::new("type", "foo") {
208 assert_eq!(inner.sum, 6);
209 } else if label == &Label::new("type", "bar") {
210 assert_eq!(inner.sum, 7);
211 } else {
212 panic!("wrong label");
213 }
214 }
215 Metric::Gauge(inner) => {
216 assert_eq!(key.name(), "test_gauge");
217 assert_eq!(inner.n, 1);
218 let label = key.labels().next().unwrap();
219 if label == &Label::new("type", "foo") {
220 assert_eq!(inner.sum, 2.0);
221 } else if label == &Label::new("type", "bar") {
222 assert_eq!(inner.sum, 3.0);
223 } else {
224 panic!("wrong label");
225 }
226 }
227 _ => panic!("wrong metric type"),
228 }
229 }
230
231 MetricsSink::publish(metrics);
232 assert!(logs_contain("test_counter"));
233 assert!(logs_contain("test_counter[type=bar]: 7 (n=3)"));
234 assert!(logs_contain("test_counter[type=foo]: 6 (n=3)"));
235 assert!(logs_contain("test_gauge[type=bar]: 3 (n=1)"));
236 assert!(logs_contain("test_gauge[type=foo]: 2 (n=1)"));
237 }
238}