1use crossbeam_channel::bounded;
2use itertools::Itertools;
3use polytype::{Context, Type, TypeScheme};
4use rayon::join;
5use rayon::prelude::*;
6use std::borrow::Cow;
7use std::collections::{HashMap, VecDeque};
8use std::rc::Rc;
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::sync::{Arc, RwLock};
11
12use super::{Expression, Language, LinkedList};
13use crate::{ECFrontier, Task};
14
15pub struct CompressionParams {
21 pub pseudocounts: u64,
24 pub topk: usize,
27 pub structure_penalty: f64,
32 pub topk_use_only_likelihood: bool,
35 pub aic: f64,
38 pub arity: u32,
41}
42impl Default for CompressionParams {
43 fn default() -> Self {
59 CompressionParams {
60 pseudocounts: 5,
61 topk: 2,
62 topk_use_only_likelihood: false,
63 structure_penalty: 1f64,
64 aic: 1f64,
65 arity: 2,
66 }
67 }
68}
69
70#[allow(clippy::too_many_arguments)]
112pub fn induce<O, T, I, P, D, F, R>(
113 dsl: &Language,
114 params: &CompressionParams,
115 tasks: &[T],
116 mut original_frontiers: Vec<ECFrontier<Expression>>,
117 state: I,
118 proposer: P,
119 proposal_to_dsl: D,
120 defragment: F,
121 rewrite_frontiers: R,
122) -> (Language, Vec<ECFrontier<Expression>>)
123where
124 O: ?Sized,
125 T: Task<O, Representation = Language, Expression = Expression>,
126 I: Sync,
127 P: Fn(
128 &I,
129 &Language,
130 &[(TypeScheme, Vec<(Expression, f64, f64)>)],
131 &CompressionParams,
132 &mut Vec<T::Expression>,
133 ) + Sync,
134 D: Fn(
135 &I,
136 &T::Expression,
137 &mut Language,
138 &[(TypeScheme, Vec<(Expression, f64, f64)>)],
139 &CompressionParams,
140 ) -> Option<f64>
141 + Sync,
142 F: Fn(Expression) -> Expression,
143 R: Fn(
144 &I,
145 T::Expression,
146 Expression,
147 &Language,
148 &mut Vec<(TypeScheme, Vec<(Expression, f64, f64)>)>,
149 &CompressionParams,
150 ),
151{
152 let mut dsl = dsl.clone();
153 let mut frontiers: Vec<RescoredFrontier> = tasks
154 .par_iter()
155 .map(|t| t.tp().clone())
156 .zip(&original_frontiers)
157 .filter(|&(_, f)| !f.is_empty())
158 .map(|(tp, f)| (tp, f.0.clone()))
159 .collect();
160
161 let joint_mdl = dsl.inside_outside(&frontiers, params.pseudocounts);
162 let mut best_score = dsl.score(joint_mdl, params);
163
164 if cfg!(feature = "verbose") {
165 eprintln!("COMPRESSION: starting score: {}", best_score)
166 }
167 if params.aic.is_finite() {
168 loop {
169 let (candidate, fragment_expr) = {
170 let rescored_frontiers: Vec<_> = frontiers
171 .par_iter()
172 .cloned()
173 .map(|f| dsl.rescore_frontier(f, params.topk, params.topk_use_only_likelihood))
174 .collect();
175 let mut proposals = Vec::new();
176 proposer(&state, &dsl, &rescored_frontiers, params, &mut proposals);
177 if cfg!(feature = "verbose") {
178 eprintln!("COMPRESSION: proposed {} fragments", proposals.len())
179 }
180 let best_proposal = proposals
181 .into_par_iter()
182 .filter_map(|candidate| {
183 let mut dsl = dsl.clone();
184 let joint_mdl = match proposal_to_dsl(
185 &state,
186 &candidate,
187 &mut dsl,
188 &rescored_frontiers,
189 params,
190 ) {
191 None => {
192 if cfg!(feature = "verbose") {
193 eprintln!("COMPRESSION: dropped invalid proposal");
194 }
195 return None;
196 }
197 Some(joint_mdl) => joint_mdl,
198 };
199 let s = dsl.score(joint_mdl, params);
200 if s.is_finite() {
201 Some((dsl, candidate, s))
202 } else {
203 None
204 }
205 })
206 .max_by(|(_, _, x), (_, _, y)| x.partial_cmp(y).unwrap());
207 if best_proposal.is_none() {
208 if cfg!(feature = "verbose") {
209 eprintln!("COMPRESSION: no sufficient proposals")
210 }
211 break;
212 }
213 let (new_dsl, candidate, new_score) = best_proposal.unwrap();
214 if new_score <= best_score {
215 if cfg!(feature = "verbose") {
216 eprintln!("COMPRESSION: score did not improve")
217 }
218 break;
219 }
220 dsl = new_dsl;
221 best_score = new_score;
222
223 let (fragment_expr, _, log_prior) = dsl.invented.pop().unwrap();
224 let inv = defragment(fragment_expr.clone());
225 if cfg!(feature = "verbose") {
226 eprintln!(
227 "COMPRESSION: score improved to {} with invention {} (defragmented from candidate expr {})",
228 best_score,
229 dsl.display(&inv),
230 dsl.display(&fragment_expr)
231 )
232 }
233 dsl.invent(inv, log_prior).expect("invalid invention");
234 (candidate, fragment_expr)
235 };
236 rewrite_frontiers(
237 &state,
238 candidate,
239 fragment_expr,
240 &dsl,
241 &mut frontiers,
242 params,
243 )
244 }
245 }
246 frontiers.reverse();
247 for f in &mut original_frontiers {
248 if !f.is_empty() {
249 f.0 = frontiers.pop().unwrap().1;
250 }
251 }
252 (dsl, original_frontiers)
253}
254
255pub type RescoredFrontier = (TypeScheme, Vec<(Expression, f64, f64)>);
257
258pub fn joint_mdl(dsl: &Language, frontiers: &[RescoredFrontier]) -> f64 {
259 frontiers
260 .par_iter()
261 .map(|(t, f)| {
262 f.iter()
263 .map(|e| e.2 + dsl.likelihood(t, &e.0))
264 .fold(f64::NEG_INFINITY, f64::max)
265 })
266 .sum::<f64>()
267}
268
269pub fn inside_outside(
272 dsl: &mut Language,
273 frontiers: &[RescoredFrontier],
274 pseudocounts: u64,
275) -> f64 {
276 dsl.inside_outside_internal(frontiers, pseudocounts)
277}
278
279pub fn induce_fragment_grammar<Observation: ?Sized>(
280 dsl: &Language,
281 params: &CompressionParams,
282 tasks: &[impl Task<Observation, Representation = Language, Expression = Expression>],
283 original_frontiers: Vec<ECFrontier<Expression>>,
284) -> (Language, Vec<ECFrontier<Expression>>) {
285 induce(
286 dsl,
287 params,
288 tasks,
289 original_frontiers,
290 (),
291 |_, dsl, rescored_frontiers, params, proposals| {
292 dsl.propose_inventions(rescored_frontiers, params.arity, proposals)
293 },
294 |_, expr, dsl, rescored_frontiers, params| {
295 if dsl.invent(expr.clone(), 0.).is_ok() {
296 Some(dsl.inside_outside(rescored_frontiers, params.pseudocounts))
297 } else {
298 None
299 }
300 },
301 proposals::defragment,
302 |_, fragment_expr, _, dsl, frontiers, _| {
303 let i = dsl.invented.len() - 1;
304 for f in frontiers {
305 dsl.rewrite_frontier_with_fragment_expression(f, i, &fragment_expr);
306 }
307 },
308 )
309}
310
311impl Language {
313 fn rescore_frontier(
314 &self,
315 f: RescoredFrontier,
316 topk: usize,
317 topk_use_only_likelihood: bool,
318 ) -> RescoredFrontier {
319 let xs =
320 f.1.iter()
321 .map(|&(ref expr, _, loglikelihood)| {
322 let logprior = self.uses(&f.0, expr).0;
323 (expr, logprior, loglikelihood, logprior + loglikelihood)
324 })
325 .sorted_by(|(_, _, xl, xpost), (_, _, yl, ypost)| {
326 if topk_use_only_likelihood {
327 yl.partial_cmp(xl).unwrap()
328 } else {
329 ypost.partial_cmp(xpost).unwrap()
330 }
331 })
332 .take(topk)
333 .map(|(expr, logprior, loglikelihood, _)| (expr.clone(), logprior, loglikelihood))
334 .collect();
335 (f.0, xs)
336 }
337
338 fn reset_uniform(&mut self) {
339 for x in &mut self.primitives {
340 x.2 = 0f64;
341 }
342 for x in &mut self.invented {
343 x.2 = 0f64;
344 }
345 self.variable_logprob = 0f64;
346 }
347
348 fn inside_outside_internal(
349 &mut self,
350 frontiers: &[RescoredFrontier],
351 pseudocounts: u64,
352 ) -> f64 {
353 self.reset_uniform();
354 let pseudocounts = pseudocounts as f64;
355 let (joint_mdl, u) = self.all_uses(frontiers);
356 self.variable_logprob = (u.actual_vars + pseudocounts).ln() - u.possible_vars.ln();
357 if !self.variable_logprob.is_finite() {
358 self.variable_logprob = u.actual_vars.max(1f64).ln()
359 }
360 for (i, prim) in self.primitives.iter_mut().enumerate() {
361 let obs = u.actual_prims[i] + pseudocounts;
362 let pot = u.possible_prims[i];
363 let pot = if pot == 0f64 { pseudocounts } else { pot };
364 prim.2 = obs.ln() - pot.ln();
365 }
366 for (i, inv) in self.invented.iter_mut().enumerate() {
367 let obs = u.actual_invented[i];
368 let pot = u.possible_invented[i];
369 inv.2 = obs.ln() - pot.ln();
370 }
371 joint_mdl
372 }
373
374 fn all_uses(&self, frontiers: &[RescoredFrontier]) -> (f64, Uses) {
375 let (tx, rx) = bounded(frontiers.len());
376 let u = frontiers
377 .par_iter()
378 .flat_map(|f| {
379 let lu =
380 f.1.iter()
381 .map(|&(ref expr, _logprior, loglikelihood)| {
382 let (logprior, u) = self.uses(&f.0, expr);
383 (logprior + loglikelihood, u)
384 })
385 .collect::<Vec<_>>();
386 let largest = lu.iter().fold(f64::NEG_INFINITY, |acc, &(l, _)| acc.max(l));
387 tx.send(largest).expect("send on closed channel");
388 let z = largest
389 + lu.iter()
390 .map(|&(l, _)| (l - largest).exp())
391 .sum::<f64>()
392 .ln();
393 lu.into_par_iter().map(move |(l, mut u)| {
394 u.scale((l - z).exp());
395 u
396 })
397 })
398 .reduce(
399 || Uses::new(self),
400 |mut u, nu| {
401 u.merge(nu);
402 u
403 },
404 );
405 let joint_mdl = rx.into_iter().take(frontiers.len()).sum();
406 (joint_mdl, u)
407 }
408
409 fn uses(&self, request: &TypeScheme, expr: &Expression) -> (f64, Uses) {
412 let mut ctx = Context::default();
413 let tp = request.clone().instantiate_owned(&mut ctx);
414 let env = Rc::new(LinkedList::default());
415 self.likelihood_uses(&tp, expr, &ctx, &env)
416 }
417
418 fn likelihood_uses(
421 &self,
422 request: &Type,
423 expr: &Expression,
424 ctx: &Context,
425 env: &Rc<LinkedList<Type>>,
426 ) -> (f64, Uses) {
427 if let Some((arg, ret)) = request.as_arrow() {
428 let env = LinkedList::prepend(env, arg.clone());
429 if let Expression::Abstraction(ref body) = *expr {
430 self.likelihood_uses(ret, body, ctx, &env)
431 } else {
432 (f64::NEG_INFINITY, Uses::new(self)) }
434 } else {
435 let candidates = self.candidates(request, ctx, &env.as_vecdeque());
436 let mut possible_vars = 0f64;
437 let mut possible_prims = vec![0f64; self.primitives.len()];
438 let mut possible_invented = vec![0f64; self.invented.len()];
439 for (_, expr, _, _) in &candidates {
440 match *expr {
441 Expression::Primitive(num) => possible_prims[num] = 1f64,
442 Expression::Invented(num) => possible_invented[num] = 1f64,
443 Expression::Index(_) => possible_vars = 1f64,
444 _ => unreachable!(),
445 }
446 }
447 let mut total_likelihood = f64::NEG_INFINITY;
448 let mut weighted_uses: Vec<(f64, Uses)> = Vec::new();
449 let mut f = expr;
450 let mut xs: VecDeque<&Expression> = VecDeque::new();
451 loop {
452 for &(mut l, ref expr, ref tp, ref cctx) in &candidates {
455 let mut ctx = Cow::Borrowed(cctx);
456 let mut tp = Cow::Borrowed(tp);
457 let mut bindings = HashMap::new();
458 if let Expression::Index(_) = *expr {
460 if expr != f {
461 continue;
462 }
463 } else if let Some(mut frag_tp) =
464 TreeMatcher::do_match(self, ctx.to_mut(), expr, f, &mut bindings, xs.len())
465 {
466 let mut template = VecDeque::with_capacity(xs.len() + 1);
467 template.push_front(request.clone());
468 for _ in 0..xs.len() {
469 template.push_front(ctx.to_mut().new_variable())
470 }
471 if ctx
473 .to_mut()
474 .unify(&frag_tp, &Type::from(template.clone()))
475 .is_err()
476 {
477 eprintln!(
478 "WARNING (please report to programinduction devs): likelihood unification failure against expr={} (tp={}) for f={} frag_tp={} tmpl_tp={} xs={:?}",
479 self.display(expr),
480 tp,
481 self.display(f),
482 frag_tp,
483 Type::from(template),
484 xs.iter().map(|x| self.display(x)).collect::<Vec<_>>(),
485 );
486 continue;
487 }
488 frag_tp.apply_mut(&ctx);
489 tp = Cow::Owned(frag_tp);
490 } else {
491 continue;
492 }
493
494 let arg_tps: VecDeque<&Type> = tp.args().unwrap_or_default();
495 if xs.len() != arg_tps.len() {
496 eprintln!(
497 "WARNING (please report to programinduction devs): xs and arg_tps did not correspond: expr={} (arg_tps={:?}) f={} xs={:?}",
498 self.display(expr),
499 arg_tps.iter().map(std::string::ToString::to_string).collect::<Vec<_>>(),
500 self.display(f),
501 xs.iter().map(|x| self.display(x)).collect::<Vec<_>>(),
502 );
503 continue;
504 }
505
506 let mut u = Uses {
507 actual_vars: 0f64,
508 actual_prims: vec![0f64; self.primitives.len()],
509 actual_invented: vec![0f64; self.invented.len()],
510 possible_vars,
511 possible_prims: possible_prims.clone(),
512 possible_invented: possible_invented.clone(),
513 };
514 match *expr {
515 Expression::Primitive(num) => u.actual_prims[num] = 1f64,
516 Expression::Invented(num) => u.actual_invented[num] = 1f64,
517 Expression::Index(_) => u.actual_vars = 1f64,
518 _ => unreachable!(),
519 }
520
521 for (free_tp, free_expr) in bindings
522 .iter()
523 .map(|(_, (tp, expr))| (tp, expr))
524 .chain(arg_tps.into_iter().zip(xs.iter().cloned()))
525 {
526 let mut free_tp = free_tp.clone();
527 loop {
528 let free_tp_new = free_tp.apply(&ctx);
529 if free_tp_new != free_tp {
530 free_tp = free_tp_new;
531 } else {
532 break;
533 }
534 }
535 let n = self.likelihood_uses(&free_tp, free_expr, &ctx, env);
536 if n.0.is_infinite() {
537 l = f64::NEG_INFINITY;
538 break;
539 }
540 l += n.0;
541 u.merge(n.1);
542 }
543
544 if l.is_infinite() {
545 continue;
546 }
547 weighted_uses.push((l, u));
548 total_likelihood = if total_likelihood > l {
549 total_likelihood + (1f64 + (l - total_likelihood).exp()).ln()
550 } else {
551 l + (1f64 + (total_likelihood - l).exp()).ln()
552 };
553 }
554
555 if let Expression::Application(ref ff, ref x) = *f {
556 f = ff;
557 xs.push_front(x);
558 } else {
559 break;
560 }
561 }
562
563 let mut u = Uses::new(self);
564 if total_likelihood.is_finite() && !weighted_uses.is_empty() {
565 u.join_from(total_likelihood, weighted_uses)
566 }
567 (total_likelihood, u)
568 }
569 }
570
571 fn rewrite_frontier_with_fragment_expression(
573 &self,
574 f: &mut RescoredFrontier,
575 i: usize,
576 expr: &Expression,
577 ) -> bool {
578 let results: Vec<_> =
579 f.1.iter_mut()
580 .map(|x| self.rewrite_expression(&mut x.0, i, expr, 0))
581 .collect();
582 results.iter().any(|&x| x)
583 }
584 fn rewrite_expression(
585 &self,
586 expr: &mut Expression,
587 inv_n: usize,
588 inv: &Expression,
589 n_args: usize,
590 ) -> bool {
591 let mut rewrote = false;
592 let do_rewrite = match *expr {
593 Expression::Application(ref mut f, ref mut x) => {
594 rewrote |= self.rewrite_expression(f, inv_n, inv, n_args + 1);
595 rewrote |= self.rewrite_expression(x, inv_n, inv, 0);
596 true
597 }
598 Expression::Abstraction(ref mut body) => {
599 rewrote |= self.rewrite_expression(body, inv_n, inv, 0);
600 true
601 }
602 _ => false,
603 };
604 if do_rewrite {
605 let mut bindings = HashMap::new();
606 let mut ctx = Context::default();
607 let matches =
608 TreeMatcher::do_match(self, &mut ctx, inv, expr, &mut bindings, n_args).is_some();
609 if matches {
610 let mut new_expr = Expression::Invented(inv_n);
611 for j in (0..bindings.len()).rev() {
612 let (_, b) = &bindings[&j];
613 let inner = Box::new(new_expr);
614 new_expr = Expression::Application(inner, Box::new(b.clone()));
615 }
616 *expr = new_expr;
617 rewrote = true
618 }
619 }
620 rewrote
621 }
622
623 fn propose_inventions(
625 &self,
626 frontiers: &[RescoredFrontier],
627 arity: u32,
628 proposals: &mut Vec<Expression>,
629 ) {
630 let (tx, rx) = bounded(100);
631 join(
632 move || {
633 let findings = Arc::new(RwLock::new(HashMap::new()));
634 frontiers
635 .par_iter()
636 .flat_map(|f| &f.1)
637 .flat_map(|(expr, _, _)| proposals::from_expression(expr, arity))
638 .filter(|fragment_expr| {
639 let expr = proposals::defragment(fragment_expr.clone());
640 !self.invented.iter().any(|(x, _, _)| x == &expr)
641 })
642 .for_each(|fragment_expr| {
643 let res = {
644 let h = findings.read().expect("hashmap was poisoned");
645 h.get(&fragment_expr)
646 .map(|x: &AtomicUsize| x.fetch_add(1, Ordering::SeqCst))
647 };
648 match res {
649 Some(2) if self.infer(&fragment_expr).is_ok() => tx
650 .send(fragment_expr)
651 .expect("failed to send fragment proposal"),
652 None => {
653 let mut h = findings.write().expect("hashmap was poisoned");
654 let count = h
655 .entry(fragment_expr.clone())
656 .or_insert_with(|| AtomicUsize::new(0));
657 if 2 == count.fetch_add(1, Ordering::SeqCst)
658 && self.infer(&fragment_expr).is_ok()
659 {
660 tx.send(fragment_expr)
661 .expect("failed to send fragment proposal")
662 }
663 }
664 _ => (),
665 }
666 })
667 },
668 move || proposals.extend(rx),
669 );
670 }
671}
672
673struct TreeMatcher<'a> {
674 dsl: &'a Language,
675 ctx: &'a mut Context,
676 bindings: &'a mut HashMap<usize, (Type, Expression)>,
677}
678impl<'a> TreeMatcher<'a> {
679 fn do_match(
683 dsl: &Language,
684 ctx: &mut Context,
685 fragment: &Expression,
686 concrete: &Expression,
687 bindings: &mut HashMap<usize, (Type, Expression)>,
688 n_args: usize,
689 ) -> Option<Type> {
690 if !Self::might_match(dsl, fragment, concrete, 0) {
691 None
692 } else {
693 let mut tm = TreeMatcher { dsl, ctx, bindings };
694 tm.execute(fragment, concrete, &Rc::new(LinkedList::default()), n_args)
695 }
696 }
697
698 fn might_match(
700 dsl: &Language,
701 fragment: &Expression,
702 concrete: &Expression,
703 depth: usize,
704 ) -> bool {
705 match *fragment {
706 Expression::Index(i) if i >= depth => true,
707 Expression::Abstraction(ref f_body) => {
708 if let Expression::Abstraction(ref e_body) = *concrete {
709 Self::might_match(dsl, f_body, e_body, depth + 1)
710 } else {
711 false
712 }
713 }
714 Expression::Application(ref f_f, ref f_x) => {
715 if let Expression::Application(ref c_f, ref c_x) = *concrete {
716 Self::might_match(dsl, f_x, c_x, depth)
717 && Self::might_match(dsl, f_f, c_f, depth)
718 } else {
719 false
720 }
721 }
722 Expression::Invented(f_num) => {
723 if let Expression::Invented(c_num) = *concrete {
724 f_num == c_num
725 } else {
726 Self::might_match(dsl, &dsl.invented[f_num].0, concrete, depth)
727 }
728 }
729 _ => fragment == concrete,
730 }
731 }
732
733 fn execute(
734 &mut self,
735 fragment: &Expression,
736 concrete: &Expression,
737 env: &Rc<LinkedList<Type>>,
738 n_args: usize,
739 ) -> Option<Type> {
740 match (fragment, concrete) {
741 (Expression::Application(f_f, f_x), Expression::Application(c_f, c_x)) => {
742 let ft = self.execute(f_f, c_f, env, n_args)?;
743 let xt = self.execute(f_x, c_x, env, n_args)?;
744 let ret = self.ctx.new_variable();
745 if self.ctx.unify(&ft, &Type::arrow(xt, ret.clone())).is_ok() {
746 Some(ret.apply(self.ctx))
747 } else {
748 None
749 }
750 }
751 (&Expression::Primitive(f_num), &Expression::Primitive(c_num)) if f_num == c_num => {
752 let tp = self.dsl.primitives[f_num].1.clone();
753 Some(tp.instantiate_owned(self.ctx))
754 }
755 (&Expression::Invented(f_num), &Expression::Invented(c_num)) => {
756 if f_num == c_num {
757 let tp = self.dsl.invented[f_num].1.clone();
758 Some(tp.instantiate_owned(self.ctx))
759 } else {
760 None
761 }
762 }
763 (&Expression::Invented(f_num), _) => {
764 let inv = &self.dsl.invented[f_num].0;
765 self.execute(inv, concrete, env, n_args)
766 }
767 (Expression::Abstraction(f_body), Expression::Abstraction(c_body)) => {
768 let arg = self.ctx.new_variable();
769 let env = LinkedList::prepend(env, arg.clone());
770 let ret = self.execute(f_body, c_body, &env, 0)?;
771 Some(Type::arrow(arg, ret))
772 }
773 (&Expression::Index(i), _) if i < env.len() => {
774 if fragment == concrete {
776 let mut tp = env[i].clone();
777 tp.apply_mut(self.ctx);
778 Some(tp)
779 } else {
780 None
781 }
782 }
783 (&Expression::Index(i), _) => {
784 let i = i - env.len();
786 let mut concrete = concrete.clone();
788 if concrete.shift(-(env.len() as i64)) {
789 if n_args > 0 {
791 concrete.shift(n_args as i64);
792 for j in 0..n_args {
793 concrete = Expression::Application(
794 Box::new(concrete),
795 Box::new(Expression::Index(j)),
796 );
797 }
798 for _ in 0..n_args {
799 concrete = Expression::Abstraction(Box::new(concrete));
800 }
801 }
802 if let Some((tp, binding)) = self.bindings.get(&i) {
804 return if binding == &concrete {
805 Some(tp.clone())
806 } else {
807 None
808 };
809 }
810 let tp = self.ctx.new_variable();
811 self.bindings.insert(i, (tp.clone(), concrete));
812 Some(tp)
813 } else {
814 None
815 }
816 }
817 _ => None,
818 }
819 }
820}
821
822#[derive(Debug, Clone)]
823struct Uses {
824 actual_vars: f64,
825 possible_vars: f64,
826 actual_prims: Vec<f64>,
827 possible_prims: Vec<f64>,
828 actual_invented: Vec<f64>,
829 possible_invented: Vec<f64>,
830}
831impl Uses {
832 fn new(dsl: &Language) -> Uses {
833 let n_primitives = dsl.primitives.len();
834 let n_invented = dsl.invented.len();
835 Uses {
836 actual_vars: 0f64,
837 possible_vars: 0f64,
838 actual_prims: vec![0f64; n_primitives],
839 possible_prims: vec![0f64; n_primitives],
840 actual_invented: vec![0f64; n_invented],
841 possible_invented: vec![0f64; n_invented],
842 }
843 }
844 fn scale(&mut self, s: f64) {
845 self.actual_vars *= s;
846 self.possible_vars *= s;
847 self.actual_prims.iter_mut().for_each(|x| *x *= s);
848 self.possible_prims.iter_mut().for_each(|x| *x *= s);
849 self.actual_invented.iter_mut().for_each(|x| *x *= s);
850 self.possible_invented.iter_mut().for_each(|x| *x *= s);
851 }
852 fn merge(&mut self, other: Uses) {
853 self.actual_vars += other.actual_vars;
854 self.possible_vars += other.possible_vars;
855 self.actual_prims
856 .iter_mut()
857 .zip(other.actual_prims)
858 .for_each(|(a, b)| *a += b);
859 self.possible_prims
860 .iter_mut()
861 .zip(other.possible_prims)
862 .for_each(|(a, b)| *a += b);
863 self.actual_invented
864 .iter_mut()
865 .zip(other.actual_invented)
866 .for_each(|(a, b)| *a += b);
867 self.possible_invented
868 .iter_mut()
869 .zip(other.possible_invented)
870 .for_each(|(a, b)| *a += b);
871 }
872 fn join_from(&mut self, z: f64, mut weighted_uses: Vec<(f64, Uses)>) {
875 for &mut (l, ref mut u) in &mut weighted_uses {
876 u.scale((l - z).exp());
877 }
878 self.actual_vars = weighted_uses
879 .iter()
880 .map(|(_, u)| u.actual_vars)
881 .sum::<f64>();
882 self.possible_vars = weighted_uses
883 .iter()
884 .map(|(_, u)| u.possible_vars)
885 .sum::<f64>();
886 self.actual_prims.iter_mut().enumerate().for_each(|(i, c)| {
887 *c = weighted_uses
888 .iter()
889 .map(|(_, u)| u.actual_prims[i])
890 .sum::<f64>()
891 });
892 self.possible_prims
893 .iter_mut()
894 .enumerate()
895 .for_each(|(i, c)| {
896 *c = weighted_uses
897 .iter()
898 .map(|(_, u)| u.possible_prims[i])
899 .sum::<f64>()
900 });
901 self.actual_invented
902 .iter_mut()
903 .enumerate()
904 .for_each(|(i, c)| {
905 *c = weighted_uses
906 .iter()
907 .map(|(_, u)| u.actual_invented[i])
908 .sum::<f64>()
909 });
910 self.possible_invented
911 .iter_mut()
912 .enumerate()
913 .for_each(|(i, c)| {
914 *c = weighted_uses
915 .iter()
916 .map(|(_, u)| u.possible_invented[i])
917 .sum::<f64>()
918 });
919 }
920}
921
922mod proposals {
923 use super::super::Expression;
927 use super::expression_count_kinds;
928 use itertools::Itertools;
929 use std::collections::HashMap;
930 use std::iter;
931
932 #[derive(Clone, Debug)]
933 enum Fragment {
934 Variable,
935 Application(Box<Fragment>, Box<Fragment>),
936 Abstraction(Box<Fragment>),
937 Expression(Expression),
938 }
939 impl Fragment {
940 fn fragvars(&self) -> usize {
941 match self {
942 Fragment::Expression(_) => 0,
943 Fragment::Application(f, x) => f.fragvars() + x.fragvars(),
944 Fragment::Abstraction(body) => body.fragvars(),
945 Fragment::Variable => 1,
946 }
947 }
948 fn n_free(&self, depth: usize) -> usize {
949 match self {
950 Fragment::Expression(expr) => Fragment::n_free_expr(expr, depth),
951 Fragment::Application(f, x) => f.n_free(depth) + x.n_free(depth),
952 Fragment::Abstraction(body) => body.n_free(depth + 1),
953 Fragment::Variable => 0,
954 }
955 }
956 fn n_free_expr(expr: &Expression, depth: usize) -> usize {
957 match expr {
958 Expression::Application(f, x) => {
959 Fragment::n_free_expr(f, depth) + Fragment::n_free_expr(x, depth)
960 }
961 Expression::Abstraction(body) => Fragment::n_free_expr(body, depth + 1),
962 Expression::Index(i) if *i >= depth => 1,
963 _ => 0,
964 }
965 }
966 fn canonicalize(self) -> impl Iterator<Item = Expression> {
967 let fragvars = self.fragvars();
968 let n_free = self.n_free(0);
969 iter::repeat(0..fragvars)
971 .take(fragvars)
972 .multi_cartesian_product()
973 .filter(|xs| {
974 if let Some(x) = xs.iter().max() {
975 if *x == 0 {
976 true
977 } else {
978 (0..*x).all(|y| xs.contains(&y))
979 }
980 } else {
981 true
982 }
983 })
984 .pad_using(1, |_| Vec::new())
985 .map(move |mut assignment| {
986 for x in &mut assignment {
987 *x += n_free
988 }
989 let mut c = Canonicalizer::new(assignment);
990 let mut frag = self.clone();
991 c.canonicalize(&mut frag, 0);
992 frag.into_expression()
993 })
994 }
995 fn into_expression(self) -> Expression {
996 match self {
997 Fragment::Expression(expr) => expr,
998 Fragment::Application(f, x) => Expression::Application(
999 Box::new(f.into_expression()),
1000 Box::new(x.into_expression()),
1001 ),
1002 Fragment::Abstraction(body) => {
1003 Expression::Abstraction(Box::new(body.into_expression()))
1004 }
1005 _ => panic!("cannot convert fragment that still has variables"),
1006 }
1007 }
1008 }
1009 pub fn defragment(mut fragment_expr: Expression) -> Expression {
1011 let reach = free_reach(&fragment_expr, 0);
1012 for _ in 0..reach {
1013 let body = Box::new(fragment_expr);
1014 fragment_expr = Expression::Abstraction(body);
1015 }
1016 fragment_expr
1017 }
1018
1019 struct Canonicalizer {
1020 assignment: Vec<usize>,
1021 elapsed: usize,
1022 free: usize,
1023 mapping: HashMap<usize, usize>,
1024 }
1025 impl Canonicalizer {
1026 fn new(assignment: Vec<usize>) -> Canonicalizer {
1027 Canonicalizer {
1028 assignment,
1029 elapsed: 0,
1030 free: 0,
1031 mapping: HashMap::default(),
1032 }
1033 }
1034 fn canonicalize(&mut self, fr: &mut Fragment, depth: usize) {
1035 match *fr {
1036 Fragment::Expression(ref mut expr) => self.canonicalize_expr(expr, depth),
1037 Fragment::Application(ref mut f, ref mut x) => {
1038 self.canonicalize(f, depth);
1039 self.canonicalize(x, depth);
1040 }
1041 Fragment::Abstraction(ref mut body) => {
1042 self.canonicalize(body, depth + 1);
1043 }
1044 Fragment::Variable => {
1045 *fr = Fragment::Expression(Expression::Index(
1046 self.assignment[self.elapsed] + depth,
1047 ));
1048 self.elapsed += 1;
1049 }
1050 }
1051 }
1052 fn canonicalize_expr(&mut self, expr: &mut Expression, depth: usize) {
1053 match *expr {
1054 Expression::Application(ref mut f, ref mut x) => {
1055 self.canonicalize_expr(f, depth);
1056 self.canonicalize_expr(x, depth);
1057 }
1058 Expression::Abstraction(ref mut body) => self.canonicalize_expr(body, depth + 1),
1059 Expression::Index(ref mut i) if *i >= depth => {
1060 let j = i.checked_sub(depth).unwrap();
1061 if let Some(k) = self.mapping.get(&j) {
1062 *i = k + depth;
1063 return;
1064 }
1065 self.mapping.insert(j, self.free);
1066 *i = self.free + depth;
1067 self.free += 1;
1068 }
1069 _ => (),
1070 }
1071 }
1072 }
1073
1074 pub fn from_expression(expr: &Expression, arity: u32) -> Vec<Expression> {
1076 (0..=arity)
1077 .flat_map(move |b| from_subexpression(expr, b))
1078 .flat_map(Fragment::canonicalize)
1079 .filter(|fragment_expr| {
1080 let (n_prims, n_free, n_bound) = expression_count_kinds(fragment_expr, 0);
1082 n_prims >= 1 && ((n_prims as f64) + 0.5 * ((n_free + n_bound) as f64) > 1.5)
1083 })
1084 .flat_map(to_inventions)
1085 .collect()
1086 }
1087 fn from_subexpression(expr: &Expression, arity: u32) -> impl Iterator<Item = Fragment> + '_ {
1088 let rst: Box<dyn Iterator<Item = Fragment>> = match *expr {
1089 Expression::Application(ref f, ref x) => {
1090 Box::new(from_subexpression(f, arity).chain(from_subexpression(x, arity)))
1091 }
1092 Expression::Abstraction(ref body) => Box::new(from_subexpression(body, arity)),
1093 _ => Box::new(iter::empty()),
1094 };
1095 from_particular(expr, arity, true).chain(rst)
1096 }
1097 fn from_particular<'a>(
1098 expr: &'a Expression,
1099 arity: u32,
1100 toplevel: bool,
1101 ) -> Box<dyn Iterator<Item = Fragment> + 'a> {
1102 if arity == 0 {
1103 return Box::new(iter::once(Fragment::Expression(expr.clone())));
1104 }
1105 let rst: Box<dyn Iterator<Item = Fragment> + 'a> = match *expr {
1106 Expression::Application(ref f, ref x) => Box::new((0..=arity).flat_map(move |fa| {
1107 let xa = (arity as i32 - fa as i32) as u32;
1108 from_particular(f, fa, false)
1109 .zip(iter::repeat(
1110 from_particular(x, xa, false).collect::<Vec<_>>(),
1111 ))
1112 .flat_map(|(f, xs)| {
1113 xs.into_iter()
1114 .map(move |x| Fragment::Application(Box::new(f.clone()), Box::new(x)))
1115 })
1116 })),
1117 Expression::Abstraction(ref body) if !toplevel => Box::new(
1118 from_particular(body, arity, false).map(|e| Fragment::Abstraction(Box::new(e))),
1119 ),
1120 _ => Box::new(iter::empty()),
1121 };
1122 Box::new(iter::once(Fragment::Variable).chain(rst))
1123 }
1124 fn to_inventions(expr: Expression) -> impl Iterator<Item = Expression> {
1125 let reach = free_reach(&expr, 0);
1127 let mut counts = HashMap::new();
1128 subtrees(expr.clone(), &mut counts);
1129 counts.remove(&expr);
1130 let fst = iter::once(expr.clone());
1131 let rst = counts
1132 .into_iter()
1133 .filter(|&(_, count)| count >= 2)
1134 .filter(|(expr, _)| is_closed(expr))
1135 .map(move |(subtree, _)| {
1136 let mut expr = expr.clone();
1137 substitute(&mut expr, &subtree, &Expression::Index(reach));
1138 expr
1139 });
1140 fst.chain(rst)
1141 }
1142
1143 fn free_reach(expr: &Expression, depth: usize) -> usize {
1148 match *expr {
1149 Expression::Application(ref f, ref x) => free_reach(f, depth).max(free_reach(x, depth)),
1150 Expression::Abstraction(ref body) => free_reach(body, depth + 1),
1151 Expression::Index(i) if i >= depth => 1 + i.checked_sub(depth).unwrap(),
1152 _ => 0,
1153 }
1154 }
1155
1156 fn subtrees(expr: Expression, counts: &mut HashMap<Expression, usize>) {
1158 match expr.clone() {
1159 Expression::Application(f, x) => {
1160 subtrees(*f, counts);
1161 subtrees(*x, counts);
1162 counts.entry(expr).or_insert(0);
1163 }
1164 Expression::Abstraction(body) => {
1165 subtrees(*body, counts);
1166 counts.entry(expr).or_insert(0);
1167 }
1168 Expression::Index(_) => (),
1169 Expression::Primitive(num) => {
1170 counts.entry(Expression::Primitive(num)).or_insert(0);
1171 }
1172 Expression::Invented(num) => {
1173 counts.entry(Expression::Invented(num)).or_insert(0);
1174 }
1175 }
1176 }
1177
1178 fn is_closed(expr: &Expression) -> bool {
1180 free_reach(expr, 0) == 0
1181 }
1182
1183 fn substitute(expr: &mut Expression, subtree: &Expression, replacement: &Expression) {
1185 if expr == subtree {
1186 *expr = replacement.clone()
1187 } else {
1188 match *expr {
1189 Expression::Application(ref mut f, ref mut x) => {
1190 substitute(f, subtree, replacement);
1191 substitute(x, subtree, replacement);
1192 }
1193 Expression::Abstraction(ref mut body) => substitute(body, subtree, replacement),
1194 _ => (),
1195 }
1196 }
1197 }
1198}
1199
1200pub fn expression_count_kinds(expr: &Expression, abstraction_depth: usize) -> (u64, u64, u64) {
1202 match *expr {
1203 Expression::Primitive(_) | Expression::Invented(_) => (1, 0, 0),
1204 Expression::Index(i) => {
1205 if i < abstraction_depth {
1206 (0, 0, 1)
1207 } else {
1208 (0, 1, 0)
1209 }
1210 }
1211 Expression::Abstraction(ref b) => expression_count_kinds(b, abstraction_depth + 1),
1212 Expression::Application(ref l, ref r) => {
1213 let (l1, f1, b1) = expression_count_kinds(l, abstraction_depth);
1214 let (l2, f2, b2) = expression_count_kinds(r, abstraction_depth);
1215 (l1 + l2, f1 + f2, b1 + b2)
1216 }
1217 }
1218}