use std::{error::Error, fmt};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OnDelete {
Cascade,
Protect,
SetNull,
SetDefault,
DoNothing,
Restrict,
}
impl fmt::Display for OnDelete {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Cascade => write!(f, "CASCADE"),
Self::Protect => write!(f, "PROTECT"),
Self::SetNull => write!(f, "SET NULL"),
Self::SetDefault => write!(f, "SET DEFAULT"),
Self::DoNothing => write!(f, "DO NOTHING"),
Self::Restrict => write!(f, "RESTRICT"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DeletionError {
ProtectedError(String),
RestrictedError(String),
}
impl fmt::Display for DeletionError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::ProtectedError(message) | Self::RestrictedError(message) => f.write_str(message),
}
}
}
impl Error for DeletionError {}
pub struct Collector {
to_delete: Vec<(String, Vec<String>)>,
protected: Vec<String>,
post_delete_callbacks: Vec<Box<dyn Fn() + Send>>,
}
impl Default for Collector {
fn default() -> Self {
Self::new()
}
}
impl Collector {
#[must_use]
pub fn new() -> Self {
Self {
to_delete: Vec::new(),
protected: Vec::new(),
post_delete_callbacks: Vec::new(),
}
}
pub fn collect(
&mut self,
table: &str,
ids: Vec<String>,
policy: OnDelete,
) -> Result<(), DeletionError> {
if ids.is_empty() {
return Ok(());
}
match policy {
OnDelete::Cascade => {
if let Some((_, existing_ids)) =
self.to_delete.iter_mut().find(|(name, _)| name == table)
{
existing_ids.extend(ids);
} else {
self.to_delete.push((table.to_owned(), ids));
}
Ok(())
}
OnDelete::Protect => {
let message = format!(
"Cannot delete {table}; protected related objects: {}",
ids.join(", ")
);
self.protected.push(message.clone());
Err(DeletionError::ProtectedError(message))
}
OnDelete::Restrict => Err(DeletionError::RestrictedError(format!(
"Cannot delete {table}; restricted related objects: {}",
ids.join(", ")
))),
OnDelete::SetNull | OnDelete::SetDefault | OnDelete::DoNothing => Ok(()),
}
}
#[must_use]
pub fn can_delete(&self) -> bool {
self.protected.is_empty()
}
#[must_use]
pub fn delete_count(&self) -> usize {
self.to_delete.iter().map(|(_, ids)| ids.len()).sum()
}
pub fn add_post_delete_callback<F: Fn() + Send + 'static>(&mut self, callback: F) {
self.post_delete_callbacks.push(Box::new(callback));
}
}
#[cfg(test)]
mod tests {
use super::{Collector, DeletionError, OnDelete};
#[test]
fn on_delete_display_cascade() {
assert_eq!(OnDelete::Cascade.to_string(), "CASCADE");
}
#[test]
fn on_delete_display_protect() {
assert_eq!(OnDelete::Protect.to_string(), "PROTECT");
}
#[test]
fn on_delete_display_set_null() {
assert_eq!(OnDelete::SetNull.to_string(), "SET NULL");
}
#[test]
fn on_delete_display_set_default() {
assert_eq!(OnDelete::SetDefault.to_string(), "SET DEFAULT");
}
#[test]
fn on_delete_display_do_nothing() {
assert_eq!(OnDelete::DoNothing.to_string(), "DO NOTHING");
}
#[test]
fn on_delete_display_restrict() {
assert_eq!(OnDelete::Restrict.to_string(), "RESTRICT");
}
#[test]
fn collector_new_starts_empty() {
let collector = Collector::new();
assert!(collector.can_delete());
assert_eq!(collector.delete_count(), 0);
assert!(collector.to_delete.is_empty());
assert!(collector.protected.is_empty());
}
#[test]
fn collect_cascade_tracks_ids() {
let mut collector = Collector::new();
collector
.collect(
"authors",
vec![String::from("1"), String::from("2")],
OnDelete::Cascade,
)
.expect("cascade should succeed");
assert_eq!(collector.delete_count(), 2);
assert_eq!(
collector.to_delete,
vec![(
String::from("authors"),
vec![String::from("1"), String::from("2")]
)]
);
}
#[test]
fn collect_cascade_merges_same_table() {
let mut collector = Collector::new();
collector
.collect("authors", vec![String::from("1")], OnDelete::Cascade)
.expect("first cascade should succeed");
collector
.collect("authors", vec![String::from("2")], OnDelete::Cascade)
.expect("second cascade should succeed");
assert_eq!(collector.delete_count(), 2);
assert_eq!(
collector.to_delete[0].1,
vec![String::from("1"), String::from("2")]
);
}
#[test]
fn collect_protect_returns_error_and_blocks_deletion() {
let mut collector = Collector::new();
let error = collector
.collect("authors", vec![String::from("1")], OnDelete::Protect)
.expect_err("protect should fail");
assert_eq!(
error,
DeletionError::ProtectedError(
"Cannot delete authors; protected related objects: 1".to_string()
)
);
assert!(!collector.can_delete());
assert_eq!(
collector.protected,
vec![String::from(
"Cannot delete authors; protected related objects: 1"
)]
);
}
#[test]
fn collect_restrict_returns_error() {
let mut collector = Collector::new();
let error = collector
.collect("authors", vec![String::from("1")], OnDelete::Restrict)
.expect_err("restrict should fail");
assert_eq!(
error,
DeletionError::RestrictedError(
"Cannot delete authors; restricted related objects: 1".to_string()
)
);
assert!(collector.can_delete());
assert_eq!(collector.delete_count(), 0);
}
#[test]
fn collect_set_null_does_not_queue_delete() {
let mut collector = Collector::new();
collector
.collect("authors", vec![String::from("1")], OnDelete::SetNull)
.expect("set null should succeed");
assert_eq!(collector.delete_count(), 0);
assert!(collector.can_delete());
}
#[test]
fn collect_set_default_does_not_queue_delete() {
let mut collector = Collector::new();
collector
.collect("authors", vec![String::from("1")], OnDelete::SetDefault)
.expect("set default should succeed");
assert_eq!(collector.delete_count(), 0);
assert!(collector.can_delete());
}
#[test]
fn collect_do_nothing_does_not_queue_delete() {
let mut collector = Collector::new();
collector
.collect("authors", vec![String::from("1")], OnDelete::DoNothing)
.expect("do nothing should succeed");
assert_eq!(collector.delete_count(), 0);
assert!(collector.can_delete());
}
#[test]
fn collect_empty_ids_is_a_no_op() {
let mut collector = Collector::new();
collector
.collect("authors", Vec::new(), OnDelete::Cascade)
.expect("empty collections should be ignored");
assert_eq!(collector.delete_count(), 0);
assert!(collector.to_delete.is_empty());
}
#[test]
fn add_post_delete_callback_registers_callback() {
let mut collector = Collector::new();
collector.add_post_delete_callback(|| {});
assert_eq!(collector.post_delete_callbacks.len(), 1);
}
#[test]
fn protected_error_display_uses_message() {
let error = DeletionError::ProtectedError("protected".to_string());
assert_eq!(error.to_string(), "protected");
}
#[test]
fn restricted_error_display_uses_message() {
let error = DeletionError::RestrictedError("restricted".to_string());
assert_eq!(error.to_string(), "restricted");
}
}