use std::{
fmt::Display,
marker::PhantomData,
};
use qubit_function::{
ArcRunnable,
ArcTester,
Callable,
CallableWith,
Runnable,
RunnableWith,
Tester,
};
use super::{
ExecutionContext,
ExecutionLogger,
ExecutionResult,
executor_builder::ExecutorBuilder,
executor_ready_builder::ExecutorReadyBuilder,
};
use crate::lock::Lock;
#[derive(Clone)]
pub struct DoubleCheckedLockExecutor<L = (), T = ()> {
lock: L,
tester: ArcTester,
logger: ExecutionLogger,
prepare_action: Option<ArcRunnable<String>>,
rollback_prepare_action: Option<ArcRunnable<String>>,
commit_prepare_action: Option<ArcRunnable<String>>,
_phantom: PhantomData<fn() -> T>,
}
impl DoubleCheckedLockExecutor<(), ()> {
#[inline]
pub fn builder() -> ExecutorBuilder {
ExecutorBuilder::default()
}
}
impl<L, T> DoubleCheckedLockExecutor<L, T>
where
L: Lock<T>,
{
#[inline]
pub fn new(builder: ExecutorReadyBuilder<L, T>) -> Self {
Self {
lock: builder.lock,
tester: builder.tester,
logger: builder.logger,
prepare_action: builder.prepare_action,
rollback_prepare_action: builder.rollback_prepare_action,
commit_prepare_action: builder.commit_prepare_action,
_phantom: builder._phantom,
}
}
#[inline]
pub fn call<C, R, E>(&self, task: C) -> ExecutionContext<R, E>
where
C: Callable<R, E> + Send + 'static,
R: Send + 'static,
E: Display + Send + 'static,
{
let mut task = task;
let result = self.execute_with_write_lock(move |_data| task.call());
ExecutionContext::new(result)
}
#[inline]
pub fn execute<Rn, E>(&self, task: Rn) -> ExecutionContext<(), E>
where
Rn: Runnable<E> + Send + 'static,
E: Display + Send + 'static,
{
let mut task = task;
let result = self.execute_with_write_lock(move |_data| task.run());
ExecutionContext::new(result)
}
#[inline]
pub fn call_with<C, R, E>(&self, task: C) -> ExecutionContext<R, E>
where
C: CallableWith<T, R, E> + Send + 'static,
R: Send + 'static,
E: Display + Send + 'static,
{
let mut task = task;
let result = self.execute_with_write_lock(move |data| task.call_with(data));
ExecutionContext::new(result)
}
#[inline]
pub fn execute_with<Rn, E>(&self, task: Rn) -> ExecutionContext<(), E>
where
Rn: RunnableWith<T, E> + Send + 'static,
E: Display + Send + 'static,
{
let mut task = task;
let result = self.execute_with_write_lock(move |data| task.run_with(data));
ExecutionContext::new(result)
}
fn execute_with_write_lock<R, E, F>(&self, task: F) -> ExecutionResult<R, E>
where
E: Display + Send + 'static,
F: FnOnce(&mut T) -> Result<R, E>,
{
if !self.tester.test() {
self.log_unmet_condition();
return ExecutionResult::unmet();
}
let prepare_completed = match self.run_prepare_action() {
Ok(completed) => completed,
Err(error) => return ExecutionResult::prepare_failed(error),
};
let result = self.lock.write(|data| {
if !self.tester.test() {
self.log_unmet_condition();
return ExecutionResult::unmet();
}
match task(data) {
Ok(value) => ExecutionResult::success(value),
Err(error) => ExecutionResult::task_failed(error),
}
});
if prepare_completed {
self.finalize_prepare(result)
} else {
result
}
}
fn run_prepare_action(&self) -> Result<bool, String> {
let Some(mut prepare_action) = self.prepare_action.clone() else {
return Ok(false);
};
if let Err(error) = prepare_action.run() {
self.logger.log_prepare_failed(&error);
return Err(error);
}
Ok(true)
}
fn finalize_prepare<R, E>(&self, mut result: ExecutionResult<R, E>) -> ExecutionResult<R, E>
where
E: Display + Send + 'static,
{
if result.is_success() {
if let Some(mut commit_prepare_action) = self.commit_prepare_action.clone()
&& let Err(error) = commit_prepare_action.run()
{
self.logger.log_prepare_commit_failed(&error);
result = ExecutionResult::prepare_commit_failed(error);
}
return result;
}
let original = if let ExecutionResult::Failed(error) = &result {
error.to_string()
} else {
"Condition not met".to_string()
};
if let Some(mut rollback_prepare_action) = self.rollback_prepare_action.clone()
&& let Err(error) = rollback_prepare_action.run()
{
self.logger.log_prepare_rollback_failed(&error);
result = ExecutionResult::prepare_rollback_failed(original, error);
}
result
}
fn log_unmet_condition(&self) {
self.logger.log_unmet_condition();
}
}