1use itertools::Itertools;
25use once_cell::sync::Lazy;
26use polytype::{ptp, tp};
27use rand::Rng;
28use std::collections::HashMap;
29
30use crate::lambda::{
31 task_by_evaluation, Evaluator as EvaluatorT, Expression, Language, LiftedFunction,
32};
33use crate::Task;
34
35pub fn dsl() -> Language {
70 let mut dsl = Language::uniform(vec![
71 ("0", ptp!(int)),
72 ("+1", ptp!(@arrow[tp!(int), tp!(int)])),
73 ("-1", ptp!(@arrow[tp!(int), tp!(int)])),
74 ("len", ptp!(@arrow[tp!(str), tp!(int)])),
75 ("empty_str", ptp!(str)),
76 ("lower", ptp!(@arrow[tp!(str), tp!(str)])),
77 ("upper", ptp!(@arrow[tp!(str), tp!(str)])),
78 ("concat", ptp!(@arrow[tp!(str), tp!(str), tp!(str)])),
79 (
80 "slice",
81 ptp!(@arrow[tp!(int), tp!(int), tp!(str), tp!(str)]),
82 ),
83 ("nth", ptp!(@arrow[tp!(int), tp!(list(tp!(str))), tp!(str)])),
84 (
85 "map",
86 ptp!(0, 1; @arrow[tp!(@arrow[tp!(0), tp!(1)]), tp!(list(tp!(0))), tp!(list(tp!(1)))]),
87 ),
88 ("strip", ptp!(@arrow[tp!(str), tp!(str)])),
89 (
90 "split",
91 ptp!(@arrow[tp!(char), tp!(str), tp!(list(tp!(str)))]),
92 ),
93 (
94 "join",
95 ptp!(@arrow[tp!(str), tp!(list(tp!(str))), tp!(str)]),
96 ),
97 ("char->str", ptp!(@arrow[tp!(char), tp!(str)])),
98 ("space", ptp!(char)),
99 (".", ptp!(char)),
100 (",", ptp!(char)),
101 ("<", ptp!(char)),
102 (">", ptp!(char)),
103 ("/", ptp!(char)),
104 ("@", ptp!(char)),
105 ("-", ptp!(char)),
106 ("|", ptp!(char)),
107 ]);
108 dsl.add_symmetry_violation(1, 0, 2);
110 dsl.add_symmetry_violation(2, 0, 1);
111 dsl.add_symmetry_violation(3, 0, 4);
113 dsl.add_symmetry_violation(3, 0, 5);
114 dsl.add_symmetry_violation(3, 0, 6);
115 dsl.add_symmetry_violation(3, 0, 14);
116 dsl.add_symmetry_violation(7, 0, 7);
118 dsl.add_symmetry_violation(7, 0, 4);
120 dsl.add_symmetry_violation(7, 1, 4);
121 dsl
122}
123
124use self::Space::*;
125#[derive(Clone)]
127pub enum Space {
128 Num(i32),
129 Char(char),
130 Str(String),
131 List(Vec<Space>),
132 Func(LiftedFunction<Space, Evaluator>),
133}
134impl std::fmt::Debug for Space {
135 fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
136 match self {
137 Num(x) => write!(f, "Num({:?})", x),
138 Char(x) => write!(f, "Char({:?})", x),
139 Str(x) => write!(f, "Str({:?})", x),
140 List(x) => write!(f, "List({:?})", x),
141 Func(_) => write!(f, "<function>"),
142 }
143 }
144}
145impl PartialEq for Space {
146 fn eq(&self, other: &Self) -> bool {
147 match (self, other) {
148 (Num(x), Num(y)) => x == y,
149 (Char(x), Char(y)) => x == y,
150 (Str(x), Str(y)) => x == y,
151 (List(xs), List(ys)) => xs == ys,
152 _ => false,
153 }
154 }
155}
156
157#[derive(Copy, Clone)]
195pub struct Evaluator;
196impl EvaluatorT for Evaluator {
197 type Space = Space;
198 type Error = ();
199 fn evaluate(&self, name: &str, inps: &[Self::Space]) -> Result<Self::Space, Self::Error> {
200 match &OPERATIONS[name] {
201 Op::Zero => Ok(Num(0)),
202 Op::Incr => match inps {
203 [Num(x)] => Ok(Num(x + 1)),
204 _ => unreachable!(),
205 },
206 Op::Decr => match inps {
207 [Num(x)] => Ok(Num(x - 1)),
208 _ => unreachable!(),
209 },
210 Op::Len => match inps {
211 [Str(s)] => Ok(Num(s.len() as i32)),
212 _ => unreachable!(),
213 },
214 Op::Empty => Ok(Str(String::new())),
215 Op::Lower => match inps {
216 [Str(s)] => Ok(Str(s.to_lowercase())),
217 _ => unreachable!(),
218 },
219 Op::Upper => match inps {
220 [Str(s)] => Ok(Str(s.to_uppercase())),
221 _ => unreachable!(),
222 },
223 Op::Concat => match inps {
224 [Str(x), Str(y)] => {
225 let mut s = x.to_string();
226 s.push_str(y);
227 Ok(Str(s))
228 }
229 _ => unreachable!(),
230 },
231 Op::Slice => match inps {
232 [Num(x), Num(y), Str(s)] => {
233 if *x as usize > s.len() || y < x {
234 Err(())
235 } else {
236 Ok(Str(s
237 .chars()
238 .skip(*x as usize)
239 .take((y - x) as usize)
240 .collect()))
241 }
242 }
243 _ => unreachable!(),
244 },
245 Op::Nth => match inps {
246 [Num(x), List(ss)] => ss.get(*x as usize).cloned().ok_or(()),
247 _ => unreachable!(),
248 },
249 Op::Map => match inps {
250 [Func(f), List(xs)] => Ok(List(
251 xs.iter()
252 .map(|x| f.eval(&[x.clone()]).map_err(|_| ()))
253 .collect::<Result<_, _>>()?,
254 )),
255 _ => unreachable!(),
256 },
257 Op::Strip => match inps {
258 [Str(s)] => Ok(Str(s.trim().to_owned())),
259 _ => unreachable!(),
260 },
261 Op::Split => match inps {
262 [Char(c), Str(s)] => Ok(List(s.split(*c).map(|s| Str(s.to_owned())).collect())),
263 _ => unreachable!(),
264 },
265 Op::Join => match inps {
266 [Str(delim), List(ss)] => Ok(Str(ss
267 .iter()
268 .map(|s| match s {
269 Str(s) => s,
270 _ => unreachable!(),
271 })
272 .join(delim))),
273 _ => unreachable!(),
274 },
275 Op::CharToStr => match inps {
276 [Char(c)] => Ok(Str(c.to_string())),
277 _ => unreachable!(),
278 },
279 Op::CharSpace => Ok(Char(' ')),
280 Op::CharDot => Ok(Char('.')),
281 Op::CharComma => Ok(Char(',')),
282 Op::CharLess => Ok(Char('<')),
283 Op::CharGreater => Ok(Char('>')),
284 Op::CharSlash => Ok(Char('/')),
285 Op::CharAt => Ok(Char('@')),
286 Op::CharDash => Ok(Char('-')),
287 Op::CharPipe => Ok(Char('|')),
288 }
289 }
290 fn lift(&self, f: LiftedFunction<Self::Space, Self>) -> Option<Self::Space> {
291 Some(Func(f))
292 }
293}
294
295#[allow(clippy::type_complexity)]
302pub fn make_tasks<R: Rng>(
303 rng: &mut R,
304 count: usize,
305 n_examples: usize,
306) -> Vec<impl Task<[(Vec<Space>, Space)], Representation = Language, Expression = Expression>> {
307 (0..=count / 1467) .flat_map(|_| gen::make_examples(rng, n_examples))
309 .take(count)
310 .map(|(_name, tp, examples)| task_by_evaluation(Evaluator, tp, examples))
311 .collect()
312}
313
314enum Op {
316 Zero,
317 Incr,
318 Decr,
319 Len,
320 Empty,
321 Lower,
322 Upper,
323 Concat,
324 Slice,
325 Nth,
326 Map,
327 Strip,
328 Split,
329 Join,
330 CharToStr,
331 CharSpace,
332 CharDot,
333 CharComma,
334 CharLess,
335 CharGreater,
336 CharSlash,
337 CharAt,
338 CharDash,
339 CharPipe,
340}
341
342static OPERATIONS: Lazy<HashMap<&'static str, Op>> = Lazy::new(|| {
343 HashMap::from([
344 ("0", Op::Zero),
345 ("+1", Op::Incr),
346 ("-1", Op::Decr),
347 ("len", Op::Len),
348 ("empty_str", Op::Empty),
349 ("lower", Op::Lower),
350 ("upper", Op::Upper),
351 ("concat", Op::Concat),
352 ("slice", Op::Slice),
353 ("nth", Op::Nth),
354 ("map", Op::Map),
355 ("strip", Op::Strip),
356 ("split", Op::Split),
357 ("join", Op::Join),
358 ("char->str", Op::CharToStr),
359 ("space", Op::CharSpace),
360 (".", Op::CharDot),
361 (",", Op::CharComma),
362 ("<", Op::CharLess),
363 (">", Op::CharGreater),
364 ("/", Op::CharSlash),
365 ("@", Op::CharAt),
366 ("-", Op::CharDash),
367 ("|", Op::CharPipe),
368 ])
369});
370
371mod gen {
372 use itertools::Itertools;
373 use polytype::{ptp, tp, TypeScheme};
374 use rand::distributions::{Distribution, Uniform};
375 use rand::{self, Rng};
376
377 use super::Space::{self, *};
378
379 static DELIMITERS: [char; 9] = ['.', ',', ' ', '<', '>', '/', '@', '-', '|'];
380
381 fn character<R: Rng>(rng: &mut R) -> char {
382 let c: u8 = Uniform::from(0..26u8).sample(rng);
383 let c = c + if rng.gen() { 65 } else { 97 };
384 c as char
385 }
386
387 fn word<R: Rng>(rng: &mut R) -> String {
388 let size = Uniform::from(3..6).sample(rng);
389 (0..size).map(|_| character(rng)).collect()
390 }
391 fn words<R: Rng>(delim: char, rng: &mut R) -> String {
392 let size = Uniform::from(2..5).sample(rng);
393 (0..size).map(|_| word(rng)).join(&delim.to_string())
394 }
395
396 fn white_word<R: Rng>(rng: &mut R) -> String {
397 let size = Uniform::from(4..7).sample(rng);
398 let mut s: String = (0..size).map(|_| character(rng)).collect();
399 let n_spaces = Uniform::from(0..3).sample(rng);
400 for _ in 0..n_spaces {
401 let j = Uniform::from(1..s.len()).sample(rng);
402 s.insert(j, ' ');
403 }
404 let between = Uniform::from(0..3usize);
405 let mut starting = 0;
406 let mut ending = 0;
407 while starting == 0 && ending == 0 {
408 starting = between.sample(rng);
409 ending = between.sample(rng);
410 }
411 s.insert_str(0, &" ".repeat(starting));
412 let len = s.len();
413 s.insert_str(len, &" ".repeat(ending));
414 s
415 }
416 fn white_words<R: Rng>(delim: char, rng: &mut R) -> String {
417 let size = Uniform::from(2..5).sample(rng);
418 (0..size).map(|_| white_word(rng)).join(&delim.to_string())
419 }
420
421 #[allow(clippy::cognitive_complexity)]
422 #[allow(clippy::redundant_closure_call)]
423 #[allow(clippy::type_complexity)]
424 pub fn make_examples<R: Rng>(
425 rng: &mut R,
426 n_examples: usize,
427 ) -> Vec<(&'static str, TypeScheme, Vec<(Vec<Space>, Space)>)> {
428 let mut tasks = Vec::new();
429
430 macro_rules! t {
431 ($name:expr, $tp:expr, $body:block) => {
432 let examples = (0..n_examples)
433 .map(|_| {
434 let (i, o) = $body;
435 (vec![i], o)
436 })
437 .collect();
438 tasks.push(($name, $tp, examples));
439 };
440 }
441 t!(
442 "map strip",
443 ptp!(@arrow[tp!(list(tp!(str))), tp!(list(tp!(str)))]),
444 {
445 let n_words = Uniform::from(1..5).sample(rng);
446 let xs: Vec<_> = (0..n_words).map(|_| white_word(rng)).collect();
447 let ys = xs.iter().map(|s| Str(s.trim().to_owned())).collect();
448 let xs = xs.into_iter().map(Str).collect();
449 (List(xs), List(ys))
450 }
451 );
452 t!("strip", ptp!(@arrow[tp!(str), tp!(str)]), {
453 let x = white_word(rng);
454 let y = x.trim().to_owned();
455 (Str(x), Str(y))
456 });
457 for d in &DELIMITERS {
458 let d: char = *d;
459 t!(
460 "map strip after splitting on d",
461 ptp!(@arrow[tp!(str), tp!(list(tp!(str)))]),
462 {
463 let x = words(d, rng);
464 let ys = x.split(d).map(|s| Str(s.trim().to_owned())).collect();
465 (Str(x), List(ys))
466 }
467 );
468 t!(
469 "map strip and then join with d",
470 ptp!(@arrow[tp!(list(tp!(str))), tp!(str)]),
471 {
472 let n_words = Uniform::from(1..5).sample(rng);
473 let xs: Vec<_> = (0..n_words).map(|_| word(rng)).collect();
474 let y = xs.iter().map(|s| s.trim().to_owned()).join(&d.to_string());
475 let xs = xs.into_iter().map(Str).collect();
476 (List(xs), Str(y))
477 }
478 );
479 t!("delete delimiter d", ptp!(@arrow[tp!(str), tp!(str)]), {
480 let x = words(d, rng);
481 let y = x.replace(d, "");
482 (Str(x), Str(y))
483 });
484 t!(
485 "extract prefix up to d, exclusive",
486 ptp!(@arrow[tp!(str), tp!(str)]),
487 {
488 let y = word(rng);
489 let x = format!("{}{}{}", y, d, word(rng));
490 (Str(x), Str(y))
491 }
492 );
493 t!(
494 "extract prefix up to d, inclusive",
495 ptp!(@arrow[tp!(str), tp!(str)]),
496 {
497 let mut y = word(rng);
498 y.push(d);
499 let x = format!("{}{}{}", y, d, word(rng));
500 (Str(x), Str(y))
501 }
502 );
503 t!(
504 "extract suffix up to d, exclusive",
505 ptp!(@arrow[tp!(str), tp!(str)]),
506 {
507 let y = word(rng);
508 let x = format!("{}{}{}", word(rng), y, d);
509 (Str(x), Str(y))
510 }
511 );
512 t!(
513 "extract suffix up to d, inclusive",
514 ptp!(@arrow[tp!(str), tp!(str)]),
515 {
516 let y = format!("{}{}", word(rng), d);
517 let x = format!("{}{}{}", word(rng), y, d);
518 (Str(x), Str(y))
519 }
520 );
521 let d1 = d;
522 for d2 in &DELIMITERS {
523 let d2: char = *d2;
524 t!(
525 "extract delimited by d1, d2",
526 ptp!(@arrow[tp!(str), tp!(str)]),
527 {
528 let y = word(rng);
529 let x = format!("{}{}{}{}{}", word(rng), d1, y, d2, word(rng));
530 (Str(x), Str(y))
531 }
532 );
533 t!(
534 "extract delimited by d1 (incl), d2",
535 ptp!(@arrow[tp!(str), tp!(str)]),
536 {
537 let y = format!("{}{}{}", d1, word(rng), d2);
538 let x = format!("{}{}{}", word(rng), y, word(rng));
539 (Str(x), Str(y))
540 }
541 );
542 t!(
543 "extract delimited by d1 (incl), d2 (incl)",
544 ptp!(@arrow[tp!(str), tp!(str)]),
545 {
546 let y = format!("{}{}", d1, word(rng));
547 let x = format!("{}{}{}{}", word(rng), y, d2, word(rng));
548 (Str(x), Str(y))
549 }
550 );
551 if d1 != ' ' {
552 t!(
553 "strip delimited by d1 from inp delimited by d2",
554 ptp!(@arrow[tp!(str), tp!(str)]),
555 {
556 let x = white_words(d1, rng);
557 let y = x
558 .split(d1)
559 .map(|s| s.trim().to_owned())
560 .join(&d2.to_string());
561 (Str(x), Str(y))
562 }
563 );
564 if d2 != ' ' {
565 t!(
566 "strip from inp delimited by d1, d2",
567 ptp!(@arrow[tp!(str), tp!(str)]),
568 {
569 let y = white_word(rng);
570 let x = format!("{}{}{}{}{}", word(rng), d1, y, d2, word(rng));
571 (Str(x), Str(y))
572 }
573 );
574 }
575 }
576 if d1 != d2 {
577 t!(
578 "replace delimiter d1 with d2",
579 ptp!(@arrow[tp!(str), tp!(str)]),
580 {
581 let x = words(d1, rng);
582 let y = x.replace(d1, &d2.to_string());
583 (Str(x), Str(y))
584 }
585 );
586 }
587 }
588 }
589
590 macro_rules! single_word_edit {
591 ($name:expr, $f:expr) => {
592 t!(
593 concat!($name, " strip"),
594 ptp!(@arrow[tp!(str), tp!(str)]),
595 {
596 let x = white_word(rng);
597 let y = ($f)(&x);
598 (Str(x), Str(y))
599 }
600 );
601 t!(
602 concat!("map ", $name),
603 ptp!(@arrow[tp!(list(tp!(str))), tp!(list(tp!(str)))]),
604 {
605 let n_words = Uniform::from(1..5).sample(rng);
606 let xs: Vec<_> = (0..n_words).map(|_| word(rng)).collect();
607 let ys = xs.iter().map(|s| Str(($f)(s))).collect();
608 let xs = xs.into_iter().map(Str).collect();
609 (List(xs), List(ys))
610 }
611 );
612 for d in &DELIMITERS {
613 let d: char = *d;
614 t!(
615 concat!("map ", $name, " after splitting"),
616 ptp!(@arrow[tp!(str), tp!(list(tp!(str)))]),
617 {
618 let x = words(d, rng);
619 let ys = x.split(d).map(|s| Str(($f)(s))).collect();
620 (Str(x), List(ys))
621 }
622 );
623 t!(
624 concat!("map ", $name, " then join"),
625 ptp!(@arrow[tp!(list(tp!(str))), tp!(str)]),
626 {
627 let n_words = Uniform::from(1..5).sample(rng);
628 let xs: Vec<_> = (0..n_words).map(|_| word(rng)).collect();
629 let y = xs.iter().map(|s| ($f)(s)).join(&d.to_string());
630 let xs = xs.into_iter().map(Str).collect();
631 (List(xs), Str(y))
632 }
633 );
634 if $name != "lowercase" && $name != "uppercase" {
635 t!(
636 concat!($name, " of delimited inp"),
637 ptp!(@arrow[tp!(str), tp!(str)]),
638 {
639 let x = words(d, rng);
640 let y = x.split(d).map(|s| ($f)(s)).join("");
641 (Str(x), Str(y))
642 }
643 );
644 }
645 let d1 = d;
646 for d2 in &DELIMITERS {
647 let d2: char = *d2;
648 t!(
649 concat!($name, " of doubly delimited inp"),
650 ptp!(@arrow[tp!(str), tp!(str)]),
651 {
652 let y = word(rng);
653 let x = format!("{}{}{}{}{}", word(rng), d1, y, d2, word(rng));
654 (Str(x), Str(($f)(&y)))
655 }
656 );
657 if d1 != d2 && $name != "lowercase" && $name != "uppercase" {
658 t!(
659 concat!("delimited ", $name, " of delimited inp"),
660 ptp!(@arrow[tp!(str), tp!(str)]),
661 {
662 let x = words(d, rng);
663 let y = x.split(d).map(|s| ($f)(s)).join(&d2.to_string());
664 (Str(x), Str(y))
665 }
666 );
667 }
668 }
669 }
670 let importance = if $name != "lowercase" && $name != "uppercase" {
671 2
672 } else {
673 1
674 };
675 for _ in 0..importance {
676 t!($name, ptp!(@arrow[tp!(str), tp!(str)]), {
677 let x = word(rng);
678 let y = ($f)(&x);
679 (Str(x), Str(y))
680 });
681 }
682 };
683 }
684
685 single_word_edit!("lowercase", |s: &str| -> String { s.to_lowercase() });
686 single_word_edit!("uppercase", |s: &str| -> String { s.to_uppercase() });
687 single_word_edit!("capitalize", |s: &str| -> String {
688 let mut s = s.to_owned();
689 if let Some(c) = s.get_mut(..1) {
690 c.make_ascii_uppercase()
691 }
692 s
693 });
694 single_word_edit!("double", |s: &str| -> String { format!("{}{}", s, s) });
695 single_word_edit!("first character", |s: &str| -> String {
696 s.chars().next().unwrap().to_string()
697 });
698 single_word_edit!("drop first character", |s: &str| -> String {
699 s.chars().skip(1).collect()
700 });
701
702 macro_rules! word_edit_pair {
703 ($name1:expr, $f1:expr, $name2:expr, $f2:expr) => {
704 t!(
705 concat!($name1, " . ", $name2),
706 ptp!(@arrow[tp!(str), tp!(str)]),
707 {
708 let x = word(rng);
709 let y = ($f2)(&(($f1)(&x)));
710 (Str(x), Str(y))
711 }
712 )
713 };
714 }
715
716 word_edit_pair!(
717 "lowercase",
718 |s: &str| -> String { s.to_lowercase() },
719 "first character",
720 |s: &str| -> String { s.chars().next().unwrap().to_string() }
721 );
722 word_edit_pair!(
723 "lowercase",
724 |s: &str| -> String { s.to_lowercase() },
725 "drop first character",
726 |s: &str| -> String { s.chars().skip(1).collect() }
727 );
728 word_edit_pair!(
729 "uppercase",
730 |s: &str| -> String { s.to_uppercase() },
731 "first character",
732 |s: &str| -> String { s.chars().next().unwrap().to_string() }
733 );
734 word_edit_pair!(
735 "uppercase",
736 |s: &str| -> String { s.to_uppercase() },
737 "drop first character",
738 |s: &str| -> String { s.chars().skip(1).collect() }
739 );
740 word_edit_pair!(
741 "double",
742 |s: &str| -> String { format!("{}{}", s, s) },
743 "first character",
744 |s: &str| -> String { s.chars().next().unwrap().to_string() }
745 );
746 word_edit_pair!(
747 "double",
748 |s: &str| -> String { format!("{}{}", s, s) },
749 "drop first character",
750 |s: &str| -> String { s.chars().skip(1).collect() }
751 );
752 word_edit_pair!(
753 "double",
754 |s: &str| -> String { format!("{}{}", s, s) },
755 "capitalize",
756 |s: &str| -> String {
757 let mut s = s.to_owned();
758 if let Some(c) = s.get_mut(..1) {
759 c.make_ascii_uppercase()
760 }
761 s
762 }
763 );
764 word_edit_pair!(
765 "first character",
766 |s: &str| -> String { s.chars().next().unwrap().to_string() },
767 "drop first character",
768 |s: &str| -> String { s.chars().skip(1).collect() }
769 );
770 word_edit_pair!(
771 "drop first character",
772 |s: &str| -> String { s.chars().skip(1).collect() },
773 "drop first character",
774 |s: &str| -> String { s.chars().skip(1).collect() }
775 );
776 word_edit_pair!(
777 "drop first character",
778 |s: &str| -> String { s.chars().skip(1).collect() },
779 "double",
780 |s: &str| -> String { format!("{}{}", s, s) }
781 );
782 word_edit_pair!(
783 "capitalize",
784 |s: &str| -> String {
785 let mut s = s.to_owned();
786 if let Some(c) = s.get_mut(..1) {
787 c.make_ascii_uppercase()
788 }
789 s
790 },
791 "double",
792 |s: &str| -> String { format!("{}{}", s, s) }
793 );
794
795 tasks
796 }
797}