1use std::marker::PhantomData;
12use std::panic::{AssertUnwindSafe, catch_unwind};
13use std::sync::{Arc, Mutex};
14use std::time::Duration;
15
16use qubit_atomic::AtomicRef;
17use qubit_common::BoxError;
18use qubit_function::{Consumer, Function};
19use qubit_retry::{
20 AttemptFailure, AttemptFailureDecision, Retry, RetryContext, RetryError, RetryOptions,
21};
22
23use crate::cas_decision::CasDecision;
24use crate::cas_outcome::CasOutcome;
25use crate::cas_success::CasSuccess;
26use crate::error::{CasAttemptFailure, CasError, CasErrorKind};
27use crate::event::{CasContext, CasEvent, CasHooks};
28use crate::observability::{
29 CasAlert, CasObservabilityConfig, CasObservabilityMode, ListenerPanicPolicy,
30};
31use crate::options::CasTimeoutPolicy;
32use crate::report::{CasExecutionOutcome, CasExecutionReport, CasReportBuilder};
33use crate::strategy::CasStrategy;
34
35use super::cas_builder::CasBuilder;
36
37#[derive(Debug, Clone)]
39pub struct CasExecutor<T, E = BoxError> {
40 options: RetryOptions,
42 attempt_timeout: Option<Duration>,
44 timeout_policy: CasTimeoutPolicy,
46 observability: CasObservabilityConfig,
48 marker: PhantomData<fn() -> (T, E)>,
50}
51
52enum AttemptSuccess<T, R> {
54 Updated {
56 previous: Arc<T>,
57 current: Arc<T>,
58 output: R,
59 },
60 Finished { current: Arc<T>, output: R },
62}
63
64struct CasReportFinishContext {
66 attempts_total: u32,
67 max_attempts: u32,
68 max_operation_elapsed: Option<Duration>,
69 max_total_elapsed: Option<Duration>,
70 outcome: CasExecutionOutcome,
71}
72
73impl CasReportFinishContext {
74 #[inline]
75 fn new(
76 attempts_total: u32,
77 max_attempts: u32,
78 max_operation_elapsed: Option<Duration>,
79 max_total_elapsed: Option<Duration>,
80 outcome: CasExecutionOutcome,
81 ) -> Self {
82 Self {
83 attempts_total,
84 max_attempts,
85 max_operation_elapsed,
86 max_total_elapsed,
87 outcome,
88 }
89 }
90}
91
92impl<T, E> CasExecutor<T, E> {
93 #[inline]
98 pub fn builder() -> CasBuilder<T, E> {
99 CasBuilder::new()
100 }
101
102 pub fn from_options(options: RetryOptions) -> Result<Self, qubit_retry::RetryConfigError> {
113 Self::builder().options(options).build()
114 }
115
116 pub fn latency_first() -> Self {
121 Self::builder()
122 .build_latency_first()
123 .expect("latency-first CAS strategy must be valid")
124 }
125
126 pub fn contention_adaptive() -> Self {
131 Self::builder()
132 .build_contention_adaptive()
133 .expect("contention-adaptive CAS strategy must be valid")
134 }
135
136 pub fn reliability_first() -> Self {
141 Self::builder()
142 .build_reliability_first()
143 .expect("reliability-first CAS strategy must be valid")
144 }
145
146 pub fn with_strategy(strategy: CasStrategy) -> Self {
154 Self::builder()
155 .strategy(strategy)
156 .build()
157 .expect("built-in CAS strategy must be valid")
158 }
159
160 #[inline]
171 pub(crate) fn new(
172 options: RetryOptions,
173 attempt_timeout: Option<Duration>,
174 timeout_policy: CasTimeoutPolicy,
175 observability: CasObservabilityConfig,
176 ) -> Self {
177 Self {
178 options,
179 attempt_timeout,
180 timeout_policy,
181 observability,
182 marker: PhantomData,
183 }
184 }
185
186 #[inline]
191 pub fn options(&self) -> &RetryOptions {
192 &self.options
193 }
194
195 #[inline]
200 pub fn attempt_timeout(&self) -> Option<Duration> {
201 self.attempt_timeout
202 }
203
204 #[inline]
209 pub fn timeout_policy(&self) -> CasTimeoutPolicy {
210 self.timeout_policy
211 }
212
213 #[inline]
218 pub fn observability(&self) -> &CasObservabilityConfig {
219 &self.observability
220 }
221
222 pub fn execute<R, O>(&self, state: &AtomicRef<T>, operation: O) -> CasOutcome<T, R, E>
232 where
233 T: 'static,
234 E: 'static,
235 O: Function<T, CasDecision<T, R, E>>,
236 {
237 self.execute_with_hooks(state, operation, CasHooks::new())
238 }
239
240 pub fn execute_with_hooks<R, O>(
251 &self,
252 state: &AtomicRef<T>,
253 operation: O,
254 hooks: CasHooks,
255 ) -> CasOutcome<T, R, E>
256 where
257 T: 'static,
258 E: 'static,
259 O: Function<T, CasDecision<T, R, E>>,
260 {
261 let success_context = Arc::new(Mutex::new(None));
262 let report_builder = Arc::new(Mutex::new(CasReportBuilder::start()));
263 self.emit_started(&hooks, &report_builder);
264 let retry = self.build_retry(
265 &hooks,
266 Arc::clone(&success_context),
267 Arc::clone(&report_builder),
268 );
269 let attempt = retry.run(|| self.run_sync_attempt(state, &operation));
270 self.finish_execution(attempt, hooks, success_context, report_builder)
271 }
272
273 #[cfg(feature = "tokio")]
282 pub async fn execute_async<R, O, Fut>(
283 &self,
284 state: &AtomicRef<T>,
285 operation: O,
286 ) -> CasOutcome<T, R, E>
287 where
288 T: 'static,
289 E: 'static,
290 O: Fn(Arc<T>) -> Fut,
291 Fut: std::future::Future<Output = CasDecision<T, R, E>>,
292 {
293 self.execute_async_with_hooks(state, operation, CasHooks::new())
294 .await
295 }
296
297 #[cfg(feature = "tokio")]
307 pub async fn execute_async_with_hooks<R, O, Fut>(
308 &self,
309 state: &AtomicRef<T>,
310 operation: O,
311 hooks: CasHooks,
312 ) -> CasOutcome<T, R, E>
313 where
314 T: 'static,
315 E: 'static,
316 O: Fn(Arc<T>) -> Fut,
317 Fut: std::future::Future<Output = CasDecision<T, R, E>>,
318 {
319 let success_context = Arc::new(Mutex::new(None));
320 let report_builder = Arc::new(Mutex::new(CasReportBuilder::start()));
321 self.emit_started(&hooks, &report_builder);
322 let retry = self.build_retry(
323 &hooks,
324 Arc::clone(&success_context),
325 Arc::clone(&report_builder),
326 );
327 let attempt = retry
328 .run_async(|| self.run_async_attempt(state, &operation))
329 .await;
330 self.finish_execution(attempt, hooks, success_context, report_builder)
331 }
332
333 fn build_retry(
343 &self,
344 hooks: &CasHooks,
345 success_context: Arc<Mutex<Option<RetryContext>>>,
346 report_builder: Arc<Mutex<CasReportBuilder>>,
347 ) -> Retry<CasAttemptFailure<T, E>>
348 where
349 T: 'static,
350 E: 'static,
351 {
352 let event_hook = hooks.event_hook();
353 let timeout_policy = self.timeout_policy;
354 let attempt_timeout = self.attempt_timeout;
355 let observability = self.observability.clone();
356
357 let mut builder = Retry::<CasAttemptFailure<T, E>>::builder()
358 .options(self.options.clone())
359 .on_success(move |context: &RetryContext| {
360 *success_context
361 .lock()
362 .expect("CAS success context slot should be lockable") = Some(*context);
363 })
364 .on_failure(
365 move |failure: &AttemptFailure<CasAttemptFailure<T, E>>, context: &RetryContext| {
366 let failure = match failure {
367 AttemptFailure::Panic(_) | AttemptFailure::Executor(_) => {
368 return AttemptFailureDecision::UseDefault;
369 }
370 AttemptFailure::Error(failure) => failure,
371 AttemptFailure::Timeout => {
372 unreachable!("CAS executor manages async timeouts explicitly")
373 }
374 };
375 let cas_context = CasContext::new(context, attempt_timeout);
376 {
377 let mut report = report_builder
378 .lock()
379 .expect("CAS report builder should be lockable");
380 match failure {
381 CasAttemptFailure::Conflict { .. } => report.record_conflict(),
382 CasAttemptFailure::Retry { .. } => report.record_retry_error(),
383 CasAttemptFailure::Abort { .. } => report.record_abort(),
384 CasAttemptFailure::Timeout { .. } => report.record_timeout(),
385 }
386 }
387 if Self::should_emit_events(&observability, &event_hook) {
388 Self::dispatch_event(
389 &observability,
390 event_hook
391 .as_ref()
392 .expect("event hook should exist when events are emitted"),
393 CasEvent::AttemptFailed {
394 context: cas_context,
395 kind: Self::failure_kind(failure),
396 },
397 );
398 }
399 match failure {
400 CasAttemptFailure::Conflict { .. } | CasAttemptFailure::Retry { .. } => {
401 if Self::should_emit_events(&observability, &event_hook) {
402 Self::dispatch_event(
403 &observability,
404 event_hook
405 .as_ref()
406 .expect("event hook should exist when events are emitted"),
407 CasEvent::RetryRequested {
408 context: cas_context,
409 },
410 );
411 }
412 AttemptFailureDecision::Retry
413 }
414 CasAttemptFailure::Abort { .. } => AttemptFailureDecision::Abort,
415 CasAttemptFailure::Timeout { .. } => match timeout_policy {
416 CasTimeoutPolicy::Retry => {
417 if Self::should_emit_events(&observability, &event_hook) {
418 Self::dispatch_event(
419 &observability,
420 event_hook.as_ref().expect(
421 "event hook should exist when events are emitted",
422 ),
423 CasEvent::RetryRequested {
424 context: cas_context,
425 },
426 );
427 }
428 AttemptFailureDecision::Retry
429 }
430 CasTimeoutPolicy::Abort => AttemptFailureDecision::Abort,
431 },
432 }
433 },
434 );
435
436 if self.observability.listener_panic_policy() == ListenerPanicPolicy::Isolate {
437 builder = builder.isolate_listener_panics();
438 }
439 builder
440 .build()
441 .expect("validated CAS executor configuration must build retry policy")
442 }
443
444 fn run_sync_attempt<R, O>(
453 &self,
454 state: &AtomicRef<T>,
455 operation: &O,
456 ) -> Result<AttemptSuccess<T, R>, CasAttemptFailure<T, E>>
457 where
458 O: Function<T, CasDecision<T, R, E>>,
459 {
460 let current = state.load();
461 match operation.apply(current.as_ref()) {
462 CasDecision::Update { next, output } => {
463 match state.compare_set(¤t, Arc::clone(&next)) {
464 Ok(()) => Ok(AttemptSuccess::Updated {
465 previous: current,
466 current: next,
467 output,
468 }),
469 Err(actual) => Err(CasAttemptFailure::conflict(actual)),
470 }
471 }
472 CasDecision::Finish { output } => Ok(AttemptSuccess::Finished { current, output }),
473 CasDecision::Retry(error) => Err(CasAttemptFailure::retry(current, error)),
474 CasDecision::Abort(error) => Err(CasAttemptFailure::abort(current, error)),
475 }
476 }
477
478 #[cfg(feature = "tokio")]
487 async fn run_async_attempt<R, O, Fut>(
488 &self,
489 state: &AtomicRef<T>,
490 operation: &O,
491 ) -> Result<AttemptSuccess<T, R>, CasAttemptFailure<T, E>>
492 where
493 O: Fn(Arc<T>) -> Fut,
494 Fut: std::future::Future<Output = CasDecision<T, R, E>>,
495 {
496 let current = state.load();
497 let decision = if let Some(timeout) = self.attempt_timeout {
498 match tokio::time::timeout(timeout, operation(Arc::clone(¤t))).await {
499 Ok(decision) => decision,
500 Err(_) => return Err(CasAttemptFailure::timeout(current)),
501 }
502 } else {
503 operation(Arc::clone(¤t)).await
504 };
505
506 match decision {
507 CasDecision::Update { next, output } => {
508 match state.compare_set(¤t, Arc::clone(&next)) {
509 Ok(()) => Ok(AttemptSuccess::Updated {
510 previous: current,
511 current: next,
512 output,
513 }),
514 Err(actual) => Err(CasAttemptFailure::conflict(actual)),
515 }
516 }
517 CasDecision::Finish { output } => Ok(AttemptSuccess::Finished { current, output }),
518 CasDecision::Retry(error) => Err(CasAttemptFailure::retry(current, error)),
519 CasDecision::Abort(error) => Err(CasAttemptFailure::abort(current, error)),
520 }
521 }
522
523 fn finish_execution<R>(
533 &self,
534 attempt: Result<AttemptSuccess<T, R>, RetryError<CasAttemptFailure<T, E>>>,
535 hooks: CasHooks,
536 success_context: Arc<Mutex<Option<RetryContext>>>,
537 report_builder: Arc<Mutex<CasReportBuilder>>,
538 ) -> CasOutcome<T, R, E>
539 where
540 T: 'static,
541 E: 'static,
542 {
543 match attempt {
544 Ok(success) => {
545 let context = success_context
546 .lock()
547 .expect("CAS success context slot should be lockable")
548 .take()
549 .expect("retry success hook must capture CAS success context");
550 let attempts_total = context.attempt();
551 let max_attempts = context.max_attempts();
552 let max_operation_elapsed = context.max_operation_elapsed();
553 let max_total_elapsed = context.max_total_elapsed();
554 let outcome = match success {
555 AttemptSuccess::Updated { .. } => CasExecutionOutcome::SuccessUpdated,
556 AttemptSuccess::Finished { .. } => CasExecutionOutcome::SuccessFinished,
557 };
558 let success = self.enrich_success(success, context);
559 let report = self.finish_report(
560 &hooks,
561 report_builder,
562 CasReportFinishContext::new(
563 attempts_total,
564 max_attempts,
565 max_operation_elapsed,
566 max_total_elapsed,
567 outcome,
568 ),
569 );
570 CasOutcome::new(Ok(success), report)
571 }
572 Err(error) => {
573 let error = CasError::new(error, self.attempt_timeout);
574 let context = error.context();
575 let outcome = Self::error_outcome(error.kind());
576 let report = self.finish_report(
577 &hooks,
578 report_builder,
579 CasReportFinishContext::new(
580 context.attempt(),
581 context.max_attempts(),
582 context.max_operation_elapsed(),
583 context.max_total_elapsed(),
584 outcome,
585 ),
586 );
587 CasOutcome::new(Err(error), report)
588 }
589 }
590 }
591
592 fn enrich_success<R>(
601 &self,
602 success: AttemptSuccess<T, R>,
603 context: RetryContext,
604 ) -> CasSuccess<T, R> {
605 let context = CasContext::new(&context, self.attempt_timeout);
606 match success {
607 AttemptSuccess::Updated {
608 previous,
609 current,
610 output,
611 } => CasSuccess::updated(previous, current, output, context),
612 AttemptSuccess::Finished { current, output } => {
613 CasSuccess::finished(current, output, context)
614 }
615 }
616 }
617
618 fn emit_started(&self, hooks: &CasHooks, report_builder: &Arc<Mutex<CasReportBuilder>>)
624 where
625 T: 'static,
626 E: 'static,
627 {
628 if hooks.event_hook().is_none()
629 || self.observability.mode() == CasObservabilityMode::ReportOnly
630 {
631 return;
632 }
633 let started_at = report_builder
634 .lock()
635 .expect("CAS report builder should be lockable")
636 .started_at();
637 let event_hook = hooks.event_hook();
638 Self::dispatch_event(
639 &self.observability,
640 event_hook
641 .as_ref()
642 .expect("event hook should exist when events are emitted"),
643 CasEvent::ExecutionStarted { started_at },
644 );
645 }
646
647 fn finish_report(
661 &self,
662 hooks: &CasHooks,
663 report_builder: Arc<Mutex<CasReportBuilder>>,
664 ctx: CasReportFinishContext,
665 ) -> CasExecutionReport
666 where
667 T: 'static,
668 E: 'static,
669 {
670 let report = report_builder
671 .lock()
672 .expect("CAS report builder should be lockable")
673 .finish(
674 ctx.attempts_total,
675 ctx.max_attempts,
676 ctx.max_operation_elapsed,
677 ctx.max_total_elapsed,
678 ctx.outcome,
679 );
680 let event_hook = hooks.event_hook();
681 if Self::should_emit_events(&self.observability, &event_hook) {
682 Self::dispatch_event(
683 &self.observability,
684 event_hook
685 .as_ref()
686 .expect("event hook should exist when events are emitted"),
687 CasEvent::ExecutionFinished {
688 report: report.clone(),
689 },
690 );
691 }
692 if self.observability.mode() == CasObservabilityMode::EventStreamWithAlert
693 && let Some(thresholds) = self.observability.contention_thresholds()
694 && report.is_contention_hot(&thresholds)
695 {
696 Self::dispatch_alert(
697 &self.observability,
698 &hooks.alert_hook(),
699 CasAlert::contention(report.clone(), thresholds),
700 );
701 }
702 report
703 }
704
705 #[inline]
713 fn error_outcome(kind: CasErrorKind) -> CasExecutionOutcome {
714 match kind {
715 CasErrorKind::Abort => CasExecutionOutcome::ErrorAbort,
716 CasErrorKind::Conflict => CasExecutionOutcome::ErrorConflictExhausted,
717 CasErrorKind::RetryExhausted => CasExecutionOutcome::ErrorRetryExhausted,
718 CasErrorKind::AttemptTimeout => CasExecutionOutcome::ErrorAttemptTimeout,
719 CasErrorKind::MaxOperationElapsedExceeded => {
720 CasExecutionOutcome::ErrorMaxOperationElapsedExceeded
721 }
722 CasErrorKind::MaxTotalElapsedExceeded => {
723 CasExecutionOutcome::ErrorMaxTotalElapsedExceeded
724 }
725 }
726 }
727
728 #[inline]
736 fn failure_kind(failure: &CasAttemptFailure<T, E>) -> crate::error::CasAttemptFailureKind {
737 failure.kind()
738 }
739
740 fn dispatch_event(
742 observability: &CasObservabilityConfig,
743 hook: &crate::event::CasEventHook,
744 event: CasEvent,
745 ) where
746 T: 'static,
747 E: 'static,
748 {
749 match observability.listener_panic_policy() {
750 ListenerPanicPolicy::Propagate => hook.accept(&event),
751 ListenerPanicPolicy::Isolate => {
752 let _ = catch_unwind(AssertUnwindSafe(|| hook.accept(&event)));
753 }
754 }
755 }
756
757 #[inline]
759 fn should_emit_events(
760 observability: &CasObservabilityConfig,
761 hook: &Option<crate::event::CasEventHook>,
762 ) -> bool {
763 observability.mode() != CasObservabilityMode::ReportOnly && hook.is_some()
764 }
765
766 fn dispatch_alert(
768 observability: &CasObservabilityConfig,
769 hook: &Option<crate::event::CasAlertHook>,
770 alert: CasAlert,
771 ) {
772 if let Some(hook) = hook {
773 match observability.listener_panic_policy() {
774 ListenerPanicPolicy::Propagate => hook.accept(&alert),
775 ListenerPanicPolicy::Isolate => {
776 let _ = catch_unwind(AssertUnwindSafe(|| hook.accept(&alert)));
777 }
778 }
779 }
780 }
781}