elfo_core/tracing/
validator.rs

1use std::{convert::TryFrom, time::Duration};
2
3use super::trace_id::{TraceId, TruncatedTime};
4use crate::{node, time};
5
6/// A [`TraceId`] validator.
7///
8/// By default, it checks the following properties:
9/// * Cannot be zero.
10/// * Must be 63-bit.
11/// * Cannot have the same node no, because it's sent outside the elfo system.
12///
13/// Optionally, it can also check the time difference, see
14/// [`TraceIdValidator::max_time_difference`].
15#[derive(Clone, Default)]
16pub struct TraceIdValidator {
17    max_time_difference: Option<Duration>,
18}
19
20// Errors
21const CANNOT_BE_ZERO: &str = "cannot be zero";
22const HIGHEST_BIT_MUST_BE_ZERO: &str = "highest bit must be zero";
23const INVALID_NODE_NO: &str = "invalid node no";
24const INVALID_TIMESTAMP: &str = "invalid timestamp";
25
26impl TraceIdValidator {
27    /// Allowed time difference between now and timestamp in a raw trace id.
28    /// Checks the absolute difference to handle both situations:
29    /// * Too old timestamp, possible incorrectly generated.
30    /// * Too new timestamp, something wrong with time synchronization.
31    pub fn max_time_difference(&mut self, time_lag: impl Into<Option<Duration>>) -> Self {
32        self.max_time_difference = time_lag.into();
33        self.clone()
34    }
35
36    /// Validates a raw trace id transforms it into [`TraceId`] if valid.
37    pub fn validate(&self, raw_trace_id: u64) -> Result<TraceId, &'static str> {
38        let trace_id = TraceId::try_from(raw_trace_id).map_err(|_| CANNOT_BE_ZERO)?;
39        let layout = trace_id.to_layout();
40
41        // The highest bit must be zero.
42        if raw_trace_id & (1 << 63) != 0 {
43            return Err(HIGHEST_BIT_MUST_BE_ZERO);
44        }
45
46        // We don't allow to specify valid `node_no` for now,
47        // but at least we can check that it isn't this node.
48        if layout.node_no == node::node_no() {
49            return Err(INVALID_NODE_NO);
50        }
51
52        if let Some(time_lag) = self.max_time_difference {
53            let truncated_now = TruncatedTime::from(time::now());
54            let delta = truncated_now.abs_delta(layout.timestamp);
55
56            if i64::from(delta) > time_lag.as_secs() as i64 {
57                return Err(INVALID_TIMESTAMP);
58            }
59        }
60
61        Ok(trace_id)
62    }
63}
64
65#[cfg(test)]
66mod tests {
67    use super::*;
68
69    #[test]
70    fn it_works() {
71        let validator = TraceIdValidator::default().max_time_difference(Duration::from_secs(5));
72
73        assert_eq!(validator.validate(0), Err(CANNOT_BE_ZERO));
74        assert_eq!(validator.validate(1 << 63), Err(HIGHEST_BIT_MUST_BE_ZERO));
75        assert_eq!(
76            validator.validate(u64::from(node::node_no()) << 22),
77            Err(INVALID_NODE_NO)
78        );
79
80        assert!(validator.validate(5 << 38).is_ok());
81        assert!(validator.validate(((1 << 25) - 5) << 38).is_ok());
82        assert_eq!(validator.validate(6 << 38), Err(INVALID_TIMESTAMP));
83        assert_eq!(
84            validator.validate(((1 << 25) - 6) << 38),
85            Err(INVALID_TIMESTAMP)
86        );
87
88        time::advance(Duration::from_secs(10));
89        assert_eq!(validator.validate(16 << 38), Err(INVALID_TIMESTAMP));
90        assert_eq!(validator.validate(4 << 38), Err(INVALID_TIMESTAMP));
91    }
92}