1use std::marker::PhantomData;
13use std::panic::{
14 AssertUnwindSafe,
15 catch_unwind,
16};
17use std::sync::{
18 Arc,
19 Mutex,
20};
21use std::time::Duration;
22
23use qubit_atomic::AtomicRef;
24use qubit_error::BoxError;
25use qubit_function::{
26 Consumer,
27 Function,
28};
29use qubit_retry::{
30 AttemptFailure,
31 AttemptFailureDecision,
32 Retry,
33 RetryContext,
34 RetryError,
35 RetryOptions,
36};
37
38use crate::cas_decision::CasDecision;
39use crate::cas_outcome::CasOutcome;
40use crate::cas_success::CasSuccess;
41use crate::error::{
42 CasAttemptFailure,
43 CasError,
44 CasErrorKind,
45};
46use crate::event::{
47 CasContext,
48 CasEvent,
49 CasHooks,
50};
51use crate::observability::{
52 CasAlert,
53 CasObservabilityConfig,
54 CasObservabilityMode,
55 ListenerPanicPolicy,
56};
57use crate::options::CasTimeoutPolicy;
58use crate::report::{
59 CasExecutionOutcome,
60 CasExecutionReport,
61 CasReportBuilder,
62};
63use crate::strategy::CasStrategy;
64
65use super::cas_builder::CasBuilder;
66
67#[derive(Debug, Clone)]
69pub struct CasExecutor<T, E = BoxError> {
70 options: RetryOptions,
72 attempt_timeout: Option<Duration>,
74 timeout_policy: CasTimeoutPolicy,
76 observability: CasObservabilityConfig,
78 marker: PhantomData<fn() -> (T, E)>,
80}
81
82enum AttemptSuccess<T, R> {
84 Updated {
86 previous: Arc<T>,
87 current: Arc<T>,
88 output: R,
89 },
90 Finished { current: Arc<T>, output: R },
92}
93
94struct CasReportFinishContext {
96 attempts_total: u32,
97 max_attempts: u32,
98 max_operation_elapsed: Option<Duration>,
99 max_total_elapsed: Option<Duration>,
100 outcome: CasExecutionOutcome,
101}
102
103impl CasReportFinishContext {
104 #[inline]
105 fn new(
106 attempts_total: u32,
107 max_attempts: u32,
108 max_operation_elapsed: Option<Duration>,
109 max_total_elapsed: Option<Duration>,
110 outcome: CasExecutionOutcome,
111 ) -> Self {
112 Self {
113 attempts_total,
114 max_attempts,
115 max_operation_elapsed,
116 max_total_elapsed,
117 outcome,
118 }
119 }
120}
121
122impl<T, E> CasExecutor<T, E> {
123 #[inline]
128 pub fn builder() -> CasBuilder<T, E> {
129 CasBuilder::new()
130 }
131
132 pub fn from_options(options: RetryOptions) -> Result<Self, qubit_retry::RetryConfigError> {
143 Self::builder().options(options).build()
144 }
145
146 pub fn latency_first() -> Self {
151 Self::builder()
152 .build_latency_first()
153 .expect("latency-first CAS strategy must be valid")
154 }
155
156 pub fn contention_adaptive() -> Self {
161 Self::builder()
162 .build_contention_adaptive()
163 .expect("contention-adaptive CAS strategy must be valid")
164 }
165
166 pub fn reliability_first() -> Self {
171 Self::builder()
172 .build_reliability_first()
173 .expect("reliability-first CAS strategy must be valid")
174 }
175
176 pub fn with_strategy(strategy: CasStrategy) -> Self {
184 Self::builder()
185 .strategy(strategy)
186 .build()
187 .expect("built-in CAS strategy must be valid")
188 }
189
190 #[inline]
201 pub(crate) fn new(
202 options: RetryOptions,
203 attempt_timeout: Option<Duration>,
204 timeout_policy: CasTimeoutPolicy,
205 observability: CasObservabilityConfig,
206 ) -> Self {
207 Self {
208 options,
209 attempt_timeout,
210 timeout_policy,
211 observability,
212 marker: PhantomData,
213 }
214 }
215
216 #[inline]
221 pub fn options(&self) -> &RetryOptions {
222 &self.options
223 }
224
225 #[inline]
230 pub fn attempt_timeout(&self) -> Option<Duration> {
231 self.attempt_timeout
232 }
233
234 #[inline]
239 pub fn timeout_policy(&self) -> CasTimeoutPolicy {
240 self.timeout_policy
241 }
242
243 #[inline]
248 pub fn observability(&self) -> &CasObservabilityConfig {
249 &self.observability
250 }
251
252 pub fn execute<R, O>(&self, state: &AtomicRef<T>, operation: O) -> CasOutcome<T, R, E>
262 where
263 T: 'static,
264 E: 'static,
265 O: Function<T, CasDecision<T, R, E>>,
266 {
267 self.execute_with_hooks(state, operation, CasHooks::new())
268 }
269
270 pub fn execute_with_hooks<R, O>(
281 &self,
282 state: &AtomicRef<T>,
283 operation: O,
284 hooks: CasHooks,
285 ) -> CasOutcome<T, R, E>
286 where
287 T: 'static,
288 E: 'static,
289 O: Function<T, CasDecision<T, R, E>>,
290 {
291 let success_context = Arc::new(Mutex::new(None));
292 let report_builder = Arc::new(Mutex::new(CasReportBuilder::start()));
293 self.emit_started(&hooks, &report_builder);
294 let retry = self.build_retry(
295 &hooks,
296 Arc::clone(&success_context),
297 Arc::clone(&report_builder),
298 );
299 let attempt = retry.run(|| self.run_sync_attempt(state, &operation));
300 self.finish_execution(attempt, hooks, success_context, report_builder)
301 }
302
303 #[cfg(feature = "tokio")]
312 pub async fn execute_async<R, O, Fut>(
313 &self,
314 state: &AtomicRef<T>,
315 operation: O,
316 ) -> CasOutcome<T, R, E>
317 where
318 T: 'static,
319 E: 'static,
320 O: Fn(Arc<T>) -> Fut,
321 Fut: std::future::Future<Output = CasDecision<T, R, E>>,
322 {
323 self.execute_async_with_hooks(state, operation, CasHooks::new())
324 .await
325 }
326
327 #[cfg(feature = "tokio")]
337 pub async fn execute_async_with_hooks<R, O, Fut>(
338 &self,
339 state: &AtomicRef<T>,
340 operation: O,
341 hooks: CasHooks,
342 ) -> CasOutcome<T, R, E>
343 where
344 T: 'static,
345 E: 'static,
346 O: Fn(Arc<T>) -> Fut,
347 Fut: std::future::Future<Output = CasDecision<T, R, E>>,
348 {
349 let success_context = Arc::new(Mutex::new(None));
350 let report_builder = Arc::new(Mutex::new(CasReportBuilder::start()));
351 self.emit_started(&hooks, &report_builder);
352 let retry = self.build_retry(
353 &hooks,
354 Arc::clone(&success_context),
355 Arc::clone(&report_builder),
356 );
357 let attempt = retry
358 .run_async(|| self.run_async_attempt(state, &operation))
359 .await;
360 self.finish_execution(attempt, hooks, success_context, report_builder)
361 }
362
363 fn build_retry(
373 &self,
374 hooks: &CasHooks,
375 success_context: Arc<Mutex<Option<RetryContext>>>,
376 report_builder: Arc<Mutex<CasReportBuilder>>,
377 ) -> Retry<CasAttemptFailure<T, E>>
378 where
379 T: 'static,
380 E: 'static,
381 {
382 let event_hook = hooks.event_hook();
383 let timeout_policy = self.timeout_policy;
384 let attempt_timeout = self.attempt_timeout;
385 let observability = self.observability.clone();
386
387 let mut builder = Retry::<CasAttemptFailure<T, E>>::builder()
388 .options(self.options.clone())
389 .on_success(move |context: &RetryContext| {
390 *success_context
391 .lock()
392 .expect("CAS success context slot should be lockable") = Some(*context);
393 })
394 .on_failure(
395 move |failure: &AttemptFailure<CasAttemptFailure<T, E>>, context: &RetryContext| {
396 let failure = match failure {
397 AttemptFailure::Panic(_) | AttemptFailure::Executor(_) => {
398 return AttemptFailureDecision::UseDefault;
399 }
400 AttemptFailure::Error(failure) => failure,
401 AttemptFailure::Timeout => {
402 unreachable!("CAS executor manages async timeouts explicitly")
403 }
404 };
405 let cas_context = CasContext::new(context, attempt_timeout);
406 {
407 let mut report = report_builder
408 .lock()
409 .expect("CAS report builder should be lockable");
410 match failure {
411 CasAttemptFailure::Conflict { .. } => report.record_conflict(),
412 CasAttemptFailure::Retry { .. } => report.record_retry_error(),
413 CasAttemptFailure::Abort { .. } => report.record_abort(),
414 CasAttemptFailure::Timeout { .. } => report.record_timeout(),
415 }
416 }
417 if Self::should_emit_events(&observability, &event_hook) {
418 Self::dispatch_event(
419 &observability,
420 event_hook
421 .as_ref()
422 .expect("event hook should exist when events are emitted"),
423 CasEvent::AttemptFailed {
424 context: cas_context,
425 kind: Self::failure_kind(failure),
426 },
427 );
428 }
429 match failure {
430 CasAttemptFailure::Conflict { .. } | CasAttemptFailure::Retry { .. } => {
431 if Self::should_emit_events(&observability, &event_hook) {
432 Self::dispatch_event(
433 &observability,
434 event_hook
435 .as_ref()
436 .expect("event hook should exist when events are emitted"),
437 CasEvent::RetryRequested {
438 context: cas_context,
439 },
440 );
441 }
442 AttemptFailureDecision::Retry
443 }
444 CasAttemptFailure::Abort { .. } => AttemptFailureDecision::Abort,
445 CasAttemptFailure::Timeout { .. } => match timeout_policy {
446 CasTimeoutPolicy::Retry => {
447 if Self::should_emit_events(&observability, &event_hook) {
448 Self::dispatch_event(
449 &observability,
450 event_hook.as_ref().expect(
451 "event hook should exist when events are emitted",
452 ),
453 CasEvent::RetryRequested {
454 context: cas_context,
455 },
456 );
457 }
458 AttemptFailureDecision::Retry
459 }
460 CasTimeoutPolicy::Abort => AttemptFailureDecision::Abort,
461 },
462 }
463 },
464 );
465
466 if self.observability.listener_panic_policy() == ListenerPanicPolicy::Isolate {
467 builder = builder.isolate_listener_panics();
468 }
469 builder
470 .build()
471 .expect("validated CAS executor configuration must build retry policy")
472 }
473
474 fn run_sync_attempt<R, O>(
483 &self,
484 state: &AtomicRef<T>,
485 operation: &O,
486 ) -> Result<AttemptSuccess<T, R>, CasAttemptFailure<T, E>>
487 where
488 O: Function<T, CasDecision<T, R, E>>,
489 {
490 let current = state.load();
491 match operation.apply(current.as_ref()) {
492 CasDecision::Update { next, output } => {
493 match state.compare_set(¤t, Arc::clone(&next)) {
494 Ok(()) => Ok(AttemptSuccess::Updated {
495 previous: current,
496 current: next,
497 output,
498 }),
499 Err(actual) => Err(CasAttemptFailure::conflict(actual)),
500 }
501 }
502 CasDecision::Finish { output } => Ok(AttemptSuccess::Finished { current, output }),
503 CasDecision::Retry(error) => Err(CasAttemptFailure::retry(current, error)),
504 CasDecision::Abort(error) => Err(CasAttemptFailure::abort(current, error)),
505 }
506 }
507
508 #[cfg(feature = "tokio")]
517 async fn run_async_attempt<R, O, Fut>(
518 &self,
519 state: &AtomicRef<T>,
520 operation: &O,
521 ) -> Result<AttemptSuccess<T, R>, CasAttemptFailure<T, E>>
522 where
523 O: Fn(Arc<T>) -> Fut,
524 Fut: std::future::Future<Output = CasDecision<T, R, E>>,
525 {
526 let current = state.load();
527 let decision = if let Some(timeout) = self.attempt_timeout {
528 match tokio::time::timeout(timeout, operation(Arc::clone(¤t))).await {
529 Ok(decision) => decision,
530 Err(_) => return Err(CasAttemptFailure::timeout(current)),
531 }
532 } else {
533 operation(Arc::clone(¤t)).await
534 };
535
536 match decision {
537 CasDecision::Update { next, output } => {
538 match state.compare_set(¤t, Arc::clone(&next)) {
539 Ok(()) => Ok(AttemptSuccess::Updated {
540 previous: current,
541 current: next,
542 output,
543 }),
544 Err(actual) => Err(CasAttemptFailure::conflict(actual)),
545 }
546 }
547 CasDecision::Finish { output } => Ok(AttemptSuccess::Finished { current, output }),
548 CasDecision::Retry(error) => Err(CasAttemptFailure::retry(current, error)),
549 CasDecision::Abort(error) => Err(CasAttemptFailure::abort(current, error)),
550 }
551 }
552
553 fn finish_execution<R>(
563 &self,
564 attempt: Result<AttemptSuccess<T, R>, RetryError<CasAttemptFailure<T, E>>>,
565 hooks: CasHooks,
566 success_context: Arc<Mutex<Option<RetryContext>>>,
567 report_builder: Arc<Mutex<CasReportBuilder>>,
568 ) -> CasOutcome<T, R, E>
569 where
570 T: 'static,
571 E: 'static,
572 {
573 match attempt {
574 Ok(success) => {
575 let context = success_context
576 .lock()
577 .expect("CAS success context slot should be lockable")
578 .take()
579 .expect("retry success hook must capture CAS success context");
580 let attempts_total = context.attempt();
581 let max_attempts = context.max_attempts();
582 let max_operation_elapsed = context.max_operation_elapsed();
583 let max_total_elapsed = context.max_total_elapsed();
584 let outcome = match success {
585 AttemptSuccess::Updated { .. } => CasExecutionOutcome::SuccessUpdated,
586 AttemptSuccess::Finished { .. } => CasExecutionOutcome::SuccessFinished,
587 };
588 let success = self.enrich_success(success, context);
589 let report = self.finish_report(
590 &hooks,
591 report_builder,
592 CasReportFinishContext::new(
593 attempts_total,
594 max_attempts,
595 max_operation_elapsed,
596 max_total_elapsed,
597 outcome,
598 ),
599 );
600 CasOutcome::new(Ok(success), report)
601 }
602 Err(error) => {
603 let error = CasError::new(error, self.attempt_timeout);
604 let context = error.context();
605 let outcome = Self::error_outcome(error.kind());
606 let report = self.finish_report(
607 &hooks,
608 report_builder,
609 CasReportFinishContext::new(
610 context.attempt(),
611 context.max_attempts(),
612 context.max_operation_elapsed(),
613 context.max_total_elapsed(),
614 outcome,
615 ),
616 );
617 CasOutcome::new(Err(error), report)
618 }
619 }
620 }
621
622 fn enrich_success<R>(
631 &self,
632 success: AttemptSuccess<T, R>,
633 context: RetryContext,
634 ) -> CasSuccess<T, R> {
635 let context = CasContext::new(&context, self.attempt_timeout);
636 match success {
637 AttemptSuccess::Updated {
638 previous,
639 current,
640 output,
641 } => CasSuccess::updated(previous, current, output, context),
642 AttemptSuccess::Finished { current, output } => {
643 CasSuccess::finished(current, output, context)
644 }
645 }
646 }
647
648 fn emit_started(&self, hooks: &CasHooks, report_builder: &Arc<Mutex<CasReportBuilder>>)
654 where
655 T: 'static,
656 E: 'static,
657 {
658 if hooks.event_hook().is_none()
659 || self.observability.mode() == CasObservabilityMode::ReportOnly
660 {
661 return;
662 }
663 let started_at = report_builder
664 .lock()
665 .expect("CAS report builder should be lockable")
666 .started_at();
667 let event_hook = hooks.event_hook();
668 Self::dispatch_event(
669 &self.observability,
670 event_hook
671 .as_ref()
672 .expect("event hook should exist when events are emitted"),
673 CasEvent::ExecutionStarted { started_at },
674 );
675 }
676
677 fn finish_report(
691 &self,
692 hooks: &CasHooks,
693 report_builder: Arc<Mutex<CasReportBuilder>>,
694 ctx: CasReportFinishContext,
695 ) -> CasExecutionReport
696 where
697 T: 'static,
698 E: 'static,
699 {
700 let report = report_builder
701 .lock()
702 .expect("CAS report builder should be lockable")
703 .finish(
704 ctx.attempts_total,
705 ctx.max_attempts,
706 ctx.max_operation_elapsed,
707 ctx.max_total_elapsed,
708 ctx.outcome,
709 );
710 let event_hook = hooks.event_hook();
711 if Self::should_emit_events(&self.observability, &event_hook) {
712 Self::dispatch_event(
713 &self.observability,
714 event_hook
715 .as_ref()
716 .expect("event hook should exist when events are emitted"),
717 CasEvent::ExecutionFinished {
718 report: report.clone(),
719 },
720 );
721 }
722 if self.observability.mode() == CasObservabilityMode::EventStreamWithAlert
723 && let Some(thresholds) = self.observability.contention_thresholds()
724 && report.is_contention_hot(&thresholds)
725 {
726 Self::dispatch_alert(
727 &self.observability,
728 &hooks.alert_hook(),
729 CasAlert::contention(report.clone(), thresholds),
730 );
731 }
732 report
733 }
734
735 #[inline]
743 fn error_outcome(kind: CasErrorKind) -> CasExecutionOutcome {
744 match kind {
745 CasErrorKind::Abort => CasExecutionOutcome::ErrorAbort,
746 CasErrorKind::Conflict => CasExecutionOutcome::ErrorConflictExhausted,
747 CasErrorKind::RetryExhausted => CasExecutionOutcome::ErrorRetryExhausted,
748 CasErrorKind::AttemptTimeout => CasExecutionOutcome::ErrorAttemptTimeout,
749 CasErrorKind::MaxOperationElapsedExceeded => {
750 CasExecutionOutcome::ErrorMaxOperationElapsedExceeded
751 }
752 CasErrorKind::MaxTotalElapsedExceeded => {
753 CasExecutionOutcome::ErrorMaxTotalElapsedExceeded
754 }
755 }
756 }
757
758 #[inline]
766 fn failure_kind(failure: &CasAttemptFailure<T, E>) -> crate::error::CasAttemptFailureKind {
767 failure.kind()
768 }
769
770 fn dispatch_event(
772 observability: &CasObservabilityConfig,
773 hook: &crate::event::CasEventHook,
774 event: CasEvent,
775 ) where
776 T: 'static,
777 E: 'static,
778 {
779 match observability.listener_panic_policy() {
780 ListenerPanicPolicy::Propagate => hook.accept(&event),
781 ListenerPanicPolicy::Isolate => {
782 let _ = catch_unwind(AssertUnwindSafe(|| hook.accept(&event)));
783 }
784 }
785 }
786
787 #[inline]
789 fn should_emit_events(
790 observability: &CasObservabilityConfig,
791 hook: &Option<crate::event::CasEventHook>,
792 ) -> bool {
793 observability.mode() != CasObservabilityMode::ReportOnly && hook.is_some()
794 }
795
796 fn dispatch_alert(
798 observability: &CasObservabilityConfig,
799 hook: &Option<crate::event::CasAlertHook>,
800 alert: CasAlert,
801 ) {
802 if let Some(hook) = hook {
803 match observability.listener_panic_policy() {
804 ListenerPanicPolicy::Propagate => hook.accept(&alert),
805 ListenerPanicPolicy::Isolate => {
806 let _ = catch_unwind(AssertUnwindSafe(|| hook.accept(&alert)));
807 }
808 }
809 }
810 }
811}