1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
use std::cell::RefCell;
use std::fmt::Debug;
use std::hash::Hash;
use std::thread::LocalKey;

mod stack;
pub use stack::FixedPointStack;

pub fn fixed_point<Input, Output>(
    tracing_span: impl Fn(&Input) -> tracing::Span,
    storage: &'static LocalKey<RefCell<FixedPointStack<Input, Output>>>,
    args: Input,
    default_value: impl Fn(&Input) -> Output,
    next_value: impl Fn(Input) -> Output,
) -> Output
where
    Input: Value,
    Output: Value,
{
    stacker::maybe_grow(32 * 1024, 1024 * 1024, || {
        FixedPoint {
            tracing_span,
            storage,
            default_value,
            next_value,
        }
        .apply(args)
    })
}

struct FixedPoint<Input, Output, DefaultValue, NextValue, TracingSpan>
where
    Input: Value,
    Output: Value,
{
    tracing_span: TracingSpan,
    storage: &'static LocalKey<RefCell<FixedPointStack<Input, Output>>>,
    default_value: DefaultValue,
    next_value: NextValue,
}

pub trait Value: Clone + Eq + Debug + Hash + 'static {}
impl<T: Clone + Eq + Debug + Hash + 'static> Value for T {}

impl<Input, Output, DefaultValue, NextValue, TracingSpan>
    FixedPoint<Input, Output, DefaultValue, NextValue, TracingSpan>
where
    Input: Value,
    Output: Value,
    DefaultValue: Fn(&Input) -> Output,
    NextValue: Fn(Input) -> Output,
    TracingSpan: Fn(&Input) -> tracing::Span,
{
    fn apply(&self, input: Input) -> Output {
        if let Some(r) = self.with_stack(|stack| stack.search(&input)) {
            tracing::debug!("recursive call to {:?}, yielding {:?}", input, r);
            return r;
        }

        self.with_stack(|stack| {
            let default_value = (self.default_value)(&input);
            stack.push(&input, default_value);
        });

        loop {
            let span = (self.tracing_span)(&input);
            let _guard = span.enter();
            let output = (self.next_value)(input.clone());
            tracing::debug!(?output);
            if !self.with_stack(|stack| stack.update_output(&input, output)) {
                break;
            } else {
                tracing::debug!("output is different from previous iteration, re-executing until fixed point is reached");
            }
        }

        self.with_stack(|stack| stack.pop(&input))
    }

    fn with_stack<R>(&self, f: impl FnOnce(&mut FixedPointStack<Input, Output>) -> R) -> R {
        self.storage.with(|v| f(&mut *v.borrow_mut()))
    }
}