1use crate::{
25 error::{FrozenError, FrozenResult},
26 hints,
27};
28use event_listener::{Event, EventListener, Listener};
29use std::{future, pin, ptr, sync, sync::atomic, task};
30
31pub type TEpoch = u64;
33
34#[derive(Debug)]
35struct AckError(atomic::AtomicPtr<FrozenError>);
36
37impl Default for AckError {
38 fn default() -> Self {
39 Self(atomic::AtomicPtr::new(ptr::null_mut()))
40 }
41}
42
43impl Drop for AckError {
44 fn drop(&mut self) {
45 let err_ptr = self.0.load(atomic::Ordering::Acquire);
46 if !err_ptr.is_null() {
47 let _ = unsafe { Box::from_raw(err_ptr) };
48 }
49 }
50}
51
52#[derive(Debug)]
72pub struct Completion {
73 current_epoch: atomic::AtomicU64,
74 durable_epoch: atomic::AtomicU64,
75 error: AckError,
76 event: Event,
77}
78
79impl Default for Completion {
80 fn default() -> Self {
81 Self {
82 current_epoch: atomic::AtomicU64::new(0),
83 durable_epoch: atomic::AtomicU64::new(0),
84 error: AckError::default(),
85 event: Event::new(),
86 }
87 }
88}
89
90impl Completion {
91 #[inline]
108 pub fn increment_current_epoch(&self) -> TEpoch {
109 self.current_epoch.fetch_add(1, atomic::Ordering::AcqRel).wrapping_add(1)
110 }
111
112 #[inline]
128 pub fn mark_epoch_as_durable(&self, epoch: TEpoch) {
129 self.durable_epoch.store(epoch, atomic::Ordering::Release);
130 }
131
132 #[inline]
143 pub fn get_err(&self) -> Option<FrozenError> {
144 let curr_err = self.error.0.load(atomic::Ordering::Acquire);
145 if hints::unlikely(!curr_err.is_null()) {
146 let frozen_error = unsafe { (*curr_err).clone() };
147 return Some(frozen_error);
148 }
149
150 None
151 }
152
153 #[inline]
167 pub fn set_err(&self, new_error: FrozenError) {
168 let boxed_error = Box::into_raw(Box::new(new_error));
169 let old_err = self.error.0.swap(boxed_error, atomic::Ordering::AcqRel);
170
171 if hints::unlikely(!old_err.is_null()) {
172 let _ = unsafe { Box::from_raw(old_err) };
173 }
174 }
175
176 #[inline]
193 pub fn del_err(&self) {
194 let old_err = self.error.0.swap(ptr::null_mut(), atomic::Ordering::AcqRel);
195 if hints::unlikely(!old_err.is_null()) {
196 let _ = unsafe { Box::from_raw(old_err) };
197 }
198 }
199
200 #[inline]
213 pub fn read_current_epoch(&self) -> TEpoch {
214 self.current_epoch.load(atomic::Ordering::Acquire)
215 }
216
217 #[inline]
230 pub fn read_durable_epoch(&self) -> TEpoch {
231 self.durable_epoch.load(atomic::Ordering::Acquire)
232 }
233
234 #[inline]
248 pub fn notify_all_listeners(&self) {
249 self.event.notify(usize::MAX);
250 }
251}
252
253#[derive(Debug)]
269pub struct AckTicket {
270 epoch: TEpoch,
271 completion: sync::Arc<Completion>,
272 listener: Option<EventListener>,
273}
274
275impl AckTicket {
276 #[inline]
290 pub const fn new(epoch: TEpoch, completion: sync::Arc<Completion>) -> Self {
291 Self { epoch, completion, listener: None }
292 }
293
294 #[inline(always)]
308 pub const fn epoch(&self) -> TEpoch {
309 self.epoch
310 }
311
312 #[inline(always)]
340 pub fn wait(&self) -> FrozenResult<TEpoch> {
341 loop {
342 if self.is_ready() {
343 return Ok(self.epoch);
344 }
345
346 if let Some(frozen_err) = self.completion.get_err() {
347 return Err(frozen_err);
348 }
349
350 let listener = self.completion.event.listen();
351
352 if self.is_ready() {
353 return Ok(self.epoch);
354 }
355
356 if let Some(err) = self.completion.get_err() {
357 return Err(err);
358 }
359
360 listener.wait();
361 }
362 }
363
364 #[inline]
365 fn is_ready(&self) -> bool {
366 self.completion.read_durable_epoch() >= self.epoch
367 }
368}
369
370impl future::Future for AckTicket {
371 type Output = FrozenResult<TEpoch>;
372
373 fn poll(mut self: pin::Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
374 loop {
375 if self.is_ready() {
376 return task::Poll::Ready(Ok(self.epoch));
377 }
378
379 if let Some(frozen_err) = self.completion.get_err() {
380 return task::Poll::Ready(Err(frozen_err));
381 }
382
383 if self.listener.is_none() {
384 self.listener = Some(self.completion.event.listen());
385
386 continue;
388 }
389
390 let listener = self.listener.as_mut().unwrap();
391 match pin::Pin::new(listener).poll(cx) {
392 task::Poll::Ready(()) => {
393 self.listener = None;
394
395 continue;
397 }
398
399 task::Poll::Pending => {
400 return task::Poll::Pending;
401 }
402 }
403 }
404 }
405}
406
407#[cfg(test)]
408mod tests {
409 use super::*;
410 use crate::error::ErrCode;
411 use std::{sync, thread, time};
412
413 mod completion {
414 use super::*;
415
416 #[test]
417 fn ok_increment_current_epoch() {
418 let completion = Completion::default();
419
420 assert_eq!(completion.increment_current_epoch(), 1);
421 assert_eq!(completion.increment_current_epoch(), 2);
422 assert_eq!(completion.increment_current_epoch(), 3);
423 }
424
425 #[test]
426 fn ok_mark_epoch_as_durable() {
427 let completion = Completion::default();
428 completion.mark_epoch_as_durable(0x0C);
429
430 assert_eq!(completion.read_durable_epoch(), 0x0C);
431 }
432
433 #[test]
434 fn ok_set_get_err() {
435 let completion = Completion::default();
436 let err = FrozenError::new(0x10, 0x20, ErrCode::new(0x30, "io"), "failure");
437 completion.set_err(err.clone());
438
439 assert_eq!(completion.get_err(), Some(err));
440 }
441
442 #[test]
443 fn ok_del_err() {
444 let completion = Completion::default();
445
446 completion.set_err(FrozenError::new(0x10, 0x20, ErrCode::new(0x30, "io"), "failure"));
447 assert!(completion.get_err().is_some());
448
449 completion.del_err();
450 assert!(completion.get_err().is_none());
451 }
452
453 #[test]
454 fn ok_set_err_overwrites_previous() {
455 let completion = Completion::default();
456
457 let err_1 = FrozenError::new(0x10, 0x20, ErrCode::new(0x30, "io"), "first");
458 let err_2 = FrozenError::new(0x11, 0x21, ErrCode::new(0x31, "sync"), "second");
459
460 completion.set_err(err_1);
461 completion.set_err(err_2.clone());
462
463 assert_eq!(completion.get_err(), Some(err_2));
464 }
465 }
466
467 mod ack_ticket {
468 use super::*;
469
470 #[test]
471 fn ok_new() {
472 let completion = sync::Arc::new(Completion::default());
473 let ticket = AckTicket::new(0x23, completion);
474
475 assert_eq!(ticket.epoch(), 0x23);
476 }
477
478 #[test]
479 fn ok_await_when_epoch_already_durable() {
480 let completion = sync::Arc::new(Completion::default());
481 completion.mark_epoch_as_durable(0x0A);
482
483 let ticket = AckTicket::new(0x0A, completion);
484 let durable_epoch = futures::executor::block_on(ticket).expect("ticket must complete");
485
486 assert_eq!(durable_epoch, 0x0A);
487 }
488
489 #[test]
490 fn ok_await_after_durability_progress() {
491 let completion = sync::Arc::new(Completion::default());
492 let ticket = AckTicket::new(1, completion.clone());
493
494 thread::spawn({
495 let completion = completion.clone();
496
497 move || {
498 thread::sleep(time::Duration::from_millis(0x0A));
499
500 completion.mark_epoch_as_durable(1);
501 completion.notify_all_listeners();
502 }
503 });
504
505 let durable_epoch = futures::executor::block_on(ticket).expect("ticket must complete");
506
507 assert_eq!(durable_epoch, 1);
508 }
509
510 #[test]
511 fn err_await_when_error_is_present() {
512 let completion = sync::Arc::new(Completion::default());
513 let expected_error = FrozenError::new(0x10, 0x20, ErrCode::new(0x30, "io"), "failure");
514
515 completion.set_err(expected_error.clone());
516
517 let ticket = AckTicket::new(1, completion);
518 let err = futures::executor::block_on(ticket).expect_err("ticket must fail");
519
520 assert_eq!(err, expected_error);
521 }
522
523 #[test]
524 fn err_await_when_error_arrives_later() {
525 let completion = sync::Arc::new(Completion::default());
526 let ticket = AckTicket::new(1, completion.clone());
527 let expected_error = FrozenError::new(0x10, 0x20, ErrCode::new(0x30, "io"), "failure");
528
529 thread::spawn({
530 let completion = completion.clone();
531 let expected_error = expected_error.clone();
532
533 move || {
534 thread::sleep(time::Duration::from_millis(0x0A));
535
536 completion.set_err(expected_error);
537 completion.notify_all_listeners();
538 }
539 });
540
541 let err = futures::executor::block_on(ticket).expect_err("ticket must fail");
542 assert_eq!(err, expected_error);
543 }
544
545 #[test]
546 fn ok_multiple_tickets_waiting_for_same_epoch() {
547 let completion = sync::Arc::new(Completion::default());
548
549 let ticket_1 = AckTicket::new(1, completion.clone());
550 let ticket_2 = AckTicket::new(1, completion.clone());
551 let ticket_3 = AckTicket::new(1, completion.clone());
552
553 thread::spawn({
554 let completion = completion.clone();
555 move || {
556 thread::sleep(time::Duration::from_millis(0x0A));
557
558 completion.mark_epoch_as_durable(1);
559 completion.notify_all_listeners();
560 }
561 });
562
563 assert_eq!(futures::executor::block_on(ticket_1).expect("ticket_1 must complete"), 1);
564 assert_eq!(futures::executor::block_on(ticket_2).expect("ticket_2 must complete"), 1);
565 assert_eq!(futures::executor::block_on(ticket_3).expect("ticket_3 must complete"), 1);
566 }
567
568 #[test]
569 fn ok_multiple_epochs_complete_in_order() {
570 let completion = sync::Arc::new(Completion::default());
571
572 let ticket_1 = AckTicket::new(1, completion.clone());
573 let ticket_2 = AckTicket::new(2, completion.clone());
574 let ticket_3 = AckTicket::new(3, completion.clone());
575
576 completion.mark_epoch_as_durable(3);
577
578 assert_eq!(futures::executor::block_on(ticket_1).expect("ticket_1 must complete"), 1);
579 assert_eq!(futures::executor::block_on(ticket_2).expect("ticket_2 must complete"), 2);
580 assert_eq!(futures::executor::block_on(ticket_3).expect("ticket_3 must complete"), 3);
581 }
582 }
583
584 mod ticket_wait {
585 use super::*;
586
587 #[test]
588 fn ok_wait_when_epoch_already_durable() {
589 let completion = sync::Arc::new(Completion::default());
590 completion.mark_epoch_as_durable(1);
591
592 let ticket = AckTicket::new(1, completion);
593 assert_eq!(ticket.wait().expect("ticket must complete"), 1);
594 }
595
596 #[test]
597 fn ok_wait_after_durability_progress() {
598 let completion = sync::Arc::new(Completion::default());
599 let ticket = AckTicket::new(1, completion.clone());
600
601 thread::spawn({
602 let completion = completion.clone();
603 move || {
604 thread::sleep(time::Duration::from_millis(10));
605
606 completion.mark_epoch_as_durable(1);
607 completion.notify_all_listeners();
608 }
609 });
610
611 assert_eq!(ticket.wait().expect("ticket must complete"), 1);
612 }
613
614 #[test]
615 fn err_wait_when_error_is_present() {
616 let completion = sync::Arc::new(Completion::default());
617 let expected = FrozenError::new(0x10, 0x20, ErrCode::new(0x30, "io"), "failure");
618
619 completion.set_err(expected.clone());
620
621 let ticket = AckTicket::new(1, completion);
622 assert_eq!(ticket.wait().expect_err("ticket must fail"), expected);
623 }
624
625 #[test]
626 fn err_wait_when_error_arrives_later() {
627 let completion = sync::Arc::new(Completion::default());
628 let ticket = AckTicket::new(1, completion.clone());
629 let expected = FrozenError::new(0x10, 0x20, ErrCode::new(0x30, "io"), "failure");
630
631 thread::spawn({
632 let completion = completion.clone();
633 let expected = expected.clone();
634
635 move || {
636 thread::sleep(time::Duration::from_millis(10));
637
638 completion.set_err(expected);
639 completion.notify_all_listeners();
640 }
641 });
642
643 assert_eq!(ticket.wait().expect_err("ticket must fail"), expected);
644 }
645 }
646}