1use std::marker::PhantomData;
13use std::panic::{AssertUnwindSafe, catch_unwind};
14use std::sync::{Arc, Mutex};
15use std::time::Duration;
16
17use qubit_atomic::AtomicRef;
18use qubit_error::BoxError;
19use qubit_function::{Consumer, Function};
20use qubit_retry::{
21 AttemptFailure, AttemptFailureDecision, Retry, RetryContext, RetryError, RetryOptions,
22};
23
24use crate::cas_decision::CasDecision;
25use crate::cas_outcome::CasOutcome;
26use crate::cas_success::CasSuccess;
27use crate::error::{CasAttemptFailure, CasError, CasErrorKind};
28use crate::event::{CasContext, CasEvent, CasHooks};
29use crate::observability::{
30 CasAlert, CasObservabilityConfig, CasObservabilityMode, ListenerPanicPolicy,
31};
32use crate::options::CasTimeoutPolicy;
33use crate::report::{CasExecutionOutcome, CasExecutionReport, CasReportBuilder};
34use crate::strategy::CasStrategy;
35
36use super::cas_builder::CasBuilder;
37
38#[derive(Debug, Clone)]
40pub struct CasExecutor<T, E = BoxError> {
41 options: RetryOptions,
43 attempt_timeout: Option<Duration>,
45 timeout_policy: CasTimeoutPolicy,
47 observability: CasObservabilityConfig,
49 marker: PhantomData<fn() -> (T, E)>,
51}
52
53enum AttemptSuccess<T, R> {
55 Updated {
57 previous: Arc<T>,
58 current: Arc<T>,
59 output: R,
60 },
61 Finished { current: Arc<T>, output: R },
63}
64
65struct CasReportFinishContext {
67 attempts_total: u32,
68 max_attempts: u32,
69 max_operation_elapsed: Option<Duration>,
70 max_total_elapsed: Option<Duration>,
71 outcome: CasExecutionOutcome,
72}
73
74impl CasReportFinishContext {
75 #[inline]
76 fn new(
77 attempts_total: u32,
78 max_attempts: u32,
79 max_operation_elapsed: Option<Duration>,
80 max_total_elapsed: Option<Duration>,
81 outcome: CasExecutionOutcome,
82 ) -> Self {
83 Self {
84 attempts_total,
85 max_attempts,
86 max_operation_elapsed,
87 max_total_elapsed,
88 outcome,
89 }
90 }
91}
92
93impl<T, E> CasExecutor<T, E> {
94 #[inline]
99 pub fn builder() -> CasBuilder<T, E> {
100 CasBuilder::new()
101 }
102
103 pub fn from_options(options: RetryOptions) -> Result<Self, qubit_retry::RetryConfigError> {
114 Self::builder().options(options).build()
115 }
116
117 pub fn latency_first() -> Self {
122 Self::builder()
123 .build_latency_first()
124 .expect("latency-first CAS strategy must be valid")
125 }
126
127 pub fn contention_adaptive() -> Self {
132 Self::builder()
133 .build_contention_adaptive()
134 .expect("contention-adaptive CAS strategy must be valid")
135 }
136
137 pub fn reliability_first() -> Self {
142 Self::builder()
143 .build_reliability_first()
144 .expect("reliability-first CAS strategy must be valid")
145 }
146
147 pub fn with_strategy(strategy: CasStrategy) -> Self {
155 Self::builder()
156 .strategy(strategy)
157 .build()
158 .expect("built-in CAS strategy must be valid")
159 }
160
161 #[inline]
172 pub(crate) fn new(
173 options: RetryOptions,
174 attempt_timeout: Option<Duration>,
175 timeout_policy: CasTimeoutPolicy,
176 observability: CasObservabilityConfig,
177 ) -> Self {
178 Self {
179 options,
180 attempt_timeout,
181 timeout_policy,
182 observability,
183 marker: PhantomData,
184 }
185 }
186
187 #[inline]
192 pub fn options(&self) -> &RetryOptions {
193 &self.options
194 }
195
196 #[inline]
201 pub fn attempt_timeout(&self) -> Option<Duration> {
202 self.attempt_timeout
203 }
204
205 #[inline]
210 pub fn timeout_policy(&self) -> CasTimeoutPolicy {
211 self.timeout_policy
212 }
213
214 #[inline]
219 pub fn observability(&self) -> &CasObservabilityConfig {
220 &self.observability
221 }
222
223 pub fn execute<R, O>(&self, state: &AtomicRef<T>, operation: O) -> CasOutcome<T, R, E>
233 where
234 T: 'static,
235 E: 'static,
236 O: Function<T, CasDecision<T, R, E>>,
237 {
238 self.execute_with_hooks(state, operation, CasHooks::new())
239 }
240
241 pub fn execute_with_hooks<R, O>(
252 &self,
253 state: &AtomicRef<T>,
254 operation: O,
255 hooks: CasHooks,
256 ) -> CasOutcome<T, R, E>
257 where
258 T: 'static,
259 E: 'static,
260 O: Function<T, CasDecision<T, R, E>>,
261 {
262 let success_context = Arc::new(Mutex::new(None));
263 let report_builder = Arc::new(Mutex::new(CasReportBuilder::start()));
264 self.emit_started(&hooks, &report_builder);
265 let retry = self.build_retry(
266 &hooks,
267 Arc::clone(&success_context),
268 Arc::clone(&report_builder),
269 );
270 let attempt = retry.run(|| self.run_sync_attempt(state, &operation));
271 self.finish_execution(attempt, hooks, success_context, report_builder)
272 }
273
274 #[cfg(feature = "tokio")]
283 pub async fn execute_async<R, O, Fut>(
284 &self,
285 state: &AtomicRef<T>,
286 operation: O,
287 ) -> CasOutcome<T, R, E>
288 where
289 T: 'static,
290 E: 'static,
291 O: Fn(Arc<T>) -> Fut,
292 Fut: std::future::Future<Output = CasDecision<T, R, E>>,
293 {
294 self.execute_async_with_hooks(state, operation, CasHooks::new())
295 .await
296 }
297
298 #[cfg(feature = "tokio")]
308 pub async fn execute_async_with_hooks<R, O, Fut>(
309 &self,
310 state: &AtomicRef<T>,
311 operation: O,
312 hooks: CasHooks,
313 ) -> CasOutcome<T, R, E>
314 where
315 T: 'static,
316 E: 'static,
317 O: Fn(Arc<T>) -> Fut,
318 Fut: std::future::Future<Output = CasDecision<T, R, E>>,
319 {
320 let success_context = Arc::new(Mutex::new(None));
321 let report_builder = Arc::new(Mutex::new(CasReportBuilder::start()));
322 self.emit_started(&hooks, &report_builder);
323 let retry = self.build_retry(
324 &hooks,
325 Arc::clone(&success_context),
326 Arc::clone(&report_builder),
327 );
328 let attempt = retry
329 .run_async(|| self.run_async_attempt(state, &operation))
330 .await;
331 self.finish_execution(attempt, hooks, success_context, report_builder)
332 }
333
334 fn build_retry(
344 &self,
345 hooks: &CasHooks,
346 success_context: Arc<Mutex<Option<RetryContext>>>,
347 report_builder: Arc<Mutex<CasReportBuilder>>,
348 ) -> Retry<CasAttemptFailure<T, E>>
349 where
350 T: 'static,
351 E: 'static,
352 {
353 let event_hook = hooks.event_hook();
354 let timeout_policy = self.timeout_policy;
355 let attempt_timeout = self.attempt_timeout;
356 let observability = self.observability.clone();
357
358 let mut builder = Retry::<CasAttemptFailure<T, E>>::builder()
359 .options(self.options.clone())
360 .on_success(move |context: &RetryContext| {
361 *success_context
362 .lock()
363 .expect("CAS success context slot should be lockable") = Some(*context);
364 })
365 .on_failure(
366 move |failure: &AttemptFailure<CasAttemptFailure<T, E>>, context: &RetryContext| {
367 let failure = match failure {
368 AttemptFailure::Panic(_) | AttemptFailure::Executor(_) => {
369 return AttemptFailureDecision::UseDefault;
370 }
371 AttemptFailure::Error(failure) => failure,
372 AttemptFailure::Timeout => {
373 unreachable!("CAS executor manages async timeouts explicitly")
374 }
375 };
376 let cas_context = CasContext::new(context, attempt_timeout);
377 {
378 let mut report = report_builder
379 .lock()
380 .expect("CAS report builder should be lockable");
381 match failure {
382 CasAttemptFailure::Conflict { .. } => report.record_conflict(),
383 CasAttemptFailure::Retry { .. } => report.record_retry_error(),
384 CasAttemptFailure::Abort { .. } => report.record_abort(),
385 CasAttemptFailure::Timeout { .. } => report.record_timeout(),
386 }
387 }
388 if Self::should_emit_events(&observability, &event_hook) {
389 Self::dispatch_event(
390 &observability,
391 event_hook
392 .as_ref()
393 .expect("event hook should exist when events are emitted"),
394 CasEvent::AttemptFailed {
395 context: cas_context,
396 kind: Self::failure_kind(failure),
397 },
398 );
399 }
400 match failure {
401 CasAttemptFailure::Conflict { .. } | CasAttemptFailure::Retry { .. } => {
402 if Self::should_emit_events(&observability, &event_hook) {
403 Self::dispatch_event(
404 &observability,
405 event_hook
406 .as_ref()
407 .expect("event hook should exist when events are emitted"),
408 CasEvent::RetryRequested {
409 context: cas_context,
410 },
411 );
412 }
413 AttemptFailureDecision::Retry
414 }
415 CasAttemptFailure::Abort { .. } => AttemptFailureDecision::Abort,
416 CasAttemptFailure::Timeout { .. } => match timeout_policy {
417 CasTimeoutPolicy::Retry => {
418 if Self::should_emit_events(&observability, &event_hook) {
419 Self::dispatch_event(
420 &observability,
421 event_hook.as_ref().expect(
422 "event hook should exist when events are emitted",
423 ),
424 CasEvent::RetryRequested {
425 context: cas_context,
426 },
427 );
428 }
429 AttemptFailureDecision::Retry
430 }
431 CasTimeoutPolicy::Abort => AttemptFailureDecision::Abort,
432 },
433 }
434 },
435 );
436
437 if self.observability.listener_panic_policy() == ListenerPanicPolicy::Isolate {
438 builder = builder.isolate_listener_panics();
439 }
440 builder
441 .build()
442 .expect("validated CAS executor configuration must build retry policy")
443 }
444
445 fn run_sync_attempt<R, O>(
454 &self,
455 state: &AtomicRef<T>,
456 operation: &O,
457 ) -> Result<AttemptSuccess<T, R>, CasAttemptFailure<T, E>>
458 where
459 O: Function<T, CasDecision<T, R, E>>,
460 {
461 let current = state.load();
462 match operation.apply(current.as_ref()) {
463 CasDecision::Update { next, output } => {
464 match state.compare_set(¤t, Arc::clone(&next)) {
465 Ok(()) => Ok(AttemptSuccess::Updated {
466 previous: current,
467 current: next,
468 output,
469 }),
470 Err(actual) => Err(CasAttemptFailure::conflict(actual)),
471 }
472 }
473 CasDecision::Finish { output } => Ok(AttemptSuccess::Finished { current, output }),
474 CasDecision::Retry(error) => Err(CasAttemptFailure::retry(current, error)),
475 CasDecision::Abort(error) => Err(CasAttemptFailure::abort(current, error)),
476 }
477 }
478
479 #[cfg(feature = "tokio")]
488 async fn run_async_attempt<R, O, Fut>(
489 &self,
490 state: &AtomicRef<T>,
491 operation: &O,
492 ) -> Result<AttemptSuccess<T, R>, CasAttemptFailure<T, E>>
493 where
494 O: Fn(Arc<T>) -> Fut,
495 Fut: std::future::Future<Output = CasDecision<T, R, E>>,
496 {
497 let current = state.load();
498 let decision = if let Some(timeout) = self.attempt_timeout {
499 match tokio::time::timeout(timeout, operation(Arc::clone(¤t))).await {
500 Ok(decision) => decision,
501 Err(_) => return Err(CasAttemptFailure::timeout(current)),
502 }
503 } else {
504 operation(Arc::clone(¤t)).await
505 };
506
507 match decision {
508 CasDecision::Update { next, output } => {
509 match state.compare_set(¤t, Arc::clone(&next)) {
510 Ok(()) => Ok(AttemptSuccess::Updated {
511 previous: current,
512 current: next,
513 output,
514 }),
515 Err(actual) => Err(CasAttemptFailure::conflict(actual)),
516 }
517 }
518 CasDecision::Finish { output } => Ok(AttemptSuccess::Finished { current, output }),
519 CasDecision::Retry(error) => Err(CasAttemptFailure::retry(current, error)),
520 CasDecision::Abort(error) => Err(CasAttemptFailure::abort(current, error)),
521 }
522 }
523
524 fn finish_execution<R>(
534 &self,
535 attempt: Result<AttemptSuccess<T, R>, RetryError<CasAttemptFailure<T, E>>>,
536 hooks: CasHooks,
537 success_context: Arc<Mutex<Option<RetryContext>>>,
538 report_builder: Arc<Mutex<CasReportBuilder>>,
539 ) -> CasOutcome<T, R, E>
540 where
541 T: 'static,
542 E: 'static,
543 {
544 match attempt {
545 Ok(success) => {
546 let context = success_context
547 .lock()
548 .expect("CAS success context slot should be lockable")
549 .take()
550 .expect("retry success hook must capture CAS success context");
551 let attempts_total = context.attempt();
552 let max_attempts = context.max_attempts();
553 let max_operation_elapsed = context.max_operation_elapsed();
554 let max_total_elapsed = context.max_total_elapsed();
555 let outcome = match success {
556 AttemptSuccess::Updated { .. } => CasExecutionOutcome::SuccessUpdated,
557 AttemptSuccess::Finished { .. } => CasExecutionOutcome::SuccessFinished,
558 };
559 let success = self.enrich_success(success, context);
560 let report = self.finish_report(
561 &hooks,
562 report_builder,
563 CasReportFinishContext::new(
564 attempts_total,
565 max_attempts,
566 max_operation_elapsed,
567 max_total_elapsed,
568 outcome,
569 ),
570 );
571 CasOutcome::new(Ok(success), report)
572 }
573 Err(error) => {
574 let error = CasError::new(error, self.attempt_timeout);
575 let context = error.context();
576 let outcome = Self::error_outcome(error.kind());
577 let report = self.finish_report(
578 &hooks,
579 report_builder,
580 CasReportFinishContext::new(
581 context.attempt(),
582 context.max_attempts(),
583 context.max_operation_elapsed(),
584 context.max_total_elapsed(),
585 outcome,
586 ),
587 );
588 CasOutcome::new(Err(error), report)
589 }
590 }
591 }
592
593 fn enrich_success<R>(
602 &self,
603 success: AttemptSuccess<T, R>,
604 context: RetryContext,
605 ) -> CasSuccess<T, R> {
606 let context = CasContext::new(&context, self.attempt_timeout);
607 match success {
608 AttemptSuccess::Updated {
609 previous,
610 current,
611 output,
612 } => CasSuccess::updated(previous, current, output, context),
613 AttemptSuccess::Finished { current, output } => {
614 CasSuccess::finished(current, output, context)
615 }
616 }
617 }
618
619 fn emit_started(&self, hooks: &CasHooks, report_builder: &Arc<Mutex<CasReportBuilder>>)
625 where
626 T: 'static,
627 E: 'static,
628 {
629 if hooks.event_hook().is_none()
630 || self.observability.mode() == CasObservabilityMode::ReportOnly
631 {
632 return;
633 }
634 let started_at = report_builder
635 .lock()
636 .expect("CAS report builder should be lockable")
637 .started_at();
638 let event_hook = hooks.event_hook();
639 Self::dispatch_event(
640 &self.observability,
641 event_hook
642 .as_ref()
643 .expect("event hook should exist when events are emitted"),
644 CasEvent::ExecutionStarted { started_at },
645 );
646 }
647
648 fn finish_report(
662 &self,
663 hooks: &CasHooks,
664 report_builder: Arc<Mutex<CasReportBuilder>>,
665 ctx: CasReportFinishContext,
666 ) -> CasExecutionReport
667 where
668 T: 'static,
669 E: 'static,
670 {
671 let report = report_builder
672 .lock()
673 .expect("CAS report builder should be lockable")
674 .finish(
675 ctx.attempts_total,
676 ctx.max_attempts,
677 ctx.max_operation_elapsed,
678 ctx.max_total_elapsed,
679 ctx.outcome,
680 );
681 let event_hook = hooks.event_hook();
682 if Self::should_emit_events(&self.observability, &event_hook) {
683 Self::dispatch_event(
684 &self.observability,
685 event_hook
686 .as_ref()
687 .expect("event hook should exist when events are emitted"),
688 CasEvent::ExecutionFinished {
689 report: report.clone(),
690 },
691 );
692 }
693 if self.observability.mode() == CasObservabilityMode::EventStreamWithAlert
694 && let Some(thresholds) = self.observability.contention_thresholds()
695 && report.is_contention_hot(&thresholds)
696 {
697 Self::dispatch_alert(
698 &self.observability,
699 &hooks.alert_hook(),
700 CasAlert::contention(report.clone(), thresholds),
701 );
702 }
703 report
704 }
705
706 #[inline]
714 fn error_outcome(kind: CasErrorKind) -> CasExecutionOutcome {
715 match kind {
716 CasErrorKind::Abort => CasExecutionOutcome::ErrorAbort,
717 CasErrorKind::Conflict => CasExecutionOutcome::ErrorConflictExhausted,
718 CasErrorKind::RetryExhausted => CasExecutionOutcome::ErrorRetryExhausted,
719 CasErrorKind::AttemptTimeout => CasExecutionOutcome::ErrorAttemptTimeout,
720 CasErrorKind::MaxOperationElapsedExceeded => {
721 CasExecutionOutcome::ErrorMaxOperationElapsedExceeded
722 }
723 CasErrorKind::MaxTotalElapsedExceeded => {
724 CasExecutionOutcome::ErrorMaxTotalElapsedExceeded
725 }
726 }
727 }
728
729 #[inline]
737 fn failure_kind(failure: &CasAttemptFailure<T, E>) -> crate::error::CasAttemptFailureKind {
738 failure.kind()
739 }
740
741 fn dispatch_event(
743 observability: &CasObservabilityConfig,
744 hook: &crate::event::CasEventHook,
745 event: CasEvent,
746 ) where
747 T: 'static,
748 E: 'static,
749 {
750 match observability.listener_panic_policy() {
751 ListenerPanicPolicy::Propagate => hook.accept(&event),
752 ListenerPanicPolicy::Isolate => {
753 let _ = catch_unwind(AssertUnwindSafe(|| hook.accept(&event)));
754 }
755 }
756 }
757
758 #[inline]
760 fn should_emit_events(
761 observability: &CasObservabilityConfig,
762 hook: &Option<crate::event::CasEventHook>,
763 ) -> bool {
764 observability.mode() != CasObservabilityMode::ReportOnly && hook.is_some()
765 }
766
767 fn dispatch_alert(
769 observability: &CasObservabilityConfig,
770 hook: &Option<crate::event::CasAlertHook>,
771 alert: CasAlert,
772 ) {
773 if let Some(hook) = hook {
774 match observability.listener_panic_policy() {
775 ListenerPanicPolicy::Propagate => hook.accept(&alert),
776 ListenerPanicPolicy::Isolate => {
777 let _ = catch_unwind(AssertUnwindSafe(|| hook.accept(&alert)));
778 }
779 }
780 }
781 }
782}