use chrono::{NaiveDate, Utc};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
pub trait TimeSource: Send + Sync + std::fmt::Debug {
fn now(&self) -> Instant;
fn sleep(&self, duration: Duration);
fn today_date_string(&self) -> String;
fn elapsed_since(&self, earlier: Instant) -> Duration {
self.now().saturating_duration_since(earlier)
}
}
pub type SharedTimeSource = Arc<dyn TimeSource>;
#[derive(Debug, Clone, Copy, Default)]
pub struct RealTimeSource;
impl RealTimeSource {
pub fn new() -> Self {
Self
}
pub fn shared() -> SharedTimeSource {
Arc::new(Self)
}
}
impl TimeSource for RealTimeSource {
fn now(&self) -> Instant {
Instant::now()
}
fn sleep(&self, duration: Duration) {
std::thread::sleep(duration);
}
fn today_date_string(&self) -> String {
Utc::now().format("%Y-%m-%d").to_string()
}
}
#[derive(Debug)]
pub struct TestTimeSource {
logical_nanos: AtomicU64,
base_instant: Instant,
base_date: NaiveDate,
}
impl Default for TestTimeSource {
fn default() -> Self {
Self::new()
}
}
impl TestTimeSource {
pub fn new() -> Self {
Self {
logical_nanos: AtomicU64::new(0),
base_instant: Instant::now(),
base_date: Utc::now().date_naive(),
}
}
pub fn shared() -> Arc<Self> {
Arc::new(Self::new())
}
pub fn advance(&self, duration: Duration) {
self.logical_nanos
.fetch_add(duration.as_nanos() as u64, Ordering::SeqCst);
}
pub fn elapsed(&self) -> Duration {
Duration::from_nanos(self.logical_nanos.load(Ordering::SeqCst))
}
pub fn reset(&self) {
self.logical_nanos.store(0, Ordering::SeqCst);
}
pub fn nanos(&self) -> u64 {
self.logical_nanos.load(Ordering::SeqCst)
}
}
impl TimeSource for TestTimeSource {
fn now(&self) -> Instant {
self.base_instant + self.elapsed()
}
fn sleep(&self, duration: Duration) {
self.advance(duration);
}
fn today_date_string(&self) -> String {
let elapsed_days = (self.elapsed().as_secs() / 86400) as i64;
let current_date = self
.base_date
.checked_add_signed(chrono::Duration::days(elapsed_days))
.unwrap_or(self.base_date);
current_date.format("%Y-%m-%d").to_string()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn real_time_source_now_advances() {
let ts = RealTimeSource::new();
let t1 = ts.now();
std::thread::sleep(Duration::from_millis(1));
let t2 = ts.now();
assert!(t2 > t1);
}
#[test]
fn test_time_source_starts_at_zero() {
let ts = TestTimeSource::new();
assert_eq!(ts.nanos(), 0);
assert_eq!(ts.elapsed(), Duration::ZERO);
}
#[test]
fn test_time_source_advance() {
let ts = TestTimeSource::new();
let start = ts.now();
ts.advance(Duration::from_secs(5));
assert_eq!(ts.elapsed(), Duration::from_secs(5));
assert!(ts.elapsed_since(start) >= Duration::from_secs(5));
}
#[test]
fn test_time_source_sleep_advances_time() {
let ts = TestTimeSource::new();
let start = ts.now();
ts.sleep(Duration::from_millis(100));
assert_eq!(ts.elapsed(), Duration::from_millis(100));
assert!(ts.elapsed_since(start) >= Duration::from_millis(100));
}
#[test]
fn test_time_source_reset() {
let ts = TestTimeSource::new();
ts.advance(Duration::from_secs(10));
assert_eq!(ts.elapsed(), Duration::from_secs(10));
ts.reset();
assert_eq!(ts.elapsed(), Duration::ZERO);
}
#[test]
fn test_time_source_thread_safe() {
use std::thread;
let ts = Arc::new(TestTimeSource::new());
let ts_clone = ts.clone();
let handle = thread::spawn(move || {
for _ in 0..100 {
ts_clone.advance(Duration::from_millis(1));
}
});
for _ in 0..100 {
ts.advance(Duration::from_millis(1));
}
handle.join().unwrap();
assert_eq!(ts.elapsed(), Duration::from_millis(200));
}
#[test]
fn shared_time_source_works() {
let real: SharedTimeSource = RealTimeSource::shared();
let test: SharedTimeSource = TestTimeSource::shared();
let _ = real.now();
let _ = test.now();
}
}