use std::collections::HashMap;
#[cfg(not(miri))]
use std::path::Path;
use tailcall::*;
fn factorial(input: u64) -> u64 {
#[tailcall]
fn factorial_inner(accumulator: u64, input: u64) -> u64 {
if input > 0 {
tailcall::call! { factorial_inner(accumulator * input, input - 1) }
} else {
accumulator
}
}
factorial_inner(1, input)
}
#[test]
fn test_factorial_correctness() {
assert_eq!(factorial(0), 1);
assert_eq!(factorial(1), 1);
assert_eq!(factorial(2), 2);
assert_eq!(factorial(3), 6);
assert_eq!(factorial(4), 24);
}
fn memoized_factorial(input: u64, memo: &mut HashMap<u64, u64>) -> u64 {
#[tailcall]
fn factorial_inner(accumulator: u64, input: u64, memo: &mut HashMap<u64, u64>) -> u64 {
memo.insert(input, accumulator);
if input > 0 {
tailcall::call! { factorial_inner(accumulator * input, input - 1, memo) }
} else {
accumulator
}
}
factorial_inner(1, input, memo)
}
#[tailcall]
#[allow(dead_code)]
fn add_iter<'a, I>(int_iter: I, accum: i32) -> i32
where
I: Iterator<Item = &'a i32>,
{
let mut int_iter = int_iter;
match int_iter.next() {
Some(i) => tailcall::call! { add_iter(int_iter, accum + i) },
None => accum,
}
}
#[test]
fn test_memoized_factorial_correctness() {
let mut memo = HashMap::new();
assert_eq!(memoized_factorial(4, &mut memo), 24);
assert_eq!(memo.get(&0), Some(&24));
assert_eq!(memo.get(&1), Some(&24));
assert_eq!(memo.get(&2), Some(&12));
assert_eq!(memo.get(&3), Some(&4));
assert_eq!(memo.get(&4), Some(&1));
}
mod qualified_calls {
use tailcall::tailcall;
#[tailcall]
pub fn countdown(input: u64) -> u64 {
if input > 0 {
return tailcall::call! { self::countdown(input - 1) };
}
input
}
}
#[test]
fn test_qualified_tailcall_path_and_explicit_return() {
assert_eq!(qualified_calls::countdown(5), 0);
}
#[tailcall]
fn optimized_countdown(input: u64) -> u64 {
if input == 0 {
return 0;
}
tailcall::call! { optimized_countdown(input - 1) }
}
#[test]
fn test_simple_self_tail_recursion_keeps_hidden_thunk_builder() {
assert_eq!(optimized_countdown(10), 0);
assert_eq!(__tailcall_build_optimized_countdown_thunk(10).call(), 0);
}
#[tailcall]
fn generic_countdown<T: Copy>(input: u64, value: T) -> T {
if input == 0 {
return value;
}
tailcall::call! { generic_countdown(input - 1, value) }
}
#[test]
fn test_generic_self_tail_recursion_keeps_hidden_thunk_builder() {
assert_eq!(generic_countdown(10, 42_u32), 42);
assert_eq!(
__tailcall_build_generic_countdown_thunk(10, 42_u32).call(),
42
);
}
#[tailcall]
fn countdown_with_shadowing(input: u64) -> u64 {
if input == 0 {
return 0;
}
let input = input - 1;
tailcall::call! { countdown_with_shadowing(input) }
}
#[test]
fn test_shadowing_falls_back_to_thunk_backend() {
assert_eq!(countdown_with_shadowing(10), 0);
assert_eq!(
__tailcall_build_countdown_with_shadowing_thunk(10).call(),
0
);
}
fn gcd_with_trace(a: u64, b: u64) -> (u64, Vec<(u64, u64)>) {
#[tailcall]
fn gcd_inner(a: u64, b: u64, trace: &mut Vec<(u64, u64)>) -> u64 {
trace.push((a, b));
match b {
0 => a,
_ => {
let next = (b, a % b);
tailcall::call! { gcd_inner(next.0, next.1, trace) }
}
}
}
let mut trace = Vec::new();
let gcd = gcd_inner(a, b, &mut trace);
(gcd, trace)
}
#[test]
fn test_tailcall_with_nested_match_and_mutable_state() {
let (gcd, trace) = gcd_with_trace(48, 18);
assert_eq!(gcd, 6);
assert_eq!(trace, vec![(48, 18), (18, 12), (12, 6), (6, 0)]);
}
fn skip_leading_separators(input: &str) -> usize {
#[tailcall]
fn skip_leading_separators_inner(rest: &[u8]) -> usize {
match rest {
[b' ' | b',', tail @ ..] => {
tailcall::call! { skip_leading_separators_inner(tail) }
}
_ => rest.len(),
}
}
skip_leading_separators_inner(input.as_bytes())
}
#[test]
fn test_tailcall_over_borrowed_input_with_state_machine_logic() {
assert_eq!(skip_leading_separators("10, 20,3"), 8);
assert_eq!(skip_leading_separators(" , abc"), 3);
assert_eq!(skip_leading_separators(","), 0);
assert_eq!(skip_leading_separators(""), 0);
}
#[tailcall]
fn is_even_macro(x: u128) -> bool {
if x == 0 {
true
} else {
tailcall::call! { is_odd_macro(x - 1) }
}
}
#[tailcall]
fn is_odd_macro(x: u128) -> bool {
if x == 0 {
false
} else {
tailcall::call! { is_even_macro(x - 1) }
}
}
#[test]
fn test_mutual_recursion_via_macros() {
assert!(is_even_macro(10_000));
assert!(!is_even_macro(10_001));
assert!(!is_odd_macro(10_000));
assert!(is_odd_macro(10_001));
}
struct MethodParity;
impl MethodParity {
#[tailcall]
fn is_even(&self, x: u32) -> bool {
if x == 0 {
true
} else {
tailcall::call! { self.is_odd(x - 1) }
}
}
#[tailcall]
fn is_odd(&self, x: u32) -> bool {
if x == 0 {
false
} else {
tailcall::call! { self.is_even(x - 1) }
}
}
}
#[test]
fn test_mutual_recursion_via_methods() {
let parity = MethodParity;
assert!(parity.is_even(10_000));
assert!(!parity.is_even(10_001));
assert!(!parity.is_odd(10_000));
assert!(parity.is_odd(10_001));
}
#[derive(Default)]
struct MethodAccumulator {
steps: usize,
}
impl MethodAccumulator {
#[tailcall]
fn tick_down(&mut self, remaining: u32) -> u32 {
self.steps += 1;
if remaining == 0 {
self.steps as u32
} else {
tailcall::call! { self.tick_down(remaining - 1) }
}
}
}
#[test]
fn test_mutable_receiver_methods_work_with_tailcall() {
let mut accumulator = MethodAccumulator::default();
let total = accumulator.tick_down(8);
assert_eq!(total, 9);
assert_eq!(accumulator.steps, 9);
}
#[test]
fn test_self_recursive_methods_keep_hidden_thunk_builder() {
let mut accumulator = MethodAccumulator::default();
let total = accumulator.__tailcall_build_tick_down_thunk(8).call();
assert_eq!(total, 9);
assert_eq!(accumulator.steps, 9);
}
struct GenericMethodCounter;
impl GenericMethodCounter {
#[tailcall]
fn countdown<T: Copy>(&self, remaining: u32, value: T) -> T {
if remaining == 0 {
value
} else {
tailcall::call! { self.countdown(remaining - 1, value) }
}
}
}
#[test]
fn test_generic_self_recursive_methods_keep_hidden_thunk_builder() {
let counter = GenericMethodCounter;
assert_eq!(counter.countdown(8, 7_u32), 7);
assert_eq!(counter.__tailcall_build_countdown_thunk(8, 7_u32).call(), 7);
}
#[cfg(not(miri))]
#[tailcall]
fn recurse_with_metadata(n: u64) -> u64 {
let file = Path::new(env!("CARGO_MANIFEST_DIR")).join("Cargo.toml");
let len = file.metadata().unwrap().len();
if n >= 1_000 {
n
} else {
tailcall::call! { recurse_with_metadata(len + n) }
}
}
#[cfg(not(miri))]
#[test]
fn test_issue_18_non_tail_setup_before_recursive_call() {
let result = recurse_with_metadata(1);
assert!(result >= 1_000);
}
#[tailcall]
fn mixed_recursion_sum(n: u64) -> u64 {
match n {
0 => 0,
1 => tailcall::call! { mixed_recursion_sum(0) },
_ if n % 2 == 0 => {
let partial = mixed_recursion_sum(n - 1);
n + partial
}
_ => tailcall::call! { mixed_recursion_sum(n - 1) },
}
}
#[test]
fn test_mixed_recursion_allows_plain_non_tail_calls() {
assert_eq!(mixed_recursion_sum(0), 0);
assert_eq!(mixed_recursion_sum(1), 0);
assert_eq!(mixed_recursion_sum(2), 2);
assert_eq!(mixed_recursion_sum(3), 2);
assert_eq!(mixed_recursion_sum(4), 6);
assert_eq!(mixed_recursion_sum(5), 6);
assert_eq!(mixed_recursion_sum(6), 12);
}