1use chrono::{offset::TimeZone, DateTime, Duration, Utc};
95use either::Either;
96use priq::PriorityQueue;
97use std::sync::{Arc, Mutex, MutexGuard};
98use tokio::sync::oneshot::{channel, Receiver, Sender};
99
100#[derive(Clone, Debug)]
106pub struct Clock<Tz: TimeZone> {
107 inner: ClockInner<Tz>,
108 timezone: Tz,
109}
110
111impl Clock<Utc> {
112 pub fn new() -> Self {
114 Self::new_with_timezone(Utc)
115 }
116}
117
118impl<Tz: TimeZone> Clock<Tz>
119where
120 <Tz as TimeZone>::Offset: core::fmt::Display,
121{
122 pub fn new_with_timezone(timezone: Tz) -> Self {
124 Self {
125 inner: ClockInner::Wall,
126 timezone,
127 }
128 }
129
130 pub fn new_fake(start: DateTime<Tz>) -> Self {
135 Self {
136 timezone: start.timezone(),
137 inner: ClockInner::Fake(FakeClock::new(start)),
138 }
139 }
140
141 pub fn now(&self) -> DateTime<Tz> {
143 match &self.inner {
144 ClockInner::Wall => Utc::now().with_timezone(&self.timezone),
145 ClockInner::Fake(f) => f.now(),
146 }
147 }
148
149 pub fn is_fake(&self) -> bool {
155 match &self.inner {
156 ClockInner::Wall => false,
157 ClockInner::Fake(_) => true,
158 }
159 }
160
161 pub fn split(&self) -> Clock<Tz> {
181 Self {
182 inner: self.inner.split(),
183 timezone: self.timezone.clone(),
184 }
185 }
186
187 pub async fn advance(&mut self, duration: Duration) -> (DateTime<Tz>, Duration) {
191 match &mut self.inner {
192 ClockInner::Wall => panic!("Attempted to advance system clock"),
193 ClockInner::Fake(f) => f.advance(duration).await,
194 }
195 }
196
197 pub fn advance_blocking(&mut self, duration: Duration) -> (DateTime<Tz>, Duration) {
201 let r = match &mut self.inner {
202 ClockInner::Wall => panic!("Attempted to advance system clock"),
203 ClockInner::Fake(f) => f.advance_blocking(duration),
204 };
205 r
206 }
207
208 pub async fn sleep(&mut self, duration: Duration) {
210 match &mut self.inner {
211 ClockInner::Wall => tokio::time::sleep(duration.to_std().unwrap()).await,
212 ClockInner::Fake(f) => f.sleep(duration).await,
213 }
214 }
215
216 pub fn sleep_blocking(&mut self, duration: Duration) {
220 match &mut self.inner {
221 ClockInner::Wall => std::thread::sleep(duration.to_std().unwrap()),
222 ClockInner::Fake(f) => f.sleep_blocking(duration),
223 }
224 }
225}
226
227impl Default for Clock<Utc> {
228 fn default() -> Self {
229 Self::new()
230 }
231}
232
233#[derive(Clone, Debug)]
234enum ClockInner<Tz: TimeZone> {
235 Wall,
236 Fake(FakeClock<Tz>),
237}
238
239impl<Tz: TimeZone> ClockInner<Tz>
240where
241 <Tz as TimeZone>::Offset: core::fmt::Display,
242{
243 pub fn split(&self) -> Self {
244 match &self {
245 Self::Wall => Self::Wall,
246 Self::Fake(f) => Self::Fake(f.split()),
247 }
248 }
249}
250
251#[derive(Debug)]
252struct FakeInner<Tz: TimeZone> {
253 now: DateTime<Tz>,
254 sleepers: PriorityQueue<DateTime<Tz>, Sender<()>>,
255 advancer: Option<Sender<()>>,
256 threads: usize,
257}
258
259#[derive(Debug)]
260struct FakeGroup<Tz: TimeZone> {
261 inner: Arc<Mutex<FakeInner<Tz>>>,
262}
263
264impl<Tz: TimeZone> Clone for FakeGroup<Tz> {
265 fn clone(&self) -> Self {
266 let mut v = self.inner.lock().unwrap();
267 v.threads += 1;
268
269 Self {
270 inner: self.inner.clone(),
271 }
272 }
273}
274
275impl<Tz: TimeZone> Drop for FakeGroup<Tz> {
276 fn drop(&mut self) {
277 let mut v = self.inner.lock().unwrap();
278 v.threads -= 1;
279 if v.sleepers.len() + 1 == v.threads {
280 if let Some(advancer) = v.advancer.take() {
281 advancer.send(()).unwrap();
282 }
283 }
284 }
285}
286
287#[derive(Clone, Debug)]
288struct FakeClock<Tz: TimeZone> {
289 group: Arc<FakeGroup<Tz>>,
290}
291
292impl<Tz: TimeZone> FakeClock<Tz>
293where
294 <Tz as TimeZone>::Offset: core::fmt::Display,
295{
296 pub(crate) fn new(start: DateTime<Tz>) -> Self {
297 Self {
298 group: Arc::new(FakeGroup {
299 inner: Arc::new(Mutex::new(FakeInner {
300 now: start,
301 sleepers: Default::default(),
302 advancer: None,
303 threads: 1,
304 })),
305 }),
306 }
307 }
308
309 pub(crate) fn split(&self) -> Self {
310 Self {
311 group: Arc::new(FakeGroup::clone(&self.group)),
312 }
313 }
314
315 pub(crate) fn now(&self) -> DateTime<Tz> {
316 self.group.inner.lock().unwrap().now.clone()
317 }
318
319 fn do_advance(
320 &self,
321 mut v: MutexGuard<FakeInner<Tz>>,
322 duration: Duration,
323 ) -> (DateTime<Tz>, Duration) {
324 let start = v.now.clone();
325 let mut end = start.clone() + duration;
326 while let Some((time, _)) = v.sleepers.peek() {
327 if time <= &end {
328 let (time, sleeper) = v.sleepers.pop().unwrap();
329 sleeper.send(()).expect("Failed to wake sleeper");
330 end = time.clone();
331 } else {
332 break;
333 }
334 }
335 v.now = end.clone();
336 (end.clone(), end - start)
337 }
338
339 fn pre_advance(&self, duration: Duration) -> Either<(DateTime<Tz>, Duration), Receiver<()>> {
340 let mut v = self.group.inner.lock().unwrap();
341 if v.advancer.is_some() {
342 panic!("Cannot advance from two threads simultaneously");
343 }
344
345 match v.sleepers.len() + 1 {
346 x if x < v.threads => {
347 let (tx, rx) = channel();
348 v.advancer = Some(tx);
349 Either::Right(rx)
350 }
351 x if x == v.threads => Either::Left(self.do_advance(v, duration)),
352 _ => panic!("Too many threads"),
353 }
354 }
355
356 pub(crate) async fn advance(&self, duration: Duration) -> (DateTime<Tz>, Duration) {
357 let rx = match self.pre_advance(duration) {
358 Either::Left(d) => {
359 return d;
360 }
361 Either::Right(rx) => rx,
362 };
363 rx.await.unwrap();
364 let v = self.group.inner.lock().unwrap();
365 self.do_advance(v, duration)
366 }
367
368 pub(crate) fn advance_blocking(&self, duration: Duration) -> (DateTime<Tz>, Duration) {
369 let rx = match self.pre_advance(duration) {
370 Either::Left(d) => {
371 return d;
372 }
373 Either::Right(rx) => rx,
374 };
375 rx.blocking_recv().unwrap();
376 let v = self.group.inner.lock().unwrap();
377 self.do_advance(v, duration)
378 }
379
380 fn sleep_common(&mut self, duration: Duration) -> Receiver<()> {
381 let mut v = self.group.inner.lock().unwrap();
382 let (tx, rx) = channel();
383 let wake_time = v.now.clone() + duration;
384 v.sleepers.put(wake_time, tx);
385 if v.sleepers.len() + 1 == v.threads {
386 if let Some(advancer) = v.advancer.take() {
387 advancer.send(()).unwrap();
388 }
389 }
390 rx
391 }
392
393 pub(crate) fn sleep_blocking(&mut self, duration: Duration) {
394 let rx = self.sleep_common(duration);
395 rx.blocking_recv().expect("Failed to wake up")
396 }
397
398 pub(crate) async fn sleep(&mut self, duration: Duration) {
399 let rx = self.sleep_common(duration);
400 rx.await.expect("Failed to wake up")
401 }
402}
403
404#[cfg(test)]
405mod test {
406 use super::Clock;
407 use chrono::{DateTime, Duration, Utc};
408
409 #[test]
410 fn test_sync_wall_sleep() {
411 let mut c = Clock::new();
412
413 let start = c.now();
414 c.sleep_blocking(Duration::seconds(5));
415 let end = c.now();
416
417 let ns = ((end - start) - Duration::seconds(5))
418 .num_nanoseconds()
419 .unwrap();
420 assert!(
421 ns.abs() < 250_000,
422 "Slept for {} nanoseconds too many (duration {})",
423 ns,
424 (end - start)
425 );
426 }
427
428 #[tokio::test]
429 async fn test_async_wall_sleep() {
430 let mut c = Clock::new();
431
432 let start = c.now();
433 c.sleep(Duration::seconds(5)).await;
434 let end = c.now();
435
436 let ns = ((end - start) - Duration::seconds(5))
437 .num_nanoseconds()
438 .unwrap();
439 assert!(
440 ns.abs() < 2_000_000,
441 "Slept for {} nanoseconds too many (duration {})",
442 ns,
443 (end - start)
444 );
445 }
446
447 #[test]
448 fn test_sync_fake_sleep() {
449 let start = DateTime::parse_from_rfc2822("Mon, 8 Aug 2022 15:21:00 GMT")
450 .unwrap()
451 .with_timezone(&Utc);
452 let mut c = Clock::new_fake(start);
453
454 let mut cs = c.split();
455 std::thread::spawn(move || {
456 cs.sleep_blocking(Duration::seconds(5));
457 let end = cs.now();
458 assert_eq!(start + Duration::seconds(5), end);
459 });
460
461 assert_eq!(
462 c.advance_blocking(Duration::seconds(1)),
463 (
464 DateTime::parse_from_rfc2822("Mon, 8 Aug 2022 15:21:01 GMT")
465 .unwrap()
466 .with_timezone(&Utc),
467 Duration::seconds(1)
468 )
469 );
470
471 assert_eq!(
472 c.advance_blocking(Duration::seconds(5)),
473 (
474 DateTime::parse_from_rfc2822("Mon, 8 Aug 2022 15:21:05 GMT")
475 .unwrap()
476 .with_timezone(&Utc),
477 Duration::seconds(4)
478 )
479 );
480
481 assert_eq!(
482 c.advance_blocking(Duration::seconds(10)),
483 (
484 DateTime::parse_from_rfc2822("Mon, 8 Aug 2022 15:21:15 GMT")
485 .unwrap()
486 .with_timezone(&Utc),
487 Duration::seconds(10)
488 )
489 );
490 }
491
492 #[tokio::test]
493 async fn test_async_fake_sleep() {
494 let start = DateTime::parse_from_rfc2822("Mon, 8 Aug 2022 15:21:00 GMT")
495 .unwrap()
496 .with_timezone(&Utc);
497 let mut c = Clock::new_fake(start);
498
499 let mut cs = c.split();
500 tokio::spawn(async move {
501 cs.sleep(Duration::seconds(5)).await;
502 let end = cs.now();
503 assert_eq!(start + Duration::seconds(5), end);
504 });
505
506 assert_eq!(
507 c.advance(Duration::seconds(1)).await,
508 (
509 DateTime::parse_from_rfc2822("Mon, 8 Aug 2022 15:21:01 GMT")
510 .unwrap()
511 .with_timezone(&Utc),
512 Duration::seconds(1)
513 )
514 );
515
516 assert_eq!(
517 c.advance(Duration::seconds(5)).await,
518 (
519 DateTime::parse_from_rfc2822("Mon, 8 Aug 2022 15:21:05 GMT")
520 .unwrap()
521 .with_timezone(&Utc),
522 Duration::seconds(4)
523 )
524 );
525
526 assert_eq!(
527 c.advance(Duration::seconds(10)).await,
528 (
529 DateTime::parse_from_rfc2822("Mon, 8 Aug 2022 15:21:15 GMT")
530 .unwrap()
531 .with_timezone(&Utc),
532 Duration::seconds(10)
533 )
534 );
535 }
536}