use std::sync::{Arc, Mutex};
use super::super::{
AdapterFuture, Create, DbAdapter, DbRecord, Delete, DeleteMany, FindMany, Update, UpdateMany,
};
use crate::context::request_state::current_request_path;
use crate::error::OpenAuthError;
use crate::plugin::{
PluginDatabaseAfterInput, PluginDatabaseBeforeAction, PluginDatabaseBeforeInput,
PluginDatabaseHook, PluginDatabaseHookContext, PluginDatabaseOperation,
};
#[derive(Clone, Default)]
pub(super) struct AfterHookQueue {
inputs: Arc<Mutex<Vec<PluginDatabaseAfterInput>>>,
}
impl AfterHookQueue {
fn push(&self, input: PluginDatabaseAfterInput) -> Result<(), OpenAuthError> {
self.inputs
.lock()
.map_err(|_| OpenAuthError::LockPoisoned {
context: "after hook queue",
})?
.push(input);
Ok(())
}
pub(super) async fn run<A>(
&self,
hooks: &[PluginDatabaseHook],
adapter: &A,
) -> Result<(), OpenAuthError>
where
A: DbAdapter,
{
let inputs = {
let mut guard = self
.inputs
.lock()
.map_err(|_| OpenAuthError::LockPoisoned {
context: "after hook queue",
})?;
std::mem::take(&mut *guard)
};
for input in inputs {
run_after_hooks(hooks, input, adapter).await?;
}
Ok(())
}
}
pub(super) fn hooked_create<'a, A>(
inner: &'a A,
hooks: Arc<Vec<PluginDatabaseHook>>,
after_queue: Option<AfterHookQueue>,
query: Create,
) -> AdapterFuture<'a, DbRecord>
where
A: DbAdapter,
{
Box::pin(async move {
let query = match run_before_hooks(
hooks.as_slice(),
PluginDatabaseBeforeInput::Create(query),
inner,
)
.await?
{
PluginDatabaseBeforeInput::Create(query) => query,
other => {
return Err(mismatched_continue_input(
PluginDatabaseOperation::Create,
other,
))
}
};
let result = inner.create(query.clone()).await?;
run_or_queue_after_hooks(
after_queue.as_ref(),
hooks.as_slice(),
PluginDatabaseAfterInput::Create {
query,
result: result.clone(),
},
inner,
)
.await?;
Ok(result)
})
}
pub(super) fn hooked_update<'a, A>(
inner: &'a A,
hooks: Arc<Vec<PluginDatabaseHook>>,
after_queue: Option<AfterHookQueue>,
query: Update,
) -> AdapterFuture<'a, Option<DbRecord>>
where
A: DbAdapter,
{
Box::pin(async move {
let query = match run_before_hooks(
hooks.as_slice(),
PluginDatabaseBeforeInput::Update(query),
inner,
)
.await?
{
PluginDatabaseBeforeInput::Update(query) => query,
other => {
return Err(mismatched_continue_input(
PluginDatabaseOperation::Update,
other,
))
}
};
let result = inner.update(query.clone()).await?;
run_or_queue_after_hooks(
after_queue.as_ref(),
hooks.as_slice(),
PluginDatabaseAfterInput::Update {
query,
result: result.clone(),
},
inner,
)
.await?;
Ok(result)
})
}
pub(super) fn hooked_update_many<'a, A>(
inner: &'a A,
hooks: Arc<Vec<PluginDatabaseHook>>,
after_queue: Option<AfterHookQueue>,
query: UpdateMany,
) -> AdapterFuture<'a, u64>
where
A: DbAdapter,
{
Box::pin(async move {
let query = match run_before_hooks(
hooks.as_slice(),
PluginDatabaseBeforeInput::UpdateMany(query),
inner,
)
.await?
{
PluginDatabaseBeforeInput::UpdateMany(query) => query,
other => {
return Err(mismatched_continue_input(
PluginDatabaseOperation::UpdateMany,
other,
));
}
};
let result = inner.update_many(query.clone()).await?;
run_or_queue_after_hooks(
after_queue.as_ref(),
hooks.as_slice(),
PluginDatabaseAfterInput::UpdateMany { query, result },
inner,
)
.await?;
Ok(result)
})
}
pub(super) fn hooked_delete<'a, A>(
inner: &'a A,
hooks: Arc<Vec<PluginDatabaseHook>>,
after_queue: Option<AfterHookQueue>,
query: Delete,
) -> AdapterFuture<'a, ()>
where
A: DbAdapter,
{
Box::pin(async move {
let snapshots = load_delete_snapshots(
inner,
query.model.clone(),
query.where_clauses.clone(),
Some(1),
)
.await;
let (query, snapshots) = match run_before_hooks(
hooks.as_slice(),
PluginDatabaseBeforeInput::Delete { query, snapshots },
inner,
)
.await?
{
PluginDatabaseBeforeInput::Delete { query, snapshots } => (query, snapshots),
other => {
return Err(mismatched_continue_input(
PluginDatabaseOperation::Delete,
other,
))
}
};
inner.delete(query.clone()).await?;
run_or_queue_after_hooks(
after_queue.as_ref(),
hooks.as_slice(),
PluginDatabaseAfterInput::Delete { query, snapshots },
inner,
)
.await?;
Ok(())
})
}
pub(super) fn hooked_delete_many<'a, A>(
inner: &'a A,
hooks: Arc<Vec<PluginDatabaseHook>>,
after_queue: Option<AfterHookQueue>,
query: DeleteMany,
) -> AdapterFuture<'a, u64>
where
A: DbAdapter,
{
Box::pin(async move {
let snapshots = load_delete_snapshots(
inner,
query.model.clone(),
query.where_clauses.clone(),
None,
)
.await;
let (query, snapshots) = match run_before_hooks(
hooks.as_slice(),
PluginDatabaseBeforeInput::DeleteMany { query, snapshots },
inner,
)
.await?
{
PluginDatabaseBeforeInput::DeleteMany { query, snapshots } => (query, snapshots),
other => {
return Err(mismatched_continue_input(
PluginDatabaseOperation::DeleteMany,
other,
));
}
};
let result = inner.delete_many(query.clone()).await?;
run_or_queue_after_hooks(
after_queue.as_ref(),
hooks.as_slice(),
PluginDatabaseAfterInput::DeleteMany {
query,
snapshots,
result,
},
inner,
)
.await?;
Ok(result)
})
}
async fn load_delete_snapshots<A>(
inner: &A,
model: String,
where_clauses: Vec<super::super::Where>,
limit: Option<usize>,
) -> Vec<DbRecord>
where
A: DbAdapter,
{
let mut query = FindMany::new(model);
query.where_clauses = where_clauses;
query.limit = limit;
inner.find_many(query).await.unwrap_or_default()
}
async fn run_before_hooks<A>(
hooks: &[PluginDatabaseHook],
mut input: PluginDatabaseBeforeInput,
adapter: &A,
) -> Result<PluginDatabaseBeforeInput, OpenAuthError>
where
A: DbAdapter,
{
let operation = input.operation();
for hook in hooks.iter().filter(|hook| hook.operation == operation) {
let Some(handler) = &hook.before else {
continue;
};
let model = input.model().to_owned();
let context = hook_context(hook, operation, &model, adapter);
input = match handler(context, input).await? {
PluginDatabaseBeforeAction::Continue(next) => next,
PluginDatabaseBeforeAction::Cancel(error) => return Err(error),
};
if input.operation() != operation {
return Err(mismatched_continue_input(operation, input));
}
}
Ok(input)
}
async fn run_after_hooks<A>(
hooks: &[PluginDatabaseHook],
input: PluginDatabaseAfterInput,
adapter: &A,
) -> Result<(), OpenAuthError>
where
A: DbAdapter,
{
let operation = input.operation();
for hook in hooks.iter().filter(|hook| hook.operation == operation) {
let Some(handler) = &hook.after else {
continue;
};
let context = hook_context(hook, operation, input.model(), adapter);
handler(context, input.clone()).await?;
}
Ok(())
}
async fn run_or_queue_after_hooks<A>(
queue: Option<&AfterHookQueue>,
hooks: &[PluginDatabaseHook],
input: PluginDatabaseAfterInput,
adapter: &A,
) -> Result<(), OpenAuthError>
where
A: DbAdapter,
{
if let Some(queue) = queue {
queue.push(input)
} else {
run_after_hooks(hooks, input, adapter).await
}
}
fn hook_context<'a, A>(
hook: &PluginDatabaseHook,
operation: PluginDatabaseOperation,
model: &str,
adapter: &'a A,
) -> PluginDatabaseHookContext<'a>
where
A: DbAdapter + 'a,
{
PluginDatabaseHookContext {
plugin_id: hook.plugin_id().unwrap_or_default().to_owned(),
hook_name: hook.name.clone(),
operation,
model: model.to_owned(),
adapter,
request_path: current_request_path().ok().flatten(),
}
}
fn mismatched_continue_input(
expected: PluginDatabaseOperation,
actual: PluginDatabaseBeforeInput,
) -> OpenAuthError {
OpenAuthError::InvalidConfig(format!(
"database before hook for {expected:?} returned {:?} input",
actual.operation()
))
}