use core::{fmt, mem, ops};
#[must_use = "`Undo` implicitly runs its undo function on drop; use `Undo::commit(...)` \
to disable"]
pub struct Undo<T, F>
where
F: FnOnce(T),
{
inner: mem::ManuallyDrop<T>,
undo: mem::ManuallyDrop<F>,
}
impl<T, F> Drop for Undo<T, F>
where
F: FnOnce(T),
{
fn drop(&mut self) {
let inner = unsafe { mem::ManuallyDrop::take(&mut self.inner) };
let undo = unsafe { mem::ManuallyDrop::take(&mut self.undo) };
undo(inner);
}
}
impl<T, F> fmt::Debug for Undo<T, F>
where
F: FnOnce(T),
T: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Undo")
.field("inner", &self.inner)
.field("undo", &"..")
.finish()
}
}
impl<T, F> ops::Deref for Undo<T, F>
where
F: FnOnce(T),
{
type Target = T;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl<T, F> ops::DerefMut for Undo<T, F>
where
F: FnOnce(T),
{
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
impl<T, F> Undo<T, F>
where
F: FnOnce(T),
{
pub fn new(inner: T, undo: F) -> Self {
Self {
inner: mem::ManuallyDrop::new(inner),
undo: mem::ManuallyDrop::new(undo),
}
}
pub fn commit(guard: Self) -> T {
let mut guard = mem::ManuallyDrop::new(guard);
unsafe {
mem::ManuallyDrop::drop(&mut guard.undo);
mem::ManuallyDrop::take(&mut guard.inner)
}
}
}
#[cfg(all(test, feature = "std"))]
mod tests {
use super::*;
use crate::error::{Result, ensure};
use core::{cell::Cell, cmp};
use std::{panic, string::ToString};
#[derive(Default)]
struct Counter {
value: u32,
max_value_seen: u32,
}
impl Counter {
fn inc(&mut self, mut f: impl FnMut(&Self) -> Result<()>) -> Result<()> {
f(self)?;
self.value += 1;
self.max_value_seen = cmp::max(self.max_value_seen, self.value);
Ok(())
}
fn dec(&mut self) {
self.value -= 1;
}
fn inc_n(&mut self, n: u32, mut f: impl FnMut(&Self) -> Result<()>) -> Result<()> {
let i = Cell::new(0);
let mut counter = Undo::new(self, |counter| {
for _ in 0..i.get() {
counter.dec();
}
});
for _ in 0..n {
counter.inc(&mut f)?;
i.set(i.get() + 1);
}
Undo::commit(counter);
Ok(())
}
}
#[test]
fn error_propagation() {
let mut counter = Counter::default();
let result = counter.inc_n(10, |c| {
ensure!(c.value < 5, "uh oh");
Ok(())
});
assert_eq!(result.unwrap_err().to_string(), "uh oh");
assert_eq!(counter.value, 0);
assert_eq!(counter.max_value_seen, 5);
}
#[test]
fn panic_unwind() {
let mut counter = Counter::default();
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
counter.inc_n(10, |c| {
assert!(c.value < 5);
Ok(())
})
}));
assert!(result.is_err());
assert_eq!(counter.value, 0);
assert_eq!(counter.max_value_seen, 5);
}
#[test]
fn commit() {
let mut counter = Counter::default();
let result = counter.inc_n(10, |_| Ok(()));
assert!(result.is_ok());
assert_eq!(counter.value, 10);
assert_eq!(counter.max_value_seen, 10);
}
}