redis_objects/
counters.rs

1//! Objects and helpers for publishing metrics in an efficent manner. 
2
3use std::{borrow::BorrowMut, sync::Arc};
4use std::marker::PhantomData;
5use std::time::Duration;
6
7use log::{error, info};
8use rand::Rng;
9use redis::AsyncCommands;
10use serde::Serialize;
11use parking_lot::Mutex;
12use serde_json::json;
13
14use crate::{retry_call, ErrorTypes, RedisObjects};
15
16/// Trait for metric messages being exported 
17pub trait MetricMessage: Serialize + Default + Send + Sync + 'static {}
18impl<T: Serialize + Default + Send + Sync + 'static> MetricMessage for T {}
19
20/// A builder to help configure a metrics counter that exports regularly to redis
21/// 
22/// This struct also acts as the internal config object for the counter once built.
23pub struct AutoExportingMetricsBuilder<Message: MetricMessage> {
24    channel_name: String,
25    counter_name: Option<String>,
26    counter_type: String,
27    host: String,
28    store: Arc<RedisObjects>,
29    data_type: PhantomData<Message>,
30    export_zero: bool,
31    export_interval: Duration,
32
33    /// Notification that wakes up the background task causing it to export early
34    export_notify: tokio::sync::Notify,
35}
36
37impl<Message: MetricMessage> AutoExportingMetricsBuilder<Message> {
38
39    pub (crate) fn new(store: Arc<RedisObjects>, channel_name: String, counter_type: String) -> Self {
40        Self {
41            channel_name,
42            counter_name: None,
43            counter_type,
44            host: format!("{:x}", rand::rng().random::<u128>()),
45            store,
46            export_zero: true,
47            export_interval: Duration::from_secs(5),
48            data_type: Default::default(),
49            export_notify: tokio::sync::Notify::new(),
50        }
51    }
52
53    /// Set the name field for this counter
54    pub fn counter_name(mut self, value: String) -> Self {
55        self.counter_name = Some(value); self
56    }
57
58    /// Set the hostname, otherwise a random id is used.
59    pub fn host(mut self, value: String) -> Self {
60        self.host = value; self
61    }
62
63    /// Set the export interval 
64    pub fn export_interval(mut self, value: Duration) -> Self {
65        self.export_interval = value; self
66    }
67
68    /// Configure if messages should be sent when no content has been added
69    pub fn export_zero(mut self, value: bool) -> Self {
70        self.export_zero = value; self
71    }
72
73    /// Launch the auto exporting process and return a handle for incrementing metrics
74    pub fn start(self) -> AutoExportingMetrics<Message> {
75        let current = Arc::new(Mutex::new(Message::default()));
76        let metrics = AutoExportingMetrics{
77            config: Arc::new(self),
78            current,
79        };
80
81        // start the background exporter
82        metrics.clone().exporter();
83
84        // return the original as metric interface 
85        metrics
86    }   
87
88    // /// build an empty message with current exporter settings
89    // fn empty_message(&self) -> Message {
90    //     let counter_name = match &self.counter_name {
91    //         Some(name) => name,
92    //         None => &self.counter_type,
93    //     };
94
95    //     Message::new(&self.counter_type, counter_name, &self.host)
96    // }
97}
98
99/// Increase the field given, by default incrementing by 1
100#[macro_export]
101macro_rules! increment {
102    ($counter:expr, $field:ident) => {
103        increment!($counter, $field, 1)
104    };
105    ($counter:expr, $field:ident, $value:expr) => {
106        $counter.lock().$field += $value
107    };
108    (timer, $counter:expr, $field:ident) => {
109        increment!(timer, $counter, $field, 0.0)
110    };
111    (timer, $counter:expr, $field:ident, $value:expr) => {
112        $counter.lock().$field.increment($value)
113    };
114}
115pub use increment;
116
117
118/// A wrapper around a Message class that adds periodic backup.
119///
120/// At the specified interval and (best efforts) program exit, the current message will be 
121/// exported to the given channel and reset with the message Default.
122pub struct AutoExportingMetrics<Message: MetricMessage> {
123    config: Arc<AutoExportingMetricsBuilder<Message>>,
124    current: Arc<Mutex<Message>>
125}
126
127impl<Message: MetricMessage> Clone for AutoExportingMetrics<Message> {
128    fn clone(&self) -> Self {
129        Self { config: self.config.clone(), current: self.current.clone() }
130    }
131}
132
133impl<Message: MetricMessage> AutoExportingMetrics<Message> {
134    /// Launch the background export worker
135    fn exporter(mut self) {
136        tokio::spawn(async move {
137            while let Err(err) = self.export_loop().await {
138                error!("Error in metrics exporter {}: {}", self.config.counter_type, err);
139                tokio::time::sleep(Duration::from_secs(5)).await;
140            }
141        });
142    }
143
144    fn is_zero(&self, obj: &serde_json::Value) -> bool {
145        if let Some(number) = obj.as_i64() {
146            if number == 0 {
147                return true
148            }
149        } 
150        if let Some(number) = obj.as_u64() {
151            if number == 0 {
152                return true
153            }
154        } 
155        if let Some(number) = obj.as_f64() {
156            if number == 0.0 {
157                return true
158            }
159        } 
160        false 
161    }
162
163    fn is_all_zero(&self, obj: &serde_json::Value) -> bool {
164        if let Some(obj) = obj.as_object() {
165            for value in obj.values() {
166                if !self.is_zero(value) {
167                    return false
168                }
169            }
170            true
171        } else {
172            false
173        }
174    }
175
176    async fn export_once(&mut self) -> Result<(), ErrorTypes> {
177        // Fetch the message that needs to be sent
178        let outgoing = self.reset();
179
180        // create mapping
181        let mut outgoing = serde_json::to_value(&outgoing)?;
182
183        // check if we will export this message
184        if self.config.export_zero || !self.is_all_zero(&outgoing) {
185            // add extra fields
186            if let Some(obj) = outgoing.as_object_mut() {
187                obj.insert("type".to_owned(), json!(self.config.counter_type));
188                obj.insert("name".to_owned(), json!(self.config.counter_name));
189                obj.insert("host".to_owned(), json!(self.config.host));
190            }
191
192            // send the message
193            let data = serde_json::to_string(&outgoing)?;
194            let _recievers: u32 = retry_call!(self.config.store.pool, publish, self.config.channel_name.as_str(), data.as_str())?;                
195        }
196        Ok(())
197    }
198
199    async fn export_loop(&mut self) -> Result<(), ErrorTypes> {
200        loop {
201            // wait for the configured duration (or we get notified to do it now)
202            let _ = tokio::time::timeout(self.config.export_interval, self.config.export_notify.notified()).await;
203            self.export_once().await?;
204
205            // check if the public object has been dropped
206            if Arc::strong_count(&self.current) == 1 {
207                info!("Stopping metrics exporter: {}", self.config.channel_name);
208                self.export_once().await?; // make sure we report any last minute messages that happened during/after the last export
209                return Ok(())
210            }
211        }
212    }
213
214    /// Get a writeable guard holding the message that will next be exported 
215    /// Rather than using this directly the increment macro can be used
216    pub fn lock(&self) -> parking_lot::MutexGuard<Message> {
217        self.current.lock()
218    }
219
220    /// Replace the current outgoing message with an empty one 
221    /// returns the replaced message
222    pub fn reset(&self) -> Message {
223        let mut message: Message = Default::default();
224        std::mem::swap(&mut message, self.current.lock().borrow_mut());
225        message
226    }
227
228    /// Trigger the background task to export immediately
229    pub fn export(&self) {
230        self.config.export_notify.notify_one()
231    }
232
233//     def set(self, name, value):
234//         try:
235//             if name not in self.counter_schema:
236//                 raise ValueError(f"{name} is not an accepted counter for this module: f{self.counter_schema}")
237//             with self.lock:
238//                 self.values[name] = value
239//                 return value
240//         except Exception:  # Don't let increment fail anything.
241//             log.exception("Setting Metric")
242//             return 0
243
244//     def increment_execution_time(self, name, execution_time):
245//         try:
246//             if name not in self.timer_schema:
247//                 raise ValueError(f"{name} is not an accepted counter for this module: f{self.timer_schema}")
248//             with self.lock:
249//                 self.counts[name + ".c"] += 1
250//                 self.counts[name + ".t"] += execution_time
251//                 return execution_time
252//         except Exception:  # Don't let increment fail anything.
253//             log.exception("Incrementing counter")
254//             return 0
255
256
257
258}
259
260impl<M: MetricMessage> Drop for AutoExportingMetrics<M> {
261    fn drop(&mut self) {
262        if Arc::strong_count(&self.current) <= 2 {
263            self.export()
264        }
265    }
266}
267
268
269#[cfg(test)]
270fn init() {
271    let _ = env_logger::builder().filter_level(log::LevelFilter::Debug).is_test(true).try_init();
272}
273
274#[tokio::test]
275async fn auto_exporting_counter() {
276    use log::info;
277    init();
278
279    use serde::Deserialize;
280    use crate::test::redis_connection;
281    let connection = redis_connection().await;
282    info!("redis connected");
283
284    #[derive(Debug, Serialize, Deserialize, PartialEq, Default)]
285    struct MetricKind {
286        started: u64,
287        finished: u64,
288    }
289
290    // Subscribe on the pubsub being used
291    let mut subscribe = connection.subscribe_json::<MetricKind>("test_metrics_channel".to_owned()).await;
292
293    {   
294        info!("Fast export");
295        // setup an exporter that sends metrics automatically very fast
296        let counter = connection.auto_exporting_metrics::<MetricKind>("test_metrics_channel".to_owned(), "component-x".to_owned())
297            .export_interval(Duration::from_micros(10))
298            .export_zero(false)
299            .start();
300
301        // Send a non default quantity via timer
302        increment!(counter, started, 5);
303        info!("Waiting for export");
304        assert_eq!(subscribe.recv().await.unwrap().unwrap(), MetricKind{started: 5, finished: 0});
305    }
306
307    {   
308        info!("slow export");
309        // setup a slow export
310        let counter = connection.auto_exporting_metrics::<MetricKind>("test_metrics_channel".to_owned(), "component-x".to_owned())
311            .export_interval(Duration::from_secs(1000))
312            .export_zero(false)
313            .start();
314
315        // set some quantities then erase them
316        increment!(counter, started);
317        increment!(counter, finished);
318        counter.reset();
319
320        // Send a default quantities explicity
321        increment!(counter, started);
322        increment!(counter, started);
323        increment!(counter, finished);
324        counter.export();
325        assert_eq!(subscribe.recv().await.unwrap().unwrap(), MetricKind{started: 2, finished: 1});
326
327        // send a message and let the drop signal an export
328        increment!(counter, finished, 5);
329        increment!(counter, finished);
330    }
331
332    let result = tokio::time::timeout(Duration::from_secs(10), subscribe.recv()).await.unwrap();
333    assert_eq!(result.unwrap().unwrap(), MetricKind{started: 0, finished: 6});
334}
335
336    
337    // # noinspection PyShadowingNames
338    // def test_basic_counters(redis_connection):
339    //     if redis_connection:
340    //         from assemblyline.remote.datatypes.counters import Counters
341    //         with Counters('test-counter') as ct:
342    //             ct.delete()
343    
344    //             for x in range(10):
345    //                 ct.inc('t1')
346    //             for x in range(20):
347    //                 ct.inc('t2', value=2)
348    //             ct.dec('t1')
349    //             ct.dec('t2')
350    //             assert sorted(ct.get_queues()) == ['test-counter-t1',
351    //                                                'test-counter-t2']
352    //             assert ct.get_queues_sizes() == {'test-counter-t1': 9,
353    //                                              'test-counter-t2': 39}
354    //             ct.reset_queues()
355    //             assert ct.get_queues_sizes() == {'test-counter-t1': 0,
356    //                                              'test-counter-t2': 0}
357    
358    
359    // # noinspection PyShadowingNames
360    // def test_tracked_counters(redis_connection):
361    //     if redis_connection:
362    //         from assemblyline.remote.datatypes.counters import Counters
363    //         with Counters('tracked-test-counter', track_counters=True) as ct:
364    //             ct.delete()
365    
366    //             for x in range(10):
367    //                 ct.inc('t1')
368    //             for x in range(20):
369    //                 ct.inc('t2', value=2)
370    //             assert ct.tracker.keys() == ['t1', 't2']
371    //             ct.dec('t1')
372    //             ct.dec('t2')
373    //             assert ct.tracker.keys() == []
374    //             assert sorted(ct.get_queues()) == ['tracked-test-counter-t1',
375    //                                                'tracked-test-counter-t2']
376    //             assert ct.get_queues_sizes() == {'tracked-test-counter-t1': 9,
377    //                                              'tracked-test-counter-t2': 39}
378    //             ct.reset_queues()
379    //             assert ct.get_queues_sizes() == {'tracked-test-counter-t1': 0,
380    //                                              'tracked-test-counter-t2': 0}