#![feature(never_type)]
use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind};
use std::sync::atomic::{AtomicU64, Ordering};
#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)]
struct Token(u64);
impl Token {
fn next() -> Self {
static NEXT_TOKEN: AtomicU64 = AtomicU64::new(0);
let token = NEXT_TOKEN.fetch_add(1, Ordering::Relaxed);
Token(token)
}
}
pub fn call_with_repeat_continuation<Return, Payload, Body>(
initial_payload: Payload,
mut body: Body,
) -> Return
where
Body: FnMut(Payload, &mut dyn FnMut(Payload) -> !) -> Return,
{
let my_token = Token::next();
let mut val = Some(initial_payload);
'repeat: loop {
match catch_unwind(
AssertUnwindSafe(|| {
let payload = val.take().unwrap();
body(
payload,
&mut |new_payload| {
val.replace(new_payload);
resume_unwind(Box::new(my_token))
}
)
})) {
Ok(ret) => return ret,
Err(thrown_token) => {
if let Some(&thrown_token) = thrown_token.downcast_ref::<Token>() {
if thrown_token == my_token {
continue 'repeat
}
}
resume_unwind(thrown_token)
}
}
}
}
pub fn call_with_escape_continuation<T, E, Body>(
body: Body,
) -> Result<T, E>
where
Body: FnOnce(&mut dyn FnMut(E) -> !) -> T,
{
let mut body = Some(body);
call_with_repeat_continuation(
None,
move |error, throw| {
if let Some(err) = error {
Err(err)
} else if let Some(body) = body.take() {
Ok(body(&mut |err| throw(Some(err))))
} else {
unreachable!("Loop in call/ec")
}
}
)
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn tokens_unique() {
use std::{collections::HashSet, thread};
const N_THREADS: usize = 8;
const TOKENS_PER_THREAD: usize = 1024;
let mut threads = Vec::with_capacity(8);
for _ in 0..N_THREADS {
threads.push(thread::spawn(move || {
let mut set = HashSet::with_capacity(TOKENS_PER_THREAD);
for _ in 0..TOKENS_PER_THREAD {
let token = Token::next();
if !set.insert(token) {
return Err(token);
}
}
Ok(set)
}));
}
let mut full_set = HashSet::with_capacity(TOKENS_PER_THREAD * N_THREADS);
for thread in threads.drain(..) {
let subset = thread.join()
.expect("thread panicked")
.expect("thread saw duplicate token");
for token in subset.iter().copied() {
if !full_set.insert(token) {
panic!("duplicate token while merging thread subsets");
}
}
}
}
#[test]
fn unused_callrepeat() {
let zero = call_with_repeat_continuation(2, |two, _repeat| two - two);
assert_eq!(zero, 0);
}
#[test]
fn loop_callrepeat() {
let kibi = call_with_repeat_continuation(
0,
|acc, repeat| if acc == 1024 { acc } else { repeat(acc + 1) }
);
assert_eq!(kibi, 1024)
}
#[test]
fn nested_callrepeat() {
let two = call_with_repeat_continuation(
0,
|outer_payload, outer_repeat| {
if outer_payload == 2 { outer_payload } else {
call_with_repeat_continuation(
outer_payload,
|inner_payload, inner_repeat| {
if inner_payload % 2 == 0 {
inner_repeat(inner_payload + 1)
} else {
outer_repeat(inner_payload + 1)
}
}
)
}
}
);
assert_eq!(two, 2)
}
#[test]
#[should_panic(expected = "test panic")]
fn panicing_callrepeat() {
call_with_repeat_continuation(
0,
|payload, repeat| {
if payload == 10 {
panic!("test panic")
} else {
repeat(payload + 1)
}
}
)
}
#[test]
fn unused_callec() {
let zero: Result<i32, i32> = call_with_escape_continuation(
|_throw| 1 - 1,
);
assert_eq!(zero, Ok(0))
}
#[test]
fn throw_callec() {
let zero: Result<i32, i32> = call_with_escape_continuation(
|throw| throw(0),
);
assert_eq!(zero, Err(0))
}
#[test]
fn callec_type_infer() {
for i in 0..256 {
let res = call_with_escape_continuation(
|throw| if i % 2 == 0 { i } else { throw("odd!") }
);
if i % 2 == 0 {
assert_eq!(res, Ok(i))
} else {
assert_eq!(res, Err("odd!"))
}
}
}
}