use std::sync::Arc;
#[cfg(feature = "sync")]
pub mod sync;
pub fn invoke<'f, T: 'f, F: FnOnce() -> T + 'f>(mode: &Mode<'f, T>, task: F) -> T {
let mut task: Box<dyn FnOnce() -> T + 'f> = Box::new(task);
if let Some(ref mode_combiner) = mode.mode_combiner {
for mode_wrapper in mode_combiner.iter() {
task = mode_wrapper.wrapper_ref().wrap(task);
}
}
task()
}
pub trait Invoker {
fn pre_invoke(&self) {}
fn invoke_with_mode<'f, T: 'f, F: FnOnce() -> T + 'f>(
&'f self,
mode: &'f Mode<'f, T>,
task: F,
) -> T {
self.invoke_with_mode_optional(Some(mode), task)
}
fn invoke<'f, T: 'f, F: FnOnce() -> T + 'f>(&'f self, task: F) -> T {
self.invoke_with_mode_optional(None, task)
}
fn invoke_with_mode_optional<'f, T: 'f, F: FnOnce() -> T + 'f>(
&'f self,
mode: Option<&'f Mode<'f, T>>,
task: F,
) -> T {
self.pre_invoke();
if self.invoke_post_invoke_on_panic() {
let mut sentinel = Sentinel {
invoker_ref: self,
cancelled: false,
};
let result = self.do_invoke(mode, task);
sentinel.cancelled = true;
self.post_invoke();
result
} else {
let result = self.do_invoke(mode, task);
self.post_invoke();
result
}
}
fn do_invoke<'f, T: 'f, F: FnOnce() -> T + 'f>(
&'f self,
mode: Option<&'f Mode<'f, T>>,
task: F,
) -> T {
if let Some(mode) = mode {
invoke(mode, task)
} else {
task()
}
}
fn post_invoke(&self) {}
fn invoke_post_invoke_on_panic(&self) -> bool {
false
}
fn and_then<I: Invoker>(self, inner: I) -> CombinedInvoker<Self, I>
where
Self: Sized,
{
CombinedInvoker { outer: self, inner }
}
}
pub struct Sentinel<'a, I: Invoker + ?Sized> {
invoker_ref: &'a I,
cancelled: bool,
}
impl<I: Invoker + ?Sized> Drop for Sentinel<'_, I> {
fn drop(&mut self) {
if !self.cancelled {
self.invoker_ref.post_invoke();
}
}
}
pub struct BaseInvoker {}
impl Invoker for BaseInvoker {}
pub struct CombinedInvoker<O: Invoker, I: Invoker> {
outer: O,
inner: I,
}
impl<O: Invoker, I: Invoker> CombinedInvoker<O, I> {
pub fn combine(outer: O, inner: I) -> CombinedInvoker<O, I> {
CombinedInvoker { outer, inner }
}
}
impl<O: Invoker, I: Invoker> Invoker for CombinedInvoker<O, I> {
fn invoke_with_mode_optional<'f, T: 'f, F: FnOnce() -> T + 'f>(
&'f self,
mode: Option<&'f Mode<'f, T>>,
task: F,
) -> T {
self.outer.invoke_with_mode_optional(mode, move || {
self.inner.invoke_with_mode_optional(mode, task)
})
}
}
pub struct Mode<'m, T: 'm> {
mode_combiner: Option<Box<dyn ModeCombiner<'m, T> + 'm + Send + Sync>>,
}
impl<'m, T: 'm> Mode<'m, T> {
pub fn new() -> Self {
Self {
mode_combiner: None,
}
}
pub fn with<M: ModeWrapper<'m, T> + 'm + Send + Sync>(mut self, mode_wrapper: M) -> Self {
if let Some(curr_combiner) = self.mode_combiner {
self.mode_combiner = Some(curr_combiner.combine(mode_wrapper.into_combiner()));
} else {
self.mode_combiner = Some(mode_wrapper.into_combiner());
}
self
}
}
impl<'m, T: 'm> Default for Mode<'m, T> {
fn default() -> Self {
Mode::new()
}
}
pub trait ModeWrapper<'m, T: 'm> {
fn wrap(self: Arc<Self>, task: Box<dyn FnOnce() -> T + 'm>) -> Box<dyn FnOnce() -> T + 'm>;
fn into_combiner(self) -> Box<dyn ModeCombiner<'m, T> + 'm + Send + Sync>
where
Self: Sized + Send + Sync + 'm,
{
Box::new(DelegatingModeCombiner {
wrapper: Arc::new(self),
outer: None,
})
}
}
pub trait ModeCombiner<'m, T: 'm> {
fn combine(
&self,
other: Box<dyn ModeCombiner<'m, T> + 'm + Send + Sync>,
) -> Box<dyn ModeCombiner<'m, T> + 'm + Send + Sync>;
fn get_outer(&self) -> Option<&(dyn ModeCombiner<'m, T> + Send + Sync)>;
fn set_outer(&mut self, outer: Arc<dyn ModeCombiner<'m, T> + 'm + Send + Sync>);
fn iter<'a>(&'a self) -> ModeCombinerIterator<'a, 'm, T>;
fn wrapper_ref(&self) -> Arc<dyn ModeWrapper<'m, T> + 'm + Send + Sync>;
}
pub struct DelegatingModeCombiner<'m, T> {
wrapper: Arc<dyn ModeWrapper<'m, T> + 'm + Send + Sync>,
outer: Option<Arc<dyn ModeCombiner<'m, T> + 'm + Send + Sync>>,
}
impl<T> Clone for DelegatingModeCombiner<'_, T> {
fn clone(&self) -> Self {
DelegatingModeCombiner {
wrapper: self.wrapper.clone(),
outer: self.outer.clone(),
}
}
}
impl<'m, T> ModeCombiner<'m, T> for DelegatingModeCombiner<'m, T> {
fn combine(
&self,
mut other: Box<dyn ModeCombiner<'m, T> + 'm + Send + Sync>,
) -> Box<dyn ModeCombiner<'m, T> + 'm + Send + Sync> {
let clone = self.clone();
other.set_outer(Arc::new(clone));
other
}
fn get_outer(&self) -> Option<&(dyn ModeCombiner<'m, T> + Send + Sync)> {
self.outer.as_ref().map(|outer| outer.as_ref())
}
fn set_outer(&mut self, outer: Arc<dyn ModeCombiner<'m, T> + 'm + Send + Sync>) {
self.outer = Some(outer);
}
fn iter<'a>(&'a self) -> ModeCombinerIterator<'a, 'm, T> {
ModeCombinerIterator {
mode_combiner: self,
curr_combiner: None,
}
}
fn wrapper_ref(&self) -> Arc<dyn ModeWrapper<'m, T> + 'm + Send + Sync> {
self.wrapper.clone()
}
}
pub struct ModeCombinerIterator<'a, 'm, T: 'm> {
mode_combiner: &'a dyn ModeCombiner<'m, T>,
curr_combiner: Option<&'a dyn ModeCombiner<'m, T>>,
}
impl<'a, 'm, T: 'm> Iterator for ModeCombinerIterator<'a, 'm, T> {
type Item = &'a dyn ModeCombiner<'m, T>;
fn next(&mut self) -> Option<<Self as Iterator>::Item> {
if let Some(curr_wrapper) = self.curr_combiner {
let curr_outer = curr_wrapper.get_outer();
if let Some(curr_outer) = curr_outer {
self.curr_combiner = Some(curr_outer);
} else {
return None;
}
} else {
self.curr_combiner = Some(self.mode_combiner);
}
self.curr_combiner
}
}
#[cfg(test)]
mod tests {
use crate::{invoke, BaseInvoker, Invoker, Mode, ModeCombinerIterator, ModeWrapper};
use std::sync::{
atomic::{AtomicU16, Ordering},
Arc,
};
static PRE_COUNTER: AtomicU16 = AtomicU16::new(1);
static POST_COUNTER: AtomicU16 = AtomicU16::new(1);
struct MultiplyTwoMode {}
impl ModeWrapper<'static, i32> for MultiplyTwoMode {
fn wrap<'f>(
self: Arc<Self>,
task: Box<(dyn FnOnce() -> i32 + 'f)>,
) -> Box<(dyn FnOnce() -> i32 + 'f)> {
Box::new(move || {
return task() * 2;
})
}
}
struct AddTwoMode {}
impl ModeWrapper<'static, i32> for AddTwoMode {
fn wrap<'f>(
self: Arc<Self>,
task: Box<(dyn FnOnce() -> i32 + 'f)>,
) -> Box<(dyn FnOnce() -> i32 + 'f)> {
Box::new(move || {
return task() + 2;
})
}
}
struct CounterInvoker {}
impl Invoker for CounterInvoker {
fn pre_invoke(&self) {
PRE_COUNTER.fetch_add(1, Ordering::Relaxed);
}
fn post_invoke(&self) {
POST_COUNTER.fetch_add(1, Ordering::Relaxed);
}
}
struct MultInvoker {}
impl Invoker for MultInvoker {
fn pre_invoke(&self) {
PRE_COUNTER
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |x| Some(x * 2))
.unwrap();
}
fn post_invoke(&self) {
POST_COUNTER
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |x| Some(x * 2))
.unwrap();
}
}
struct StringRefMode<'a> {
str_ref: &'a str,
}
impl<'a> ModeWrapper<'a, &'a str> for StringRefMode<'a> {
fn wrap(
self: Arc<Self>,
task: Box<(dyn FnOnce() -> &'a str + 'a)>,
) -> Box<(dyn FnOnce() -> &'a str + 'a)> {
Box::new(move || {
task();
self.str_ref
})
}
}
struct ModeCombinerIteratorMode {}
impl<'a, 'm> ModeWrapper<'a, ModeCombinerIterator<'a, 'm, &'m str>> for ModeCombinerIteratorMode {
fn wrap(
self: Arc<Self>,
task: Box<dyn FnOnce() -> ModeCombinerIterator<'a, 'm, &'m str> + 'a>,
) -> Box<dyn FnOnce() -> ModeCombinerIterator<'a, 'm, &'m str> + 'a> {
Box::new(move || task())
}
}
#[test]
fn it_works() {
let mode = Mode::new().with(MultiplyTwoMode {}).with(AddTwoMode {});
assert_eq!(invoke(&mode, || 2 + 2), 12);
}
#[test]
fn test_combined_invoker() {
let invoker = CounterInvoker {}
.and_then(MultInvoker {})
.and_then(MultInvoker {})
.and_then(CounterInvoker {});
invoker.invoke(|| {
PRE_COUNTER
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |x| Some(x * 3))
.unwrap();
POST_COUNTER
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |x| Some(x * 3))
.unwrap();
});
assert_eq!(PRE_COUNTER.load(Ordering::Relaxed), 27);
assert_eq!(POST_COUNTER.load(Ordering::Relaxed), 17);
}
#[test]
fn test_lifetime_iterator() {
let s = String::from("test");
let m = StringRefMode { str_ref: &s };
let combiner = m.into_combiner();
let iter = combiner.iter();
let mode = Mode::new().with(ModeCombinerIteratorMode {});
let _iter = invoke(&mode, move || iter);
}
#[test]
fn test_lifetime_str_ref() {
let s = String::from("test");
let m = StringRefMode { str_ref: &s };
let mode = Mode::new().with(m);
assert_eq!("test", invoke(&mode, || { "fail" }));
}
#[test]
fn test_shorter_lifetime() {
let invoker = BaseInvoker {};
{
let shorter_lived_string = String::from("test");
let str_ref = invoker.invoke(|| &shorter_lived_string);
assert_eq!(str_ref, "test");
}
}
#[test]
fn test_post_invoke_on_panic() {
struct TestPostInvoke {
counter: Arc<AtomicU16>,
}
impl Invoker for TestPostInvoke {
fn post_invoke(&self) {
self.counter.fetch_add(1, Ordering::Relaxed);
}
fn invoke_post_invoke_on_panic(&self) -> bool {
true
}
}
let counter = Arc::new(AtomicU16::new(0));
let test = TestPostInvoke {
counter: counter.clone(),
};
let handle = std::thread::spawn(move || {
test.invoke(|| {
panic!("test panic");
});
});
let _ = handle.join();
assert_eq!(counter.load(Ordering::Relaxed), 1);
}
#[test]
fn test_not_post_invoke_on_panic() {
struct TestPostInvoke {
counter: Arc<AtomicU16>,
}
impl Invoker for TestPostInvoke {
fn post_invoke(&self) {
self.counter.fetch_add(1, Ordering::Relaxed);
}
fn invoke_post_invoke_on_panic(&self) -> bool {
false
}
}
let counter = Arc::new(AtomicU16::new(0));
let test = TestPostInvoke {
counter: counter.clone(),
};
let handle = std::thread::spawn(move || {
test.invoke(|| {
panic!("test panic");
});
});
let _ = handle.join();
assert_eq!(counter.load(Ordering::Relaxed), 0);
}
#[test]
fn test_send_mode() {
struct MultiplierMode {
multiplier: u16,
}
impl ModeWrapper<'static, u16> for MultiplierMode {
fn wrap(
self: Arc<Self>,
task: Box<(dyn FnOnce() -> u16)>,
) -> Box<(dyn FnOnce() -> u16)> {
Box::new(move || {
return task() * self.multiplier;
})
}
}
let mode = Arc::new(Mode::new().with(MultiplierMode { multiplier: 4 }));
let result = Arc::new(AtomicU16::new(0));
let m = mode.clone();
let r = result.clone();
let handle = std::thread::spawn(move || {
let result = invoke(&m, || 5);
r.store(result, Ordering::Relaxed);
});
handle.join().unwrap();
assert_eq!(result.load(Ordering::Relaxed), 20);
}
}