formality_core/
fixed_point.rs

1use std::cell::RefCell;
2use std::fmt::Debug;
3use std::hash::Hash;
4use std::thread::LocalKey;
5
6mod stack;
7pub use stack::FixedPointStack;
8
9pub fn fixed_point<Input, Output>(
10    tracing_span: impl Fn(&Input) -> tracing::Span,
11    storage: &'static LocalKey<RefCell<FixedPointStack<Input, Output>>>,
12    args: Input,
13    default_value: impl Fn(&Input) -> Output,
14    next_value: impl Fn(Input) -> Output,
15) -> Output
16where
17    Input: Value,
18    Output: Value,
19{
20    stacker::maybe_grow(32 * 1024, 1024 * 1024, || {
21        FixedPoint {
22            tracing_span,
23            storage,
24            default_value,
25            next_value,
26        }
27        .apply(args)
28    })
29}
30
31struct FixedPoint<Input, Output, DefaultValue, NextValue, TracingSpan>
32where
33    Input: Value,
34    Output: Value,
35{
36    tracing_span: TracingSpan,
37    storage: &'static LocalKey<RefCell<FixedPointStack<Input, Output>>>,
38    default_value: DefaultValue,
39    next_value: NextValue,
40}
41
42pub trait Value: Clone + Eq + Debug + Hash + 'static {}
43impl<T: Clone + Eq + Debug + Hash + 'static> Value for T {}
44
45impl<Input, Output, DefaultValue, NextValue, TracingSpan>
46    FixedPoint<Input, Output, DefaultValue, NextValue, TracingSpan>
47where
48    Input: Value,
49    Output: Value,
50    DefaultValue: Fn(&Input) -> Output,
51    NextValue: Fn(Input) -> Output,
52    TracingSpan: Fn(&Input) -> tracing::Span,
53{
54    fn apply(&self, input: Input) -> Output {
55        if let Some(r) = self.with_stack(|stack| stack.search(&input)) {
56            tracing::debug!("recursive call to {:?}, yielding {:?}", input, r);
57            return r;
58        }
59
60        self.with_stack(|stack| {
61            let default_value = (self.default_value)(&input);
62            stack.push(&input, default_value);
63        });
64
65        loop {
66            let span = (self.tracing_span)(&input);
67            let _guard = span.enter();
68            let output = (self.next_value)(input.clone());
69            tracing::debug!(?output);
70            if !self.with_stack(|stack| stack.update_output(&input, output)) {
71                break;
72            } else {
73                tracing::debug!("output is different from previous iteration, re-executing until fixed point is reached");
74            }
75        }
76
77        self.with_stack(|stack| stack.pop(&input))
78    }
79
80    fn with_stack<R>(&self, f: impl FnOnce(&mut FixedPointStack<Input, Output>) -> R) -> R {
81        self.storage.with(|v| f(&mut *v.borrow_mut()))
82    }
83}