1use std::marker::PhantomData;
12use std::panic::{AssertUnwindSafe, catch_unwind};
13use std::sync::{Arc, Mutex};
14use std::time::Duration;
15
16use qubit_atomic::AtomicRef;
17use qubit_common::BoxError;
18use qubit_function::{Consumer, Function};
19use qubit_retry::{
20 AttemptFailure, AttemptFailureDecision, Retry, RetryContext, RetryError, RetryOptions,
21};
22
23use crate::decision::CasDecision;
24use crate::error::{CasAttemptFailure, CasError, CasErrorKind};
25use crate::event::{CasContext, CasEvent, CasHooks};
26use crate::observability::{
27 CasAlert, CasObservabilityConfig, CasObservabilityMode, ListenerPanicPolicy,
28};
29use crate::options::CasTimeoutPolicy;
30use crate::outcome::CasOutcome;
31use crate::report::{CasExecutionOutcome, CasExecutionReport, CasReportBuilder};
32use crate::strategy::CasStrategy;
33use crate::success::CasSuccess;
34
35use super::cas_builder::CasBuilder;
36
37#[derive(Debug, Clone)]
39pub struct CasExecutor<T, E = BoxError> {
40 options: RetryOptions,
42 attempt_timeout: Option<Duration>,
44 timeout_policy: CasTimeoutPolicy,
46 observability: CasObservabilityConfig,
48 marker: PhantomData<fn() -> (T, E)>,
50}
51
52enum AttemptSuccess<T, R> {
54 Updated {
56 previous: Arc<T>,
57 current: Arc<T>,
58 output: R,
59 },
60 Finished { current: Arc<T>, output: R },
62}
63
64impl<T, E> CasExecutor<T, E> {
65 #[inline]
70 pub fn builder() -> CasBuilder<T, E> {
71 CasBuilder::new()
72 }
73
74 pub fn from_options(options: RetryOptions) -> Result<Self, qubit_retry::RetryConfigError> {
85 Self::builder().options(options).build()
86 }
87
88 pub fn latency_first() -> Self {
93 Self::builder()
94 .build_latency_first()
95 .expect("latency-first CAS strategy must be valid")
96 }
97
98 pub fn contention_adaptive() -> Self {
103 Self::builder()
104 .build_contention_adaptive()
105 .expect("contention-adaptive CAS strategy must be valid")
106 }
107
108 pub fn reliability_first() -> Self {
113 Self::builder()
114 .build_reliability_first()
115 .expect("reliability-first CAS strategy must be valid")
116 }
117
118 pub fn with_strategy(strategy: CasStrategy) -> Self {
126 Self::builder()
127 .strategy(strategy)
128 .build()
129 .expect("built-in CAS strategy must be valid")
130 }
131
132 #[inline]
143 pub(crate) fn new(
144 options: RetryOptions,
145 attempt_timeout: Option<Duration>,
146 timeout_policy: CasTimeoutPolicy,
147 observability: CasObservabilityConfig,
148 ) -> Self {
149 Self {
150 options,
151 attempt_timeout,
152 timeout_policy,
153 observability,
154 marker: PhantomData,
155 }
156 }
157
158 #[inline]
163 pub fn options(&self) -> &RetryOptions {
164 &self.options
165 }
166
167 #[inline]
172 pub fn attempt_timeout(&self) -> Option<Duration> {
173 self.attempt_timeout
174 }
175
176 #[inline]
181 pub fn timeout_policy(&self) -> CasTimeoutPolicy {
182 self.timeout_policy
183 }
184
185 #[inline]
190 pub fn observability(&self) -> &CasObservabilityConfig {
191 &self.observability
192 }
193
194 pub fn execute<R, O>(&self, state: &AtomicRef<T>, operation: O) -> CasOutcome<T, R, E>
204 where
205 T: 'static,
206 E: 'static,
207 O: Function<T, CasDecision<T, R, E>>,
208 {
209 self.execute_with_hooks(state, operation, CasHooks::new())
210 }
211
212 pub fn execute_with_hooks<R, O>(
223 &self,
224 state: &AtomicRef<T>,
225 operation: O,
226 hooks: CasHooks,
227 ) -> CasOutcome<T, R, E>
228 where
229 T: 'static,
230 E: 'static,
231 O: Function<T, CasDecision<T, R, E>>,
232 {
233 let success_context = Arc::new(Mutex::new(None));
234 let report_builder = Arc::new(Mutex::new(CasReportBuilder::start()));
235 self.emit_started(&hooks, &report_builder);
236 let retry = self.build_retry(
237 &hooks,
238 Arc::clone(&success_context),
239 Arc::clone(&report_builder),
240 );
241 let attempt = retry.run(|| self.run_sync_attempt(state, &operation));
242 self.finish_execution(attempt, hooks, success_context, report_builder)
243 }
244
245 #[cfg(feature = "tokio")]
254 pub async fn execute_async<R, O, Fut>(
255 &self,
256 state: &AtomicRef<T>,
257 operation: O,
258 ) -> CasOutcome<T, R, E>
259 where
260 T: 'static,
261 E: 'static,
262 O: Fn(Arc<T>) -> Fut,
263 Fut: std::future::Future<Output = CasDecision<T, R, E>>,
264 {
265 self.execute_async_with_hooks(state, operation, CasHooks::new())
266 .await
267 }
268
269 #[cfg(feature = "tokio")]
279 pub async fn execute_async_with_hooks<R, O, Fut>(
280 &self,
281 state: &AtomicRef<T>,
282 operation: O,
283 hooks: CasHooks,
284 ) -> CasOutcome<T, R, E>
285 where
286 T: 'static,
287 E: 'static,
288 O: Fn(Arc<T>) -> Fut,
289 Fut: std::future::Future<Output = CasDecision<T, R, E>>,
290 {
291 let success_context = Arc::new(Mutex::new(None));
292 let report_builder = Arc::new(Mutex::new(CasReportBuilder::start()));
293 self.emit_started(&hooks, &report_builder);
294 let retry = self.build_retry(
295 &hooks,
296 Arc::clone(&success_context),
297 Arc::clone(&report_builder),
298 );
299 let attempt = retry
300 .run_async(|| self.run_async_attempt(state, &operation))
301 .await;
302 self.finish_execution(attempt, hooks, success_context, report_builder)
303 }
304
305 fn build_retry(
315 &self,
316 hooks: &CasHooks,
317 success_context: Arc<Mutex<Option<RetryContext>>>,
318 report_builder: Arc<Mutex<CasReportBuilder>>,
319 ) -> Retry<CasAttemptFailure<T, E>>
320 where
321 T: 'static,
322 E: 'static,
323 {
324 let event_hook = hooks.event_hook();
325 let timeout_policy = self.timeout_policy;
326 let attempt_timeout = self.attempt_timeout;
327 let observability = self.observability.clone();
328
329 let mut builder = Retry::<CasAttemptFailure<T, E>>::builder()
330 .options(self.options.clone())
331 .on_success(move |context: &RetryContext| {
332 *success_context
333 .lock()
334 .expect("CAS success context slot should be lockable") = Some(*context);
335 })
336 .on_failure(
337 move |failure: &AttemptFailure<CasAttemptFailure<T, E>>, context: &RetryContext| {
338 let failure = match failure {
339 AttemptFailure::Panic(_) | AttemptFailure::Executor(_) => {
340 return AttemptFailureDecision::UseDefault;
341 }
342 AttemptFailure::Error(failure) => failure,
343 AttemptFailure::Timeout => {
344 unreachable!("CAS executor manages async timeouts explicitly")
345 }
346 };
347 let cas_context = CasContext::new(context, attempt_timeout);
348 {
349 let mut report = report_builder
350 .lock()
351 .expect("CAS report builder should be lockable");
352 match failure {
353 CasAttemptFailure::Conflict { .. } => report.record_conflict(),
354 CasAttemptFailure::Retry { .. } => report.record_retry_error(),
355 CasAttemptFailure::Abort { .. } => report.record_abort(),
356 CasAttemptFailure::Timeout { .. } => report.record_timeout(),
357 }
358 }
359 if Self::should_emit_events(&observability, &event_hook) {
360 Self::dispatch_event(
361 &observability,
362 event_hook
363 .as_ref()
364 .expect("event hook should exist when events are emitted"),
365 CasEvent::AttemptFailed {
366 context: cas_context,
367 kind: Self::failure_kind(failure),
368 },
369 );
370 }
371 match failure {
372 CasAttemptFailure::Conflict { .. } | CasAttemptFailure::Retry { .. } => {
373 if Self::should_emit_events(&observability, &event_hook) {
374 Self::dispatch_event(
375 &observability,
376 event_hook
377 .as_ref()
378 .expect("event hook should exist when events are emitted"),
379 CasEvent::RetryRequested {
380 context: cas_context,
381 },
382 );
383 }
384 AttemptFailureDecision::Retry
385 }
386 CasAttemptFailure::Abort { .. } => AttemptFailureDecision::Abort,
387 CasAttemptFailure::Timeout { .. } => match timeout_policy {
388 CasTimeoutPolicy::Retry => {
389 if Self::should_emit_events(&observability, &event_hook) {
390 Self::dispatch_event(
391 &observability,
392 event_hook.as_ref().expect(
393 "event hook should exist when events are emitted",
394 ),
395 CasEvent::RetryRequested {
396 context: cas_context,
397 },
398 );
399 }
400 AttemptFailureDecision::Retry
401 }
402 CasTimeoutPolicy::Abort => AttemptFailureDecision::Abort,
403 },
404 }
405 },
406 );
407
408 if self.observability.listener_panic_policy() == ListenerPanicPolicy::Isolate {
409 builder = builder.isolate_listener_panics();
410 }
411 builder
412 .build()
413 .expect("validated CAS executor configuration must build retry policy")
414 }
415
416 fn run_sync_attempt<R, O>(
425 &self,
426 state: &AtomicRef<T>,
427 operation: &O,
428 ) -> Result<AttemptSuccess<T, R>, CasAttemptFailure<T, E>>
429 where
430 O: Function<T, CasDecision<T, R, E>>,
431 {
432 let current = state.load();
433 match operation.apply(current.as_ref()) {
434 CasDecision::Update { next, output } => {
435 match state.compare_set(¤t, Arc::clone(&next)) {
436 Ok(()) => Ok(AttemptSuccess::Updated {
437 previous: current,
438 current: next,
439 output,
440 }),
441 Err(actual) => Err(CasAttemptFailure::conflict(actual)),
442 }
443 }
444 CasDecision::Finish { output } => Ok(AttemptSuccess::Finished { current, output }),
445 CasDecision::Retry(error) => Err(CasAttemptFailure::retry(current, error)),
446 CasDecision::Abort(error) => Err(CasAttemptFailure::abort(current, error)),
447 }
448 }
449
450 #[cfg(feature = "tokio")]
459 async fn run_async_attempt<R, O, Fut>(
460 &self,
461 state: &AtomicRef<T>,
462 operation: &O,
463 ) -> Result<AttemptSuccess<T, R>, CasAttemptFailure<T, E>>
464 where
465 O: Fn(Arc<T>) -> Fut,
466 Fut: std::future::Future<Output = CasDecision<T, R, E>>,
467 {
468 let current = state.load();
469 let decision = if let Some(timeout) = self.attempt_timeout {
470 match tokio::time::timeout(timeout, operation(Arc::clone(¤t))).await {
471 Ok(decision) => decision,
472 Err(_) => return Err(CasAttemptFailure::timeout(current)),
473 }
474 } else {
475 operation(Arc::clone(¤t)).await
476 };
477
478 match decision {
479 CasDecision::Update { next, output } => {
480 match state.compare_set(¤t, Arc::clone(&next)) {
481 Ok(()) => Ok(AttemptSuccess::Updated {
482 previous: current,
483 current: next,
484 output,
485 }),
486 Err(actual) => Err(CasAttemptFailure::conflict(actual)),
487 }
488 }
489 CasDecision::Finish { output } => Ok(AttemptSuccess::Finished { current, output }),
490 CasDecision::Retry(error) => Err(CasAttemptFailure::retry(current, error)),
491 CasDecision::Abort(error) => Err(CasAttemptFailure::abort(current, error)),
492 }
493 }
494
495 fn finish_execution<R>(
505 &self,
506 attempt: Result<AttemptSuccess<T, R>, RetryError<CasAttemptFailure<T, E>>>,
507 hooks: CasHooks,
508 success_context: Arc<Mutex<Option<RetryContext>>>,
509 report_builder: Arc<Mutex<CasReportBuilder>>,
510 ) -> CasOutcome<T, R, E>
511 where
512 T: 'static,
513 E: 'static,
514 {
515 match attempt {
516 Ok(success) => {
517 let context = success_context
518 .lock()
519 .expect("CAS success context slot should be lockable")
520 .take()
521 .expect("retry success hook must capture CAS success context");
522 let attempts_total = context.attempt();
523 let max_attempts = context.max_attempts();
524 let max_elapsed = context.max_elapsed();
525 let outcome = match success {
526 AttemptSuccess::Updated { .. } => CasExecutionOutcome::SuccessUpdated,
527 AttemptSuccess::Finished { .. } => CasExecutionOutcome::SuccessFinished,
528 };
529 let success = self.enrich_success(success, context);
530 let report = self.finish_report(
531 &hooks,
532 report_builder,
533 attempts_total,
534 max_attempts,
535 max_elapsed,
536 outcome,
537 );
538 CasOutcome::new(Ok(success), report)
539 }
540 Err(error) => {
541 let error = CasError::new(error, self.attempt_timeout);
542 let context = error.context();
543 let outcome = Self::error_outcome(error.kind());
544 let report = self.finish_report(
545 &hooks,
546 report_builder,
547 context.attempt(),
548 context.max_attempts(),
549 context.max_elapsed(),
550 outcome,
551 );
552 CasOutcome::new(Err(error), report)
553 }
554 }
555 }
556
557 fn enrich_success<R>(
566 &self,
567 success: AttemptSuccess<T, R>,
568 context: RetryContext,
569 ) -> CasSuccess<T, R> {
570 let context = CasContext::new(&context, self.attempt_timeout);
571 match success {
572 AttemptSuccess::Updated {
573 previous,
574 current,
575 output,
576 } => CasSuccess::updated(previous, current, output, context),
577 AttemptSuccess::Finished { current, output } => {
578 CasSuccess::finished(current, output, context)
579 }
580 }
581 }
582
583 fn emit_started(&self, hooks: &CasHooks, report_builder: &Arc<Mutex<CasReportBuilder>>)
589 where
590 T: 'static,
591 E: 'static,
592 {
593 if hooks.event_hook().is_none()
594 || self.observability.mode() == CasObservabilityMode::ReportOnly
595 {
596 return;
597 }
598 let started_at = report_builder
599 .lock()
600 .expect("CAS report builder should be lockable")
601 .started_at();
602 let event_hook = hooks.event_hook();
603 Self::dispatch_event(
604 &self.observability,
605 event_hook
606 .as_ref()
607 .expect("event hook should exist when events are emitted"),
608 CasEvent::ExecutionStarted { started_at },
609 );
610 }
611
612 fn finish_report(
629 &self,
630 hooks: &CasHooks,
631 report_builder: Arc<Mutex<CasReportBuilder>>,
632 attempts_total: u32,
633 max_attempts: u32,
634 max_elapsed: Option<Duration>,
635 outcome: CasExecutionOutcome,
636 ) -> CasExecutionReport
637 where
638 T: 'static,
639 E: 'static,
640 {
641 let report = report_builder
642 .lock()
643 .expect("CAS report builder should be lockable")
644 .finish(attempts_total, max_attempts, max_elapsed, outcome);
645 let event_hook = hooks.event_hook();
646 if Self::should_emit_events(&self.observability, &event_hook) {
647 Self::dispatch_event(
648 &self.observability,
649 event_hook
650 .as_ref()
651 .expect("event hook should exist when events are emitted"),
652 CasEvent::ExecutionFinished {
653 report: report.clone(),
654 },
655 );
656 }
657 if self.observability.mode() == CasObservabilityMode::EventStreamWithAlert
658 && let Some(thresholds) = self.observability.contention_thresholds()
659 && report.is_contention_hot(&thresholds)
660 {
661 Self::dispatch_alert(
662 &self.observability,
663 &hooks.alert_hook(),
664 CasAlert::contention(report.clone(), thresholds),
665 );
666 }
667 report
668 }
669
670 #[inline]
678 fn error_outcome(kind: CasErrorKind) -> CasExecutionOutcome {
679 match kind {
680 CasErrorKind::Abort => CasExecutionOutcome::ErrorAbort,
681 CasErrorKind::Conflict => CasExecutionOutcome::ErrorConflictExhausted,
682 CasErrorKind::RetryExhausted => CasExecutionOutcome::ErrorRetryExhausted,
683 CasErrorKind::AttemptTimeout => CasExecutionOutcome::ErrorAttemptTimeout,
684 CasErrorKind::MaxElapsedExceeded => CasExecutionOutcome::ErrorMaxElapsedExceeded,
685 }
686 }
687
688 #[inline]
696 fn failure_kind(failure: &CasAttemptFailure<T, E>) -> crate::error::CasAttemptFailureKind {
697 failure.kind()
698 }
699
700 fn dispatch_event(
702 observability: &CasObservabilityConfig,
703 hook: &crate::event::CasEventHook,
704 event: CasEvent,
705 ) where
706 T: 'static,
707 E: 'static,
708 {
709 match observability.listener_panic_policy() {
710 ListenerPanicPolicy::Propagate => hook.accept(&event),
711 ListenerPanicPolicy::Isolate => {
712 let _ = catch_unwind(AssertUnwindSafe(|| hook.accept(&event)));
713 }
714 }
715 }
716
717 #[inline]
719 fn should_emit_events(
720 observability: &CasObservabilityConfig,
721 hook: &Option<crate::event::CasEventHook>,
722 ) -> bool {
723 observability.mode() != CasObservabilityMode::ReportOnly && hook.is_some()
724 }
725
726 fn dispatch_alert(
728 observability: &CasObservabilityConfig,
729 hook: &Option<crate::event::CasAlertHook>,
730 alert: CasAlert,
731 ) {
732 if let Some(hook) = hook {
733 match observability.listener_panic_policy() {
734 ListenerPanicPolicy::Propagate => hook.accept(&alert),
735 ListenerPanicPolicy::Isolate => {
736 let _ = catch_unwind(AssertUnwindSafe(|| hook.accept(&alert)));
737 }
738 }
739 }
740 }
741}