Skip to main content

prune_lang/interp/
strategy.rs

1use super::*;
2use itertools::Itertools;
3use std::fmt;
4
5#[derive(Clone, Debug)]
6pub struct Branch {
7    pub depth: usize,
8    pub answers: Vec<(Ident, TermVal<IdentCtx>)>,
9    pub prims: Vec<(Prim, Vec<AtomVal<IdentCtx>>)>,
10    pub calls: Vec<PredCall>,
11}
12
13impl fmt::Display for Branch {
14    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
15        writeln!(f, "##### depth: = {} #####", self.depth)?;
16
17        for (par, val) in &self.answers {
18            writeln!(f, "{par} = {val}")?;
19        }
20
21        for (prim, args) in &self.prims {
22            let args = args.iter().format(", ");
23            writeln!(f, "{prim:?}({args})")?;
24        }
25
26        for call in &self.calls {
27            writeln!(f, "{call}")?;
28        }
29
30        Ok(())
31    }
32}
33
34impl Branch {
35    pub fn new(pred: Ident, pars: Vec<Ident>, rule_cnt: usize) -> Branch {
36        let call = PredCall {
37            pred,
38            polys: Vec::new(),
39            args: pars.iter().map(|par| Term::Var(par.tag_ctx(0))).collect(),
40            looks: (0..rule_cnt).collect(),
41            depth: 0,
42        };
43
44        Branch {
45            depth: 0,
46            answers: pars
47                .iter()
48                .map(|par| (*par, Term::Var(par.tag_ctx(0))))
49                .collect(),
50            prims: Vec::new(),
51            calls: vec![call],
52        }
53    }
54
55    pub fn merge(&mut self, unifier: Unifier<IdentCtx, LitVal, OptCons<Ident>>) {
56        for call in &mut self.calls {
57            for arg in &mut call.args {
58                *arg = unifier.subst(arg);
59            }
60        }
61
62        for (_par, val) in &mut self.answers {
63            *val = unifier.subst(val);
64        }
65    }
66
67    pub fn insert(&mut self, call_idx: usize, call: PredCall) {
68        self.calls.insert(call_idx, call);
69    }
70
71    pub fn remove(&mut self, call_idx: usize) -> PredCall {
72        self.calls.remove(call_idx)
73    }
74
75    pub fn random_strategy(&mut self, rng: &mut rand::rngs::ThreadRng) -> usize {
76        assert!(!self.calls.is_empty());
77        rng.random_range(0..self.calls.len())
78    }
79
80    pub fn left_biased_strategy(&mut self) -> usize {
81        assert!(!self.calls.is_empty());
82        0
83    }
84
85    pub fn interleave_strategy(&mut self) -> usize {
86        (0..self.calls.len())
87            .min_by_key(|idx| self.calls[*idx].depth)
88            .unwrap()
89    }
90
91    pub fn small_first_strategy(&mut self) -> usize {
92        (0..self.calls.len())
93            .min_by_key(|idx| {
94                let call = &self.calls[*idx];
95                call.looks.len() * 1000 + call.depth
96            })
97            .unwrap()
98    }
99
100    pub fn check_reduction(&self) -> Option<usize> {
101        (0..self.calls.len()).find(|idx| self.calls[*idx].looks.len() <= 1)
102    }
103}
104
105#[derive(Clone, Debug)]
106#[allow(dead_code)]
107pub struct PredCall {
108    pub pred: Ident,
109    pub polys: Vec<TermType>,
110    pub args: Vec<TermVal<IdentCtx>>,
111    pub looks: Vec<usize>,
112    pub depth: usize,
113}
114
115impl fmt::Display for PredCall {
116    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
117        let args = self.args.iter().format(", ");
118        if self.polys.is_empty() {
119            write!(f, "{}({})", self.pred, args)
120        } else {
121            let polys = self.polys.iter().format(", ");
122            write!(f, "{}[{}]({})", self.pred, polys, args)
123        }
124    }
125}
126
127impl PredCall {
128    fn try_unify_rule_head(&self, head: &[TermVal]) -> Result<(), ()> {
129        assert_eq!(head.len(), self.args.len());
130
131        let mut unifier: Unifier<IdentCtx, LitVal, OptCons<Ident>> = Unifier::new();
132        for (par, arg) in head.iter().zip(self.args.iter()) {
133            if unifier.unify(&par.tag_ctx(0), arg).is_err() {
134                return Err(());
135            }
136        }
137
138        Ok(())
139    }
140
141    pub fn lookahead_update(&mut self, rules: &[Rule]) {
142        let mut new_looks = self.looks.clone();
143        new_looks.retain(|look| self.try_unify_rule_head(&rules[*look].head).is_ok());
144        self.looks = new_looks
145    }
146}
147
148// input: a vector of positive integer [a, b, ..., z]
149// output: solution of equation x^-a + x^-b + ... + x^-z = 1
150pub fn tau_function(vec: &Vec<usize>) -> f32 {
151    if vec.len() <= 1 {
152        return 1.0;
153    }
154
155    // Newton's method for finding numerical solutions
156    let mut x: f32 = 1.0;
157    for _ in 0..100 {
158        // g(x) = sum(x^(-w_i)) - 1
159        let mut g: f32 = -1.0;
160        for &w in vec {
161            g += x.powi(-(w as i32));
162        }
163
164        // g'(x) = sum(-w_i * x^(-w_i - 1))
165        let mut gp: f32 = 0.0;
166        for &w in vec {
167            gp -= (w as f32) * x.powi(-(w as i32) - 1);
168        }
169
170        let step = g / gp;
171        let nx = x - step;
172
173        if step.abs() < 1e-4 {
174            return nx;
175        } else {
176            x = nx;
177        }
178    }
179    panic!("Newton's method fails to converge after 100 iterations!")
180}
181
182#[test]
183fn test_tau_function() {
184    assert_eq!(tau_function(&vec![]), 1.0);
185
186    assert_eq!(tau_function(&vec![1]), 1.0);
187
188    assert_eq!(tau_function(&vec![2]), 1.0);
189
190    let x = tau_function(&vec![1, 1]);
191    assert!((x - 2.0).abs() < 1e-4);
192
193    let x = tau_function(&vec![1, 1, 1]);
194    assert!((x - 3.0).abs() < 1e-4);
195
196    let x = tau_function(&vec![2, 2]);
197    assert!((x - 1.41421356).abs() < 1e-4);
198
199    let x = tau_function(&vec![1, 2]);
200    assert!((x - 1.61803399).abs() < 1e-4);
201
202    let x = tau_function(&vec![1, 2, 3]);
203    assert!((x - 1.83928676).abs() < 1e-4);
204
205    let x = tau_function(&vec![5, 5]);
206    assert!((x - 1.14869835).abs() < 1e-4);
207}