use std::{convert::TryFrom, time::Duration};
use super::trace_id::{TraceId, TruncatedTime};
#[derive(Clone, Default)]
pub struct TraceIdValidator {
max_time_difference: Option<Duration>,
}
const CANNOT_BE_ZERO: &str = "cannot be zero";
const HIGHEST_BIT_MUST_BE_ZERO: &str = "highest bit must be zero";
const INVALID_TIMESTAMP: &str = "invalid timestamp";
impl TraceIdValidator {
pub fn max_time_difference(&mut self, time_lag: impl Into<Option<Duration>>) -> Self {
self.max_time_difference = time_lag.into();
self.clone()
}
pub fn validate(&self, raw_trace_id: u64) -> Result<TraceId, &'static str> {
let trace_id = TraceId::try_from(raw_trace_id).map_err(|_| CANNOT_BE_ZERO)?;
let layout = trace_id.to_layout();
if raw_trace_id & (1 << 63) != 0 {
return Err(HIGHEST_BIT_MUST_BE_ZERO);
}
if let Some(time_lag) = self.max_time_difference {
let truncated_now = TruncatedTime::now();
let delta = truncated_now.abs_delta(layout.timestamp);
if i64::from(delta) > time_lag.as_secs() as i64 {
return Err(INVALID_TIMESTAMP);
}
}
Ok(trace_id)
}
}
#[test]
fn it_works() {
elfo_utils::time::with_mock(|mock| {
let validator = TraceIdValidator::default().max_time_difference(Duration::from_secs(5));
assert_eq!(validator.validate(0), Err(CANNOT_BE_ZERO));
assert_eq!(validator.validate(1 << 63), Err(HIGHEST_BIT_MUST_BE_ZERO));
assert!(validator.validate(5 << 38).is_ok());
assert!(validator.validate(((1 << 25) - 5) << 38).is_ok());
assert_eq!(validator.validate(6 << 38), Err(INVALID_TIMESTAMP));
assert_eq!(
validator.validate(((1 << 25) - 6) << 38),
Err(INVALID_TIMESTAMP)
);
mock.advance(Duration::from_secs(10));
assert_eq!(validator.validate(16 << 38), Err(INVALID_TIMESTAMP));
assert_eq!(validator.validate(4 << 38), Err(INVALID_TIMESTAMP));
});
}