programinduction/lambda/
compression.rs

1use crossbeam_channel::bounded;
2use itertools::Itertools;
3use polytype::{Context, Type, TypeScheme};
4use rayon::join;
5use rayon::prelude::*;
6use std::borrow::Cow;
7use std::collections::{HashMap, VecDeque};
8use std::rc::Rc;
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::sync::{Arc, RwLock};
11
12use super::{Expression, Language, LinkedList};
13use crate::{ECFrontier, Task};
14
15/// Parameters for grammar induction.
16///
17/// Proposed grammars are scored as `likelihood - aic * #primitives - structure_penalty * #nodes`.
18/// Additionally, `pseudocounts` affects the likelihood calculation, and `topk` and `arity` affect
19/// what fragments can be proposed.
20pub struct CompressionParams {
21    /// Pseudocounts are added to the observed counts associated with each primitive and invented
22    /// expression.
23    pub pseudocounts: u64,
24    /// Rather than using every expression in the frontier for proposing fragments, only use the
25    /// `topk` best expressions in each frontier.
26    pub topk: usize,
27    /// Structure penalty penalizes the total number of nodes in each [`Expression`] of the
28    /// grammar's primitives and invented expressions.
29    ///
30    /// [`Expression`]: enum.Expression.html
31    pub structure_penalty: f64,
32    /// Determines whether to use the maximum a-posteriori value for topk evaluation, or whether to
33    /// use only the likelihood. Leave this to `false` unless you know what you are doing.
34    pub topk_use_only_likelihood: bool,
35    /// AIC is a penalty in the number of parameters, i.e. the number of primitives and invented
36    /// expressions.
37    pub aic: f64,
38    /// Arity is the largest applicative depth of an expression that may be manipulated to propose
39    /// a fragment.
40    pub arity: u32,
41}
42impl Default for CompressionParams {
43    /// The default params prevent completely discarding of primives by having non-zero
44    /// pseudocounts.
45    ///
46    /// ```
47    /// # use programinduction::lambda::CompressionParams;
48    /// CompressionParams {
49    ///     pseudocounts: 5,
50    ///     topk: 2,
51    ///     topk_use_only_likelihood: false,
52    ///     structure_penalty: 1f64,
53    ///     aic: 1f64,
54    ///     arity: 2,
55    /// }
56    /// # ;
57    /// ```
58    fn default() -> Self {
59        CompressionParams {
60            pseudocounts: 5,
61            topk: 2,
62            topk_use_only_likelihood: false,
63            structure_penalty: 1f64,
64            aic: 1f64,
65            arity: 2,
66        }
67    }
68}
69
70/// This function makes it easier to write your own compression scheme.
71/// It takes the role of a single compression step, separating it into sub-steps that are decent to
72/// implement in isolation.
73///
74/// This is a sophisticated higher-order function — tread carefully.
75/// - `state: I` can be mutated by making it `Arc<RwLock<_>>`, though it will often just be `()`
76///   unless you really need it.
77/// - type `X` is for a _candidate_, something which can be used to update a dsl.
78/// - `proposer` pushes candidates to the given vector.
79/// - `proposal_to_dsl` adds the candidate to the dsl and returns a new joint minimum description
80///   length for the dsl. For example, this may be set to:
81///
82///   ```compile_fails
83///   |_state, expr, dsl, frontiers, params| {
84///       if dsl.invent(expr.clone(), 0.).is_ok() {
85///           Some(dsl.inside_outside(frontiers, params.pseudocounts))
86///       } else {
87///           None
88///       }
89///   }
90///   ```
91/// - `defragment` is most often a no-op and can be set to `|x| x`. It allows you to effectively
92///   change the output of `proposal_to_expr` after scoring has been done. This is useful for
93///   fragment grammar compression, because scoring with inventions that have free variables (i.e.
94///   non-closed expressions) will let inside-outside capture those uses without having to rewrite
95///   the frontiers.
96/// - `rewrite_frontiers` finally takes the highest-scoring dsl, which is guaranteed to have a
97///   single latest invention (accessible via `dsl.invented.last().unwrap()`) equal to
98///   `defragment(proposal_to_expr(_, proposal))` for some generated proposal. The
99///   [`lambda::Expression`] it is supplied is the non-defragmented proposal.
100///   There's no need to rescore the frontiers, that's done automatically.
101///
102/// We recommended to make a function that adapts this into a four-argument `induce_my_algo`
103/// function by filling in the higher-order functions. See the source code of this project to find
104/// the particular use of this function that gives [`Language::compress`] using a
105/// fragment-grammar-like compression scheme.
106///
107/// [`Language::compress`]: struct.Language.html#method.compress
108/// [`lambda::CompressionParams`]: struct.CompressionParams.html
109/// [`lambda::Expression`]: enum.Expression.html
110/// [`lambda::Language`]: struct.Language.html
111#[allow(clippy::too_many_arguments)]
112pub fn induce<O, T, I, P, D, F, R>(
113    dsl: &Language,
114    params: &CompressionParams,
115    tasks: &[T],
116    mut original_frontiers: Vec<ECFrontier<Expression>>,
117    state: I,
118    proposer: P,
119    proposal_to_dsl: D,
120    defragment: F,
121    rewrite_frontiers: R,
122) -> (Language, Vec<ECFrontier<Expression>>)
123where
124    O: ?Sized,
125    T: Task<O, Representation = Language, Expression = Expression>,
126    I: Sync,
127    P: Fn(
128            &I,
129            &Language,
130            &[(TypeScheme, Vec<(Expression, f64, f64)>)],
131            &CompressionParams,
132            &mut Vec<T::Expression>,
133        ) + Sync,
134    D: Fn(
135            &I,
136            &T::Expression,
137            &mut Language,
138            &[(TypeScheme, Vec<(Expression, f64, f64)>)],
139            &CompressionParams,
140        ) -> Option<f64>
141        + Sync,
142    F: Fn(Expression) -> Expression,
143    R: Fn(
144        &I,
145        T::Expression,
146        Expression,
147        &Language,
148        &mut Vec<(TypeScheme, Vec<(Expression, f64, f64)>)>,
149        &CompressionParams,
150    ),
151{
152    let mut dsl = dsl.clone();
153    let mut frontiers: Vec<RescoredFrontier> = tasks
154        .par_iter()
155        .map(|t| t.tp().clone())
156        .zip(&original_frontiers)
157        .filter(|&(_, f)| !f.is_empty())
158        .map(|(tp, f)| (tp, f.0.clone()))
159        .collect();
160
161    let joint_mdl = dsl.inside_outside(&frontiers, params.pseudocounts);
162    let mut best_score = dsl.score(joint_mdl, params);
163
164    if cfg!(feature = "verbose") {
165        eprintln!("COMPRESSION: starting score: {}", best_score)
166    }
167    if params.aic.is_finite() {
168        loop {
169            let (candidate, fragment_expr) = {
170                let rescored_frontiers: Vec<_> = frontiers
171                    .par_iter()
172                    .cloned()
173                    .map(|f| dsl.rescore_frontier(f, params.topk, params.topk_use_only_likelihood))
174                    .collect();
175                let mut proposals = Vec::new();
176                proposer(&state, &dsl, &rescored_frontiers, params, &mut proposals);
177                if cfg!(feature = "verbose") {
178                    eprintln!("COMPRESSION: proposed {} fragments", proposals.len())
179                }
180                let best_proposal = proposals
181                    .into_par_iter()
182                    .filter_map(|candidate| {
183                        let mut dsl = dsl.clone();
184                        let joint_mdl = match proposal_to_dsl(
185                            &state,
186                            &candidate,
187                            &mut dsl,
188                            &rescored_frontiers,
189                            params,
190                        ) {
191                            None => {
192                                if cfg!(feature = "verbose") {
193                                    eprintln!("COMPRESSION: dropped invalid proposal");
194                                }
195                                return None;
196                            }
197                            Some(joint_mdl) => joint_mdl,
198                        };
199                        let s = dsl.score(joint_mdl, params);
200                        if s.is_finite() {
201                            Some((dsl, candidate, s))
202                        } else {
203                            None
204                        }
205                    })
206                    .max_by(|(_, _, x), (_, _, y)| x.partial_cmp(y).unwrap());
207                if best_proposal.is_none() {
208                    if cfg!(feature = "verbose") {
209                        eprintln!("COMPRESSION: no sufficient proposals")
210                    }
211                    break;
212                }
213                let (new_dsl, candidate, new_score) = best_proposal.unwrap();
214                if new_score <= best_score {
215                    if cfg!(feature = "verbose") {
216                        eprintln!("COMPRESSION: score did not improve")
217                    }
218                    break;
219                }
220                dsl = new_dsl;
221                best_score = new_score;
222
223                let (fragment_expr, _, log_prior) = dsl.invented.pop().unwrap();
224                let inv = defragment(fragment_expr.clone());
225                if cfg!(feature = "verbose") {
226                    eprintln!(
227                        "COMPRESSION: score improved to {} with invention {} (defragmented from candidate expr {})",
228                        best_score,
229                        dsl.display(&inv),
230                        dsl.display(&fragment_expr)
231                    )
232                }
233                dsl.invent(inv, log_prior).expect("invalid invention");
234                (candidate, fragment_expr)
235            };
236            rewrite_frontiers(
237                &state,
238                candidate,
239                fragment_expr,
240                &dsl,
241                &mut frontiers,
242                params,
243            )
244        }
245    }
246    frontiers.reverse();
247    for f in &mut original_frontiers {
248        if !f.is_empty() {
249            f.0 = frontiers.pop().unwrap().1;
250        }
251    }
252    (dsl, original_frontiers)
253}
254
255/// A convenient frontier representation.
256pub type RescoredFrontier = (TypeScheme, Vec<(Expression, f64, f64)>);
257
258pub fn joint_mdl(dsl: &Language, frontiers: &[RescoredFrontier]) -> f64 {
259    frontiers
260        .par_iter()
261        .map(|(t, f)| {
262            f.iter()
263                .map(|e| e.2 + dsl.likelihood(t, &e.0))
264                .fold(f64::NEG_INFINITY, f64::max)
265        })
266        .sum::<f64>()
267}
268
269/// Runs a variant of the inside outside algorithm to assign production probabilities for all
270/// primitives and invented expressions. The joint minimum description length is returned.
271pub fn inside_outside(
272    dsl: &mut Language,
273    frontiers: &[RescoredFrontier],
274    pseudocounts: u64,
275) -> f64 {
276    dsl.inside_outside_internal(frontiers, pseudocounts)
277}
278
279pub fn induce_fragment_grammar<Observation: ?Sized>(
280    dsl: &Language,
281    params: &CompressionParams,
282    tasks: &[impl Task<Observation, Representation = Language, Expression = Expression>],
283    original_frontiers: Vec<ECFrontier<Expression>>,
284) -> (Language, Vec<ECFrontier<Expression>>) {
285    induce(
286        dsl,
287        params,
288        tasks,
289        original_frontiers,
290        (),
291        |_, dsl, rescored_frontiers, params, proposals| {
292            dsl.propose_inventions(rescored_frontiers, params.arity, proposals)
293        },
294        |_, expr, dsl, rescored_frontiers, params| {
295            if dsl.invent(expr.clone(), 0.).is_ok() {
296                Some(dsl.inside_outside(rescored_frontiers, params.pseudocounts))
297            } else {
298                None
299            }
300        },
301        proposals::defragment,
302        |_, fragment_expr, _, dsl, frontiers, _| {
303            let i = dsl.invented.len() - 1;
304            for f in frontiers {
305                dsl.rewrite_frontier_with_fragment_expression(f, i, &fragment_expr);
306            }
307        },
308    )
309}
310
311/// Extend the Language in our scope so we can do useful compression things.
312impl Language {
313    fn rescore_frontier(
314        &self,
315        f: RescoredFrontier,
316        topk: usize,
317        topk_use_only_likelihood: bool,
318    ) -> RescoredFrontier {
319        let xs =
320            f.1.iter()
321                .map(|&(ref expr, _, loglikelihood)| {
322                    let logprior = self.uses(&f.0, expr).0;
323                    (expr, logprior, loglikelihood, logprior + loglikelihood)
324                })
325                .sorted_by(|(_, _, xl, xpost), (_, _, yl, ypost)| {
326                    if topk_use_only_likelihood {
327                        yl.partial_cmp(xl).unwrap()
328                    } else {
329                        ypost.partial_cmp(xpost).unwrap()
330                    }
331                })
332                .take(topk)
333                .map(|(expr, logprior, loglikelihood, _)| (expr.clone(), logprior, loglikelihood))
334                .collect();
335        (f.0, xs)
336    }
337
338    fn reset_uniform(&mut self) {
339        for x in &mut self.primitives {
340            x.2 = 0f64;
341        }
342        for x in &mut self.invented {
343            x.2 = 0f64;
344        }
345        self.variable_logprob = 0f64;
346    }
347
348    fn inside_outside_internal(
349        &mut self,
350        frontiers: &[RescoredFrontier],
351        pseudocounts: u64,
352    ) -> f64 {
353        self.reset_uniform();
354        let pseudocounts = pseudocounts as f64;
355        let (joint_mdl, u) = self.all_uses(frontiers);
356        self.variable_logprob = (u.actual_vars + pseudocounts).ln() - u.possible_vars.ln();
357        if !self.variable_logprob.is_finite() {
358            self.variable_logprob = u.actual_vars.max(1f64).ln()
359        }
360        for (i, prim) in self.primitives.iter_mut().enumerate() {
361            let obs = u.actual_prims[i] + pseudocounts;
362            let pot = u.possible_prims[i];
363            let pot = if pot == 0f64 { pseudocounts } else { pot };
364            prim.2 = obs.ln() - pot.ln();
365        }
366        for (i, inv) in self.invented.iter_mut().enumerate() {
367            let obs = u.actual_invented[i];
368            let pot = u.possible_invented[i];
369            inv.2 = obs.ln() - pot.ln();
370        }
371        joint_mdl
372    }
373
374    fn all_uses(&self, frontiers: &[RescoredFrontier]) -> (f64, Uses) {
375        let (tx, rx) = bounded(frontiers.len());
376        let u = frontiers
377            .par_iter()
378            .flat_map(|f| {
379                let lu =
380                    f.1.iter()
381                        .map(|&(ref expr, _logprior, loglikelihood)| {
382                            let (logprior, u) = self.uses(&f.0, expr);
383                            (logprior + loglikelihood, u)
384                        })
385                        .collect::<Vec<_>>();
386                let largest = lu.iter().fold(f64::NEG_INFINITY, |acc, &(l, _)| acc.max(l));
387                tx.send(largest).expect("send on closed channel");
388                let z = largest
389                    + lu.iter()
390                        .map(|&(l, _)| (l - largest).exp())
391                        .sum::<f64>()
392                        .ln();
393                lu.into_par_iter().map(move |(l, mut u)| {
394                    u.scale((l - z).exp());
395                    u
396                })
397            })
398            .reduce(
399                || Uses::new(self),
400                |mut u, nu| {
401                    u.merge(nu);
402                    u
403                },
404            );
405        let joint_mdl = rx.into_iter().take(frontiers.len()).sum();
406        (joint_mdl, u)
407    }
408
409    /// This is similar to `enumerator::likelihood` but it does a lot more work to determine
410    /// _outside_ counts.
411    fn uses(&self, request: &TypeScheme, expr: &Expression) -> (f64, Uses) {
412        let mut ctx = Context::default();
413        let tp = request.clone().instantiate_owned(&mut ctx);
414        let env = Rc::new(LinkedList::default());
415        self.likelihood_uses(&tp, expr, &ctx, &env)
416    }
417
418    /// This is similar to `enumerator::likelihood_internal` but it does a lot more work to
419    /// determine _outside_ counts.
420    fn likelihood_uses(
421        &self,
422        request: &Type,
423        expr: &Expression,
424        ctx: &Context,
425        env: &Rc<LinkedList<Type>>,
426    ) -> (f64, Uses) {
427        if let Some((arg, ret)) = request.as_arrow() {
428            let env = LinkedList::prepend(env, arg.clone());
429            if let Expression::Abstraction(ref body) = *expr {
430                self.likelihood_uses(ret, body, ctx, &env)
431            } else {
432                (f64::NEG_INFINITY, Uses::new(self)) // invalid expression
433            }
434        } else {
435            let candidates = self.candidates(request, ctx, &env.as_vecdeque());
436            let mut possible_vars = 0f64;
437            let mut possible_prims = vec![0f64; self.primitives.len()];
438            let mut possible_invented = vec![0f64; self.invented.len()];
439            for (_, expr, _, _) in &candidates {
440                match *expr {
441                    Expression::Primitive(num) => possible_prims[num] = 1f64,
442                    Expression::Invented(num) => possible_invented[num] = 1f64,
443                    Expression::Index(_) => possible_vars = 1f64,
444                    _ => unreachable!(),
445                }
446            }
447            let mut total_likelihood = f64::NEG_INFINITY;
448            let mut weighted_uses: Vec<(f64, Uses)> = Vec::new();
449            let mut f = expr;
450            let mut xs: VecDeque<&Expression> = VecDeque::new();
451            loop {
452                // if we're dealing with an Application, we reiterate for every applicable f/xs
453                // combination. (see the end of this block.)
454                for &(mut l, ref expr, ref tp, ref cctx) in &candidates {
455                    let mut ctx = Cow::Borrowed(cctx);
456                    let mut tp = Cow::Borrowed(tp);
457                    let mut bindings = HashMap::new();
458                    // skip this iteration if candidate expr and f don't match:
459                    if let Expression::Index(_) = *expr {
460                        if expr != f {
461                            continue;
462                        }
463                    } else if let Some(mut frag_tp) =
464                        TreeMatcher::do_match(self, ctx.to_mut(), expr, f, &mut bindings, xs.len())
465                    {
466                        let mut template = VecDeque::with_capacity(xs.len() + 1);
467                        template.push_front(request.clone());
468                        for _ in 0..xs.len() {
469                            template.push_front(ctx.to_mut().new_variable())
470                        }
471                        // unification cannot fail, so we can safely unwrap:
472                        if ctx
473                            .to_mut()
474                            .unify(&frag_tp, &Type::from(template.clone()))
475                            .is_err()
476                        {
477                            eprintln!(
478                                "WARNING (please report to programinduction devs): likelihood unification failure against expr={} (tp={}) for f={} frag_tp={} tmpl_tp={} xs={:?}",
479                                self.display(expr),
480                                tp,
481                                self.display(f),
482                                frag_tp,
483                                Type::from(template),
484                                xs.iter().map(|x| self.display(x)).collect::<Vec<_>>(),
485                            );
486                            continue;
487                        }
488                        frag_tp.apply_mut(&ctx);
489                        tp = Cow::Owned(frag_tp);
490                    } else {
491                        continue;
492                    }
493
494                    let arg_tps: VecDeque<&Type> = tp.args().unwrap_or_default();
495                    if xs.len() != arg_tps.len() {
496                        eprintln!(
497                            "WARNING (please report to programinduction devs): xs and arg_tps did not correspond: expr={} (arg_tps={:?}) f={} xs={:?}",
498                            self.display(expr),
499                            arg_tps.iter().map(std::string::ToString::to_string).collect::<Vec<_>>(),
500                            self.display(f),
501                            xs.iter().map(|x| self.display(x)).collect::<Vec<_>>(),
502                        );
503                        continue;
504                    }
505
506                    let mut u = Uses {
507                        actual_vars: 0f64,
508                        actual_prims: vec![0f64; self.primitives.len()],
509                        actual_invented: vec![0f64; self.invented.len()],
510                        possible_vars,
511                        possible_prims: possible_prims.clone(),
512                        possible_invented: possible_invented.clone(),
513                    };
514                    match *expr {
515                        Expression::Primitive(num) => u.actual_prims[num] = 1f64,
516                        Expression::Invented(num) => u.actual_invented[num] = 1f64,
517                        Expression::Index(_) => u.actual_vars = 1f64,
518                        _ => unreachable!(),
519                    }
520
521                    for (free_tp, free_expr) in bindings
522                        .iter()
523                        .map(|(_, (tp, expr))| (tp, expr))
524                        .chain(arg_tps.into_iter().zip(xs.iter().cloned()))
525                    {
526                        let mut free_tp = free_tp.clone();
527                        loop {
528                            let free_tp_new = free_tp.apply(&ctx);
529                            if free_tp_new != free_tp {
530                                free_tp = free_tp_new;
531                            } else {
532                                break;
533                            }
534                        }
535                        let n = self.likelihood_uses(&free_tp, free_expr, &ctx, env);
536                        if n.0.is_infinite() {
537                            l = f64::NEG_INFINITY;
538                            break;
539                        }
540                        l += n.0;
541                        u.merge(n.1);
542                    }
543
544                    if l.is_infinite() {
545                        continue;
546                    }
547                    weighted_uses.push((l, u));
548                    total_likelihood = if total_likelihood > l {
549                        total_likelihood + (1f64 + (l - total_likelihood).exp()).ln()
550                    } else {
551                        l + (1f64 + (total_likelihood - l).exp()).ln()
552                    };
553                }
554
555                if let Expression::Application(ref ff, ref x) = *f {
556                    f = ff;
557                    xs.push_front(x);
558                } else {
559                    break;
560                }
561            }
562
563            let mut u = Uses::new(self);
564            if total_likelihood.is_finite() && !weighted_uses.is_empty() {
565                u.join_from(total_likelihood, weighted_uses)
566            }
567            (total_likelihood, u)
568        }
569    }
570
571    /// returns whether the frontier was rewritten
572    fn rewrite_frontier_with_fragment_expression(
573        &self,
574        f: &mut RescoredFrontier,
575        i: usize,
576        expr: &Expression,
577    ) -> bool {
578        let results: Vec<_> =
579            f.1.iter_mut()
580                .map(|x| self.rewrite_expression(&mut x.0, i, expr, 0))
581                .collect();
582        results.iter().any(|&x| x)
583    }
584    fn rewrite_expression(
585        &self,
586        expr: &mut Expression,
587        inv_n: usize,
588        inv: &Expression,
589        n_args: usize,
590    ) -> bool {
591        let mut rewrote = false;
592        let do_rewrite = match *expr {
593            Expression::Application(ref mut f, ref mut x) => {
594                rewrote |= self.rewrite_expression(f, inv_n, inv, n_args + 1);
595                rewrote |= self.rewrite_expression(x, inv_n, inv, 0);
596                true
597            }
598            Expression::Abstraction(ref mut body) => {
599                rewrote |= self.rewrite_expression(body, inv_n, inv, 0);
600                true
601            }
602            _ => false,
603        };
604        if do_rewrite {
605            let mut bindings = HashMap::new();
606            let mut ctx = Context::default();
607            let matches =
608                TreeMatcher::do_match(self, &mut ctx, inv, expr, &mut bindings, n_args).is_some();
609            if matches {
610                let mut new_expr = Expression::Invented(inv_n);
611                for j in (0..bindings.len()).rev() {
612                    let (_, b) = &bindings[&j];
613                    let inner = Box::new(new_expr);
614                    new_expr = Expression::Application(inner, Box::new(b.clone()));
615                }
616                *expr = new_expr;
617                rewrote = true
618            }
619        }
620        rewrote
621    }
622
623    /// Yields expressions that may have free variables.
624    fn propose_inventions(
625        &self,
626        frontiers: &[RescoredFrontier],
627        arity: u32,
628        proposals: &mut Vec<Expression>,
629    ) {
630        let (tx, rx) = bounded(100);
631        join(
632            move || {
633                let findings = Arc::new(RwLock::new(HashMap::new()));
634                frontiers
635                    .par_iter()
636                    .flat_map(|f| &f.1)
637                    .flat_map(|(expr, _, _)| proposals::from_expression(expr, arity))
638                    .filter(|fragment_expr| {
639                        let expr = proposals::defragment(fragment_expr.clone());
640                        !self.invented.iter().any(|(x, _, _)| x == &expr)
641                    })
642                    .for_each(|fragment_expr| {
643                        let res = {
644                            let h = findings.read().expect("hashmap was poisoned");
645                            h.get(&fragment_expr)
646                                .map(|x: &AtomicUsize| x.fetch_add(1, Ordering::SeqCst))
647                        };
648                        match res {
649                            Some(2) if self.infer(&fragment_expr).is_ok() => tx
650                                .send(fragment_expr)
651                                .expect("failed to send fragment proposal"),
652                            None => {
653                                let mut h = findings.write().expect("hashmap was poisoned");
654                                let count = h
655                                    .entry(fragment_expr.clone())
656                                    .or_insert_with(|| AtomicUsize::new(0));
657                                if 2 == count.fetch_add(1, Ordering::SeqCst)
658                                    && self.infer(&fragment_expr).is_ok()
659                                {
660                                    tx.send(fragment_expr)
661                                        .expect("failed to send fragment proposal")
662                                }
663                            }
664                            _ => (),
665                        }
666                    })
667            },
668            move || proposals.extend(rx),
669        );
670    }
671}
672
673struct TreeMatcher<'a> {
674    dsl: &'a Language,
675    ctx: &'a mut Context,
676    bindings: &'a mut HashMap<usize, (Type, Expression)>,
677}
678impl<'a> TreeMatcher<'a> {
679    /// If the trees (`fragment` against `concrete`) match, this appropriately updates the context
680    /// and gets the type for `fragment`.  Also gives bindings for indices. This may modify the
681    /// context even upon failure.
682    fn do_match(
683        dsl: &Language,
684        ctx: &mut Context,
685        fragment: &Expression,
686        concrete: &Expression,
687        bindings: &mut HashMap<usize, (Type, Expression)>,
688        n_args: usize,
689    ) -> Option<Type> {
690        if !Self::might_match(dsl, fragment, concrete, 0) {
691            None
692        } else {
693            let mut tm = TreeMatcher { dsl, ctx, bindings };
694            tm.execute(fragment, concrete, &Rc::new(LinkedList::default()), n_args)
695        }
696    }
697
698    /// Small tree comparison that doesn't update any bindings.
699    fn might_match(
700        dsl: &Language,
701        fragment: &Expression,
702        concrete: &Expression,
703        depth: usize,
704    ) -> bool {
705        match *fragment {
706            Expression::Index(i) if i >= depth => true,
707            Expression::Abstraction(ref f_body) => {
708                if let Expression::Abstraction(ref e_body) = *concrete {
709                    Self::might_match(dsl, f_body, e_body, depth + 1)
710                } else {
711                    false
712                }
713            }
714            Expression::Application(ref f_f, ref f_x) => {
715                if let Expression::Application(ref c_f, ref c_x) = *concrete {
716                    Self::might_match(dsl, f_x, c_x, depth)
717                        && Self::might_match(dsl, f_f, c_f, depth)
718                } else {
719                    false
720                }
721            }
722            Expression::Invented(f_num) => {
723                if let Expression::Invented(c_num) = *concrete {
724                    f_num == c_num
725                } else {
726                    Self::might_match(dsl, &dsl.invented[f_num].0, concrete, depth)
727                }
728            }
729            _ => fragment == concrete,
730        }
731    }
732
733    fn execute(
734        &mut self,
735        fragment: &Expression,
736        concrete: &Expression,
737        env: &Rc<LinkedList<Type>>,
738        n_args: usize,
739    ) -> Option<Type> {
740        match (fragment, concrete) {
741            (Expression::Application(f_f, f_x), Expression::Application(c_f, c_x)) => {
742                let ft = self.execute(f_f, c_f, env, n_args)?;
743                let xt = self.execute(f_x, c_x, env, n_args)?;
744                let ret = self.ctx.new_variable();
745                if self.ctx.unify(&ft, &Type::arrow(xt, ret.clone())).is_ok() {
746                    Some(ret.apply(self.ctx))
747                } else {
748                    None
749                }
750            }
751            (&Expression::Primitive(f_num), &Expression::Primitive(c_num)) if f_num == c_num => {
752                let tp = self.dsl.primitives[f_num].1.clone();
753                Some(tp.instantiate_owned(self.ctx))
754            }
755            (&Expression::Invented(f_num), &Expression::Invented(c_num)) => {
756                if f_num == c_num {
757                    let tp = self.dsl.invented[f_num].1.clone();
758                    Some(tp.instantiate_owned(self.ctx))
759                } else {
760                    None
761                }
762            }
763            (&Expression::Invented(f_num), _) => {
764                let inv = &self.dsl.invented[f_num].0;
765                self.execute(inv, concrete, env, n_args)
766            }
767            (Expression::Abstraction(f_body), Expression::Abstraction(c_body)) => {
768                let arg = self.ctx.new_variable();
769                let env = LinkedList::prepend(env, arg.clone());
770                let ret = self.execute(f_body, c_body, &env, 0)?;
771                Some(Type::arrow(arg, ret))
772            }
773            (&Expression::Index(i), _) if i < env.len() => {
774                // bound variable
775                if fragment == concrete {
776                    let mut tp = env[i].clone();
777                    tp.apply_mut(self.ctx);
778                    Some(tp)
779                } else {
780                    None
781                }
782            }
783            (&Expression::Index(i), _) => {
784                // free variable
785                let i = i - env.len();
786                // make sure index bindings don't reach beyond fragment
787                let mut concrete = concrete.clone();
788                if concrete.shift(-(env.len() as i64)) {
789                    // wrap in abstracted applications for eta-long form
790                    if n_args > 0 {
791                        concrete.shift(n_args as i64);
792                        for j in 0..n_args {
793                            concrete = Expression::Application(
794                                Box::new(concrete),
795                                Box::new(Expression::Index(j)),
796                            );
797                        }
798                        for _ in 0..n_args {
799                            concrete = Expression::Abstraction(Box::new(concrete));
800                        }
801                    }
802                    // update bindings
803                    if let Some((tp, binding)) = self.bindings.get(&i) {
804                        return if binding == &concrete {
805                            Some(tp.clone())
806                        } else {
807                            None
808                        };
809                    }
810                    let tp = self.ctx.new_variable();
811                    self.bindings.insert(i, (tp.clone(), concrete));
812                    Some(tp)
813                } else {
814                    None
815                }
816            }
817            _ => None,
818        }
819    }
820}
821
822#[derive(Debug, Clone)]
823struct Uses {
824    actual_vars: f64,
825    possible_vars: f64,
826    actual_prims: Vec<f64>,
827    possible_prims: Vec<f64>,
828    actual_invented: Vec<f64>,
829    possible_invented: Vec<f64>,
830}
831impl Uses {
832    fn new(dsl: &Language) -> Uses {
833        let n_primitives = dsl.primitives.len();
834        let n_invented = dsl.invented.len();
835        Uses {
836            actual_vars: 0f64,
837            possible_vars: 0f64,
838            actual_prims: vec![0f64; n_primitives],
839            possible_prims: vec![0f64; n_primitives],
840            actual_invented: vec![0f64; n_invented],
841            possible_invented: vec![0f64; n_invented],
842        }
843    }
844    fn scale(&mut self, s: f64) {
845        self.actual_vars *= s;
846        self.possible_vars *= s;
847        self.actual_prims.iter_mut().for_each(|x| *x *= s);
848        self.possible_prims.iter_mut().for_each(|x| *x *= s);
849        self.actual_invented.iter_mut().for_each(|x| *x *= s);
850        self.possible_invented.iter_mut().for_each(|x| *x *= s);
851    }
852    fn merge(&mut self, other: Uses) {
853        self.actual_vars += other.actual_vars;
854        self.possible_vars += other.possible_vars;
855        self.actual_prims
856            .iter_mut()
857            .zip(other.actual_prims)
858            .for_each(|(a, b)| *a += b);
859        self.possible_prims
860            .iter_mut()
861            .zip(other.possible_prims)
862            .for_each(|(a, b)| *a += b);
863        self.actual_invented
864            .iter_mut()
865            .zip(other.actual_invented)
866            .for_each(|(a, b)| *a += b);
867        self.possible_invented
868            .iter_mut()
869            .zip(other.possible_invented)
870            .for_each(|(a, b)| *a += b);
871    }
872    /// self must be freshly created via `Uses::new()`, `z` must be finite and `weighted_uses` must
873    /// be non-empty.
874    fn join_from(&mut self, z: f64, mut weighted_uses: Vec<(f64, Uses)>) {
875        for &mut (l, ref mut u) in &mut weighted_uses {
876            u.scale((l - z).exp());
877        }
878        self.actual_vars = weighted_uses
879            .iter()
880            .map(|(_, u)| u.actual_vars)
881            .sum::<f64>();
882        self.possible_vars = weighted_uses
883            .iter()
884            .map(|(_, u)| u.possible_vars)
885            .sum::<f64>();
886        self.actual_prims.iter_mut().enumerate().for_each(|(i, c)| {
887            *c = weighted_uses
888                .iter()
889                .map(|(_, u)| u.actual_prims[i])
890                .sum::<f64>()
891        });
892        self.possible_prims
893            .iter_mut()
894            .enumerate()
895            .for_each(|(i, c)| {
896                *c = weighted_uses
897                    .iter()
898                    .map(|(_, u)| u.possible_prims[i])
899                    .sum::<f64>()
900            });
901        self.actual_invented
902            .iter_mut()
903            .enumerate()
904            .for_each(|(i, c)| {
905                *c = weighted_uses
906                    .iter()
907                    .map(|(_, u)| u.actual_invented[i])
908                    .sum::<f64>()
909            });
910        self.possible_invented
911            .iter_mut()
912            .enumerate()
913            .for_each(|(i, c)| {
914                *c = weighted_uses
915                    .iter()
916                    .map(|(_, u)| u.possible_invented[i])
917                    .sum::<f64>()
918            });
919    }
920}
921
922mod proposals {
923    //! Proposals, or "fragment expressions" (written `fragment_expr` where applicable) are
924    //! expressions with free variables.
925
926    use super::super::Expression;
927    use super::expression_count_kinds;
928    use itertools::Itertools;
929    use std::collections::HashMap;
930    use std::iter;
931
932    #[derive(Clone, Debug)]
933    enum Fragment {
934        Variable,
935        Application(Box<Fragment>, Box<Fragment>),
936        Abstraction(Box<Fragment>),
937        Expression(Expression),
938    }
939    impl Fragment {
940        fn fragvars(&self) -> usize {
941            match self {
942                Fragment::Expression(_) => 0,
943                Fragment::Application(f, x) => f.fragvars() + x.fragvars(),
944                Fragment::Abstraction(body) => body.fragvars(),
945                Fragment::Variable => 1,
946            }
947        }
948        fn n_free(&self, depth: usize) -> usize {
949            match self {
950                Fragment::Expression(expr) => Fragment::n_free_expr(expr, depth),
951                Fragment::Application(f, x) => f.n_free(depth) + x.n_free(depth),
952                Fragment::Abstraction(body) => body.n_free(depth + 1),
953                Fragment::Variable => 0,
954            }
955        }
956        fn n_free_expr(expr: &Expression, depth: usize) -> usize {
957            match expr {
958                Expression::Application(f, x) => {
959                    Fragment::n_free_expr(f, depth) + Fragment::n_free_expr(x, depth)
960                }
961                Expression::Abstraction(body) => Fragment::n_free_expr(body, depth + 1),
962                Expression::Index(i) if *i >= depth => 1,
963                _ => 0,
964            }
965        }
966        fn canonicalize(self) -> impl Iterator<Item = Expression> {
967            let fragvars = self.fragvars();
968            let n_free = self.n_free(0);
969            // 000 001 010 100 011 101 110 ~111~
970            iter::repeat(0..fragvars)
971                .take(fragvars)
972                .multi_cartesian_product()
973                .filter(|xs| {
974                    if let Some(x) = xs.iter().max() {
975                        if *x == 0 {
976                            true
977                        } else {
978                            (0..*x).all(|y| xs.contains(&y))
979                        }
980                    } else {
981                        true
982                    }
983                })
984                .pad_using(1, |_| Vec::new())
985                .map(move |mut assignment| {
986                    for x in &mut assignment {
987                        *x += n_free
988                    }
989                    let mut c = Canonicalizer::new(assignment);
990                    let mut frag = self.clone();
991                    c.canonicalize(&mut frag, 0);
992                    frag.into_expression()
993                })
994        }
995        fn into_expression(self) -> Expression {
996            match self {
997                Fragment::Expression(expr) => expr,
998                Fragment::Application(f, x) => Expression::Application(
999                    Box::new(f.into_expression()),
1000                    Box::new(x.into_expression()),
1001                ),
1002                Fragment::Abstraction(body) => {
1003                    Expression::Abstraction(Box::new(body.into_expression()))
1004                }
1005                _ => panic!("cannot convert fragment that still has variables"),
1006            }
1007        }
1008    }
1009    /// remove free variables from an expression by introducing abstractions.
1010    pub fn defragment(mut fragment_expr: Expression) -> Expression {
1011        let reach = free_reach(&fragment_expr, 0);
1012        for _ in 0..reach {
1013            let body = Box::new(fragment_expr);
1014            fragment_expr = Expression::Abstraction(body);
1015        }
1016        fragment_expr
1017    }
1018
1019    struct Canonicalizer {
1020        assignment: Vec<usize>,
1021        elapsed: usize,
1022        free: usize,
1023        mapping: HashMap<usize, usize>,
1024    }
1025    impl Canonicalizer {
1026        fn new(assignment: Vec<usize>) -> Canonicalizer {
1027            Canonicalizer {
1028                assignment,
1029                elapsed: 0,
1030                free: 0,
1031                mapping: HashMap::default(),
1032            }
1033        }
1034        fn canonicalize(&mut self, fr: &mut Fragment, depth: usize) {
1035            match *fr {
1036                Fragment::Expression(ref mut expr) => self.canonicalize_expr(expr, depth),
1037                Fragment::Application(ref mut f, ref mut x) => {
1038                    self.canonicalize(f, depth);
1039                    self.canonicalize(x, depth);
1040                }
1041                Fragment::Abstraction(ref mut body) => {
1042                    self.canonicalize(body, depth + 1);
1043                }
1044                Fragment::Variable => {
1045                    *fr = Fragment::Expression(Expression::Index(
1046                        self.assignment[self.elapsed] + depth,
1047                    ));
1048                    self.elapsed += 1;
1049                }
1050            }
1051        }
1052        fn canonicalize_expr(&mut self, expr: &mut Expression, depth: usize) {
1053            match *expr {
1054                Expression::Application(ref mut f, ref mut x) => {
1055                    self.canonicalize_expr(f, depth);
1056                    self.canonicalize_expr(x, depth);
1057                }
1058                Expression::Abstraction(ref mut body) => self.canonicalize_expr(body, depth + 1),
1059                Expression::Index(ref mut i) if *i >= depth => {
1060                    let j = i.checked_sub(depth).unwrap();
1061                    if let Some(k) = self.mapping.get(&j) {
1062                        *i = k + depth;
1063                        return;
1064                    }
1065                    self.mapping.insert(j, self.free);
1066                    *i = self.free + depth;
1067                    self.free += 1;
1068                }
1069                _ => (),
1070            }
1071        }
1072    }
1073
1074    /// main entry point for proposals
1075    pub fn from_expression(expr: &Expression, arity: u32) -> Vec<Expression> {
1076        (0..=arity)
1077            .flat_map(move |b| from_subexpression(expr, b))
1078            .flat_map(Fragment::canonicalize)
1079            .filter(|fragment_expr| {
1080                // determine if nontrivial
1081                let (n_prims, n_free, n_bound) = expression_count_kinds(fragment_expr, 0);
1082                n_prims >= 1 && ((n_prims as f64) + 0.5 * ((n_free + n_bound) as f64) > 1.5)
1083            })
1084            .flat_map(to_inventions)
1085            .collect()
1086    }
1087    fn from_subexpression(expr: &Expression, arity: u32) -> impl Iterator<Item = Fragment> + '_ {
1088        let rst: Box<dyn Iterator<Item = Fragment>> = match *expr {
1089            Expression::Application(ref f, ref x) => {
1090                Box::new(from_subexpression(f, arity).chain(from_subexpression(x, arity)))
1091            }
1092            Expression::Abstraction(ref body) => Box::new(from_subexpression(body, arity)),
1093            _ => Box::new(iter::empty()),
1094        };
1095        from_particular(expr, arity, true).chain(rst)
1096    }
1097    fn from_particular<'a>(
1098        expr: &'a Expression,
1099        arity: u32,
1100        toplevel: bool,
1101    ) -> Box<dyn Iterator<Item = Fragment> + 'a> {
1102        if arity == 0 {
1103            return Box::new(iter::once(Fragment::Expression(expr.clone())));
1104        }
1105        let rst: Box<dyn Iterator<Item = Fragment> + 'a> = match *expr {
1106            Expression::Application(ref f, ref x) => Box::new((0..=arity).flat_map(move |fa| {
1107                let xa = (arity as i32 - fa as i32) as u32;
1108                from_particular(f, fa, false)
1109                    .zip(iter::repeat(
1110                        from_particular(x, xa, false).collect::<Vec<_>>(),
1111                    ))
1112                    .flat_map(|(f, xs)| {
1113                        xs.into_iter()
1114                            .map(move |x| Fragment::Application(Box::new(f.clone()), Box::new(x)))
1115                    })
1116            })),
1117            Expression::Abstraction(ref body) if !toplevel => Box::new(
1118                from_particular(body, arity, false).map(|e| Fragment::Abstraction(Box::new(e))),
1119            ),
1120            _ => Box::new(iter::empty()),
1121        };
1122        Box::new(iter::once(Fragment::Variable).chain(rst))
1123    }
1124    fn to_inventions(expr: Expression) -> impl Iterator<Item = Expression> {
1125        // for any common subtree within the expression, replace with new index.
1126        let reach = free_reach(&expr, 0);
1127        let mut counts = HashMap::new();
1128        subtrees(expr.clone(), &mut counts);
1129        counts.remove(&expr);
1130        let fst = iter::once(expr.clone());
1131        let rst = counts
1132            .into_iter()
1133            .filter(|&(_, count)| count >= 2)
1134            .filter(|(expr, _)| is_closed(expr))
1135            .map(move |(subtree, _)| {
1136                let mut expr = expr.clone();
1137                substitute(&mut expr, &subtree, &Expression::Index(reach));
1138                expr
1139            });
1140        fst.chain(rst)
1141    }
1142
1143    /// How far out does the furthest reaching index go, excluding internal abstractions?
1144    ///
1145    /// For examples, the free reach of `(+ $0 (λ + 1 $0))` is 1, because there need to be one
1146    /// abstraction around the expression for it to make sense.
1147    fn free_reach(expr: &Expression, depth: usize) -> usize {
1148        match *expr {
1149            Expression::Application(ref f, ref x) => free_reach(f, depth).max(free_reach(x, depth)),
1150            Expression::Abstraction(ref body) => free_reach(body, depth + 1),
1151            Expression::Index(i) if i >= depth => 1 + i.checked_sub(depth).unwrap(),
1152            _ => 0,
1153        }
1154    }
1155
1156    /// Counts occurrences for every subtree of expr.
1157    fn subtrees(expr: Expression, counts: &mut HashMap<Expression, usize>) {
1158        match expr.clone() {
1159            Expression::Application(f, x) => {
1160                subtrees(*f, counts);
1161                subtrees(*x, counts);
1162                counts.entry(expr).or_insert(0);
1163            }
1164            Expression::Abstraction(body) => {
1165                subtrees(*body, counts);
1166                counts.entry(expr).or_insert(0);
1167            }
1168            Expression::Index(_) => (),
1169            Expression::Primitive(num) => {
1170                counts.entry(Expression::Primitive(num)).or_insert(0);
1171            }
1172            Expression::Invented(num) => {
1173                counts.entry(Expression::Invented(num)).or_insert(0);
1174            }
1175        }
1176    }
1177
1178    /// Whether every `Expression::Index` is bound within expr.
1179    fn is_closed(expr: &Expression) -> bool {
1180        free_reach(expr, 0) == 0
1181    }
1182
1183    /// Replace all occurrences of subtree in expr with replacement.
1184    fn substitute(expr: &mut Expression, subtree: &Expression, replacement: &Expression) {
1185        if expr == subtree {
1186            *expr = replacement.clone()
1187        } else {
1188            match *expr {
1189                Expression::Application(ref mut f, ref mut x) => {
1190                    substitute(f, subtree, replacement);
1191                    substitute(x, subtree, replacement);
1192                }
1193                Expression::Abstraction(ref mut body) => substitute(body, subtree, replacement),
1194                _ => (),
1195            }
1196        }
1197    }
1198}
1199
1200/// Counts of prims, free, bound
1201pub fn expression_count_kinds(expr: &Expression, abstraction_depth: usize) -> (u64, u64, u64) {
1202    match *expr {
1203        Expression::Primitive(_) | Expression::Invented(_) => (1, 0, 0),
1204        Expression::Index(i) => {
1205            if i < abstraction_depth {
1206                (0, 0, 1)
1207            } else {
1208                (0, 1, 0)
1209            }
1210        }
1211        Expression::Abstraction(ref b) => expression_count_kinds(b, abstraction_depth + 1),
1212        Expression::Application(ref l, ref r) => {
1213            let (l1, f1, b1) = expression_count_kinds(l, abstraction_depth);
1214            let (l2, f2, b2) = expression_count_kinds(r, abstraction_depth);
1215            (l1 + l2, f1 + f2, b1 + b2)
1216        }
1217    }
1218}