use std::{
any::Any,
fmt::Display,
marker::PhantomData,
panic::{self, AssertUnwindSafe},
};
use qubit_function::{
ArcRunnable, ArcTester, Callable, CallableWith, Runnable, RunnableWith, Tester,
};
use super::{
ExecutionContext, ExecutionLogger, ExecutionResult, ExecutorError,
executor_builder::ExecutorBuilder, executor_ready_builder::ExecutorReadyBuilder,
};
use crate::lock::Lock;
#[derive(Clone)]
pub struct DoubleCheckedLockExecutor<L = (), T = ()> {
lock: L,
tester: ArcTester,
logger: ExecutionLogger,
prepare_action: Option<ArcRunnable<CallbackError>>,
rollback_prepare_action: Option<ArcRunnable<CallbackError>>,
commit_prepare_action: Option<ArcRunnable<CallbackError>>,
catch_panics: bool,
_phantom: PhantomData<fn() -> T>,
}
impl DoubleCheckedLockExecutor<(), ()> {
#[inline]
pub fn builder() -> ExecutorBuilder {
ExecutorBuilder::default()
}
}
impl<L, T> DoubleCheckedLockExecutor<L, T>
where
L: Lock<T>,
{
#[inline]
pub fn new(builder: ExecutorReadyBuilder<L, T>) -> Self {
Self {
lock: builder.lock,
tester: builder.tester,
logger: builder.logger,
prepare_action: builder.prepare_action,
rollback_prepare_action: builder.rollback_prepare_action,
commit_prepare_action: builder.commit_prepare_action,
catch_panics: builder.catch_panics,
_phantom: builder._phantom,
}
}
#[inline]
pub fn call<C, R, E>(&self, task: C) -> ExecutionContext<R, E>
where
C: Callable<R, E>,
E: Display,
{
let mut task = task;
let result = self.execute_with_write_lock(move |_data| task.call());
ExecutionContext::new(result)
}
#[inline]
pub fn execute<Rn, E>(&self, task: Rn) -> ExecutionContext<(), E>
where
Rn: Runnable<E>,
E: Display,
{
let mut task = task;
let result = self.execute_with_write_lock(move |_data| task.run());
ExecutionContext::new(result)
}
#[inline]
pub fn call_with<C, R, E>(&self, task: C) -> ExecutionContext<R, E>
where
C: CallableWith<T, R, E>,
E: Display,
{
let mut task = task;
let result = self.execute_with_write_lock(move |data| task.call_with(data));
ExecutionContext::new(result)
}
#[inline]
pub fn execute_with<Rn, E>(&self, task: Rn) -> ExecutionContext<(), E>
where
Rn: RunnableWith<T, E>,
E: Display,
{
let mut task = task;
let result = self.execute_with_write_lock(move |data| task.run_with(data));
ExecutionContext::new(result)
}
#[inline]
pub fn set_catch_panics(mut self, catch_panics: bool) -> Self {
self.catch_panics = catch_panics;
self
}
#[deprecated(note = "Use `set_catch_panics` instead to align with setter naming.")]
#[inline]
pub fn with_catch_panics(self, catch_panics: bool) -> Self {
self.set_catch_panics(catch_panics)
}
#[inline]
pub fn catch_panics(&self) -> bool {
self.catch_panics
}
fn execute_with_write_lock<R, E, F>(&self, task: F) -> ExecutionResult<R, E>
where
E: Display,
F: FnOnce(&mut T) -> Result<R, E>,
{
let first_check = match self.try_run("tester", || self.tester.test()) {
Ok(v) => v,
Err(error) => {
return ExecutionResult::from_executor_error(ExecutorError::Panic(error));
}
};
if !first_check {
self.log_unmet_condition();
return ExecutionResult::unmet();
}
let prepare_completed = match self.run_prepare_action() {
Ok(completed) => completed,
Err(error) => {
return ExecutionResult::from_executor_error(ExecutorError::PrepareFailed(error));
}
};
let result = self.lock.write(|data| {
let passed = match self.try_run("tester", || self.tester.test()) {
Ok(v) => v,
Err(error) => {
return ExecutionResult::from_executor_error(ExecutorError::Panic(error));
}
};
if !passed {
return ExecutionResult::unmet();
}
match self.try_run("task", || task(data)) {
Ok(Ok(value)) => ExecutionResult::success(value),
Ok(Err(error)) => ExecutionResult::task_failed(error),
Err(error) => ExecutionResult::from_executor_error(ExecutorError::Panic(error)),
}
});
if result.is_unmet() {
self.log_unmet_condition();
}
if prepare_completed {
self.finalize_prepare(result)
} else {
result
}
}
fn run_prepare_action(&self) -> Result<bool, CallbackError> {
let Some(mut prepare_action) = self.prepare_action.clone() else {
return Ok(false);
};
match self.try_run("prepare", move || prepare_action.run()) {
Ok(Ok(_)) => Ok(true),
Ok(Err(error)) => {
self.logger.log_prepare_failed(&error);
Err(error)
}
Err(error) => {
self.logger.log_prepare_failed(&error);
Err(error)
}
}
}
fn finalize_prepare<R, E>(&self, mut result: ExecutionResult<R, E>) -> ExecutionResult<R, E>
where
E: Display,
{
if result.is_success() {
if let Some(mut commit_prepare_action) = self.commit_prepare_action.clone() {
match self.try_run("prepare_commit", move || commit_prepare_action.run()) {
Ok(Ok(_)) => {}
Ok(Err(error)) => {
self.logger.log_prepare_commit_failed(&error);
result = ExecutionResult::from_executor_error(
ExecutorError::PrepareCommitFailed(error),
);
}
Err(error) => {
self.logger.log_prepare_commit_failed(&error);
result = ExecutionResult::from_executor_error(
ExecutorError::PrepareCommitFailed(error),
);
}
}
}
return result;
}
let original = if let ExecutionResult::Failed(error) = &result {
error.to_string()
} else {
"Condition not met".to_string()
};
if let Some(mut rollback_prepare_action) = self.rollback_prepare_action.clone() {
match self.try_run("prepare_rollback", move || rollback_prepare_action.run()) {
Ok(Ok(_)) => {}
Ok(Err(error)) => {
self.logger.log_prepare_rollback_failed(&error);
result = ExecutionResult::prepare_rollback_failed(original, error.message());
}
Err(error) => {
self.logger.log_prepare_rollback_failed(&error);
result = ExecutionResult::prepare_rollback_failed(original, error.message());
}
}
}
result
}
fn try_run<R>(
&self,
callback_type: &'static str,
callback: impl FnOnce() -> R,
) -> Result<R, CallbackError> {
if !self.catch_panics {
return Ok(callback());
}
match panic::catch_unwind(AssertUnwindSafe(callback)) {
Ok(result) => Ok(result),
Err(payload) => {
let message = panic_payload_to_message(&*payload);
Err(CallbackError::with_type(callback_type, message))
}
}
}
fn log_unmet_condition(&self) {
self.logger.log_unmet_condition();
}
}
type CallbackError = super::callback_error::CallbackError;
fn panic_payload_to_message(payload: &(dyn Any + Send)) -> String {
if let Some(message) = payload.downcast_ref::<&str>() {
(*message).to_string()
} else if let Some(message) = payload.downcast_ref::<String>() {
message.to_string()
} else {
format!("{:?}", payload)
}
}