1use std::collections::{BTreeMap, BTreeSet}; use array_tool::vec::{Uniq, Union};
3use crate::cats::*;
4use crate::cats::Colour::*;
5use crate::cats::GridCategory::*;
6use crate::data::*;
7use crate::grid::*;
8use crate::shape::*;
9
10#[derive(Debug, Clone, PartialEq)]
14pub struct Example {
15 pub input: QualifiedGrid,
16 pub output: QualifiedGrid,
17 pub cat: BTreeSet<GridCategory>,
18 pub pairs: Vec<(Shape, Shape, bool)>,
19 pub coloured_pairs: Vec<(Shape, Shape, bool)>,
20 }
24
25impl Example {
26 pub fn new(data: &IO) -> Self {
27 let input = QualifiedGrid::new(&data.input);
28 let output = match &data.output {
29 Some(output) => QualifiedGrid::new(output),
30 None => QualifiedGrid::trivial(),
31 };
32 let cat = Example::categorise_grid(&input, &output);
33 let pairs = Vec::new();
34 let coloured_pairs = Vec::new();
35
36 Example { input, output, cat, pairs, coloured_pairs }
37 }
38
39 pub fn new_cons(data: &IO) -> Self {
40 let input = QualifiedGrid::new_cons(&data.input);
41 let output = match &data.output {
42 Some(output) => QualifiedGrid::new_cons(output),
43 None => QualifiedGrid::trivial(),
44 };
45 let cat = Example::categorise_grid(&input, &output);
46 let pairs = Vec::new();
47 let coloured_pairs = Vec::new();
48
49 Example { input, output, cat, pairs, coloured_pairs }
50 }
51
52 pub fn transform(&self, trans: Transformation, input: bool) -> Self {
53 let mut example = self.clone();
54
55 example.transform_mut(trans, input);
56
57 example
58 }
59
60 pub fn transform_mut(&mut self, trans: Transformation, input: bool) {
61 let qgrid = if input { &mut self.input } else { &mut self.output };
62
63 qgrid.grid = qgrid.grid.transform(trans);
64
65 qgrid.shapes = if qgrid.bg == NoColour {
66 qgrid.grid.to_shapes()
67 } else {
68 qgrid.grid.to_shapes_bg(qgrid.bg)
69 };
70 qgrid.coloured_shapes = if qgrid.bg == NoColour {
71 qgrid.grid.to_shapes_coloured()
72 } else {
73 qgrid.grid.to_shapes_coloured_bg(qgrid.bg)
74 };
75 if !qgrid.black.is_empty() {
76 qgrid.black = qgrid.grid.find_black_patches();
77 }
78 }
79
80 pub fn is_equal(&self) -> bool {
109 self.input.grid.cells.columns == self.output.grid.cells.columns && self.input.grid.cells.rows == self.output.grid.cells.rows
110 }
111
112 pub fn is_bigger(&self) -> bool {
113 self.input.grid.cells.columns * self.input.grid.cells.rows > self.output.grid.cells.columns * self.output.grid.cells.rows
114 }
115
116 pub fn is_smaller(&self) -> bool {
117 self.input.grid.cells.columns * self.input.grid.cells.rows < self.output.grid.cells.columns * self.output.grid.cells.rows
118 }
119
120 pub fn diff(&self) -> Option<Grid> {
121 self.input.grid.diff(&self.output.grid)
122 }
123
124 pub fn categorise_grid(input: &QualifiedGrid, output: &QualifiedGrid ) -> BTreeSet<GridCategory> {
125 let mut cats: BTreeSet<GridCategory> = BTreeSet::new();
126
127 if output.grid.size() == 0 {
128 cats.insert(EmptyOutput);
129 }
131
132 let in_dim = input.grid.dimensions();
133 let out_dim = output.grid.dimensions();
134
135 if input.grid.is_empty() {
136 cats.insert(InEmpty);
137 }
138 if in_dim.0 > 1 && in_dim.0 == in_dim.1 && out_dim.0 == out_dim.1 && in_dim == out_dim {
139 cats.insert(InOutSquareSameSize);
140 if in_dim.0 % 2 == 0 {
142 cats.insert(InOutSquareSameSizeEven);
143 } else {
144 cats.insert(InOutSquareSameSizeOdd);
145 }
146
147 let output_grid_json = output.grid.to_json();
148
149 if input.grid.rotated_90(1).to_json() == output_grid_json {
150 cats.insert(Rot90);
151 }
152 if input.grid.rotated_90(2).to_json() == output_grid_json {
153 cats.insert(Rot180);
154 }
155 if input.grid.rotated_270(1).to_json() == output_grid_json {
156 cats.insert(Rot270);
157 }
158 if input.grid.transposed().to_json() == output_grid_json {
159 cats.insert(Transpose);
160 }
161 if input.grid.mirrored_rows().to_json() == output_grid_json {
167 cats.insert(MirroredR);
168 }
169 if input.grid.mirrored_cols().to_json() == output_grid_json {
170 cats.insert(MirroredC);
171 }
172 } else {
173 if in_dim.0 == out_dim.0 && in_dim.1 == out_dim.1 {
174 cats.insert(InOutSameSize);
175 }
176 if in_dim.0 > 1 && out_dim.0 > 1 && in_dim.0 == in_dim.1 && out_dim.0 == out_dim.1 {
177 cats.insert(InOutSquare);
178 cats.insert(NxNIn(in_dim.0));
179 cats.insert(NxNOut(out_dim.0));
180 if in_dim.0 * in_dim.0 == out_dim.0 {
181 cats.insert(InToSquaredOut);
182 }
183 } else if in_dim.0 > 1 && in_dim.0 == in_dim.1 {
184 cats.insert(InSquare);
185 cats.insert(NxNIn(in_dim.0));
186 } else if out_dim.0 > 1 && out_dim.0 == out_dim.1 {
187 cats.insert(OutSquare);
188 cats.insert(NxNOut(out_dim.0));
189 }
190 }
191 if out_dim.0 > 0 {
192 if in_dim.0 % out_dim.0 == 0 && in_dim.0 != out_dim.0 {
193 cats.insert(OutRInWidth(in_dim.0 / out_dim.0));
194 } else if in_dim.1 % out_dim.1 == 0 && in_dim.1 != out_dim.1 {
195 cats.insert(OutRInHeight(in_dim.1 / out_dim.1));
196 }
197 if out_dim.0 % in_dim.0 == 0 && out_dim.0 != in_dim.0 {
198 cats.insert(InROutWidth(out_dim.0 / in_dim.0));
199 } else if out_dim.1 % in_dim.1 == 0 && out_dim.1 != in_dim.1 {
200 cats.insert(InROutHeight(out_dim.1 / in_dim.1));
201 }
202 }
203 if out_dim.0 == 1 && out_dim.1 == 1 {
204 cats.insert(SinglePixelOut);
205 }
206 if in_dim.0 >= out_dim.0 && in_dim.1 > out_dim.1 || in_dim.0 > out_dim.0 && in_dim.1 >= out_dim.1 {
207 cats.insert(OutLessThanIn);
208 } else if in_dim.0 <= out_dim.0 && in_dim.1 < out_dim.1 || in_dim.0 < out_dim.0 && in_dim.1 <= out_dim.1 {
209 cats.insert(InLessThanOut);
210 }
211 let is_mirror_rows_in = input.grid.is_mirror_rows();
212 let is_mirror_cols_in = input.grid.is_mirror_cols();
213 let is_mirror_rows_out = output.grid.is_mirror_rows();
214 let is_mirror_cols_out = output.grid.is_mirror_cols();
215 if is_mirror_rows_in && is_mirror_cols_in {
216 cats.insert(SymmetricIn);
217 } else if is_mirror_rows_in {
218 cats.insert(SymmetricInUD);
219 } else if is_mirror_cols_in {
220 cats.insert(SymmetricInLR);
221 }
222 if is_mirror_rows_out && is_mirror_cols_out {
223 cats.insert(SymmetricOut);
224 } else if is_mirror_rows_out {
225 cats.insert(SymmetricOutUD);
226 } else if is_mirror_cols_out {
227 cats.insert(SymmetricOutLR);
228 }
229
230 let in_is_mirror_x = input.grid.is_mirror_rows();
231 let in_is_mirror_y = input.grid.is_mirror_cols();
232 let out_is_mirror_x = output.grid.is_mirror_rows();
233 let out_is_mirror_y = output.grid.is_mirror_cols();
234 if in_is_mirror_x {
235 cats.insert(MirrorRIn);
236 }
237 if in_is_mirror_y {
238 cats.insert(MirrorCIn);
239 }
240 if out_is_mirror_x {
241 cats.insert(MirrorROut);
242 }
243 if out_is_mirror_y {
244 cats.insert(MirrorCOut);
245 }
246 if input.grid.has_bg_grid() == Black {
276 cats.insert(BGGridInBlack);
277 }
278 if output.grid.has_bg_grid() == Black {
279 cats.insert(BGGridOutBlack);
280 }
281 if input.grid.has_bg_grid_coloured() != NoColour {
282 cats.insert(BGGridInColoured);
283 }
284 if output.grid.has_bg_grid_coloured() != NoColour {
285 cats.insert(BGGridOutColoured);
286 }
287 if input.grid.is_panelled_rows() {
288 cats.insert(IsPanelledRIn);
289 }
290 if output.grid.is_panelled_rows() {
291 cats.insert(IsPanelledROut);
292 }
293 if input.grid.is_panelled_cols() {
294 cats.insert(IsPanelledCIn);
295 }
296 if output.grid.is_panelled_cols() {
297 cats.insert(IsPanelledCOut);
298 }
299 let in_no_colours = input.grid.no_colours();
300 let out_no_colours = output.grid.no_colours();
301 if in_no_colours == 0 {
302 cats.insert(BlankIn);
303 }
304 if out_no_colours == 0 {
305 cats.insert(BlankOut);
306 }
307 if in_no_colours == 1 {
308 cats.insert(SingleColourIn);
309 }
310 if out_no_colours == 1 {
311 cats.insert(SingleColourOut);
312 }
313 if input.grid.colour == output.grid.colour && input.grid.colour != Mixed {
314 cats.insert(SameColour);
315 }
316if input.shapes.len() == 1 {
329 cats.insert(SingleShapeIn);
330 } else if input.coloured_shapes.len() == 1 {
331 cats.insert(SingleColouredShapeIn);
332 }
333 if output.shapes.len() == 1 {
334 cats.insert(SingleShapeOut);
335 } else if output.coloured_shapes.len() == 1 {
336 cats.insert(SingleColouredShapeOut);
337 }
338 if input.shapes.len() > 1 && input.shapes.len() == output.shapes.len() {
339 cats.insert(InSameCountOut);
340 } else if input.shapes.len() > 1 && input.coloured_shapes.len() == output.coloured_shapes.len() {
341 cats.insert(InSameCountOutColoured);
342 } else if input.shapes.len() < output.shapes.len() {
343 cats.insert(InLessCountOut);
344 } else if input.coloured_shapes.len() < output.coloured_shapes.len() {
345 cats.insert(InLessCountOutColoured);
346 } else if input.shapes.len() < output.shapes.len() {
347 cats.insert(OutLessCountIn);
348 } else if input.coloured_shapes.len() > output.coloured_shapes.len() {
349 cats.insert(OutLessCountInColoured);
350 }
351 let in_border_top = input.grid.border_top();
352 let in_border_bottom = input.grid.border_bottom();
353 let in_border_left = input.grid.border_left();
354 let in_border_right = input.grid.border_right();
355 if in_border_top {
356 cats.insert(BorderTopIn);
357 }
358 if in_border_bottom {
359 cats.insert(BorderBottomIn);
360 }
361 if in_border_left {
362 cats.insert(BorderLeftIn);
363 }
364 if in_border_right {
365 cats.insert(BorderRightIn);
366 }
367 if output.grid.size() > 0 {
368 let out_border_top = output.grid.border_top();
369 let out_border_bottom = output.grid.border_bottom();
370 let out_border_left = output.grid.border_left();
371 let out_border_right = output.grid.border_right();
372 if out_border_top {
373 cats.insert(BorderTopOut);
374 }
375 if out_border_bottom {
376 cats.insert(BorderBottomOut);
377 }
378 if out_border_left {
379 cats.insert(BorderLeftOut);
380 }
381 if out_border_right {
382 cats.insert(BorderRightOut);
383 }
384 }
385 if input.grid.even_rows() {
386 cats.insert(EvenRowsIn);
387 }
388 if output.grid.even_rows() {
389 cats.insert(EvenRowsOut);
390 }
391 if input.grid.is_full() {
392 cats.insert(FullyPopulatedIn);
393 } else if input.shapes.shapes.len() == 1 && input.shapes.shapes[0].bare_corners() {
394 cats.insert(BareCornersIn);
395 }
396 if output.grid.is_full() {
397 cats.insert(FullyPopulatedOut);
398 } else if output.shapes.shapes.len() == 1 && output.shapes.shapes[0].bare_corners() {
399 cats.insert(BareCornersOut);
400 }
401 if !input.grid.has_gravity_down() && output.grid.has_gravity_down() {
402 cats.insert(GravityDown);
403 } else if !input.grid.has_gravity_up() && output.grid.has_gravity_up() {
404 cats.insert(GravityUp);
405 } else if !input.grid.has_gravity_left() && output.grid.has_gravity_left() {
406 cats.insert(GravityLeft);
407 } else if !input.grid.has_gravity_right() && output.grid.has_gravity_right() {
408 cats.insert(GravityRight);
409 }
410 if input.grid.is_3x3() {
411 cats.insert(Is3x3In);
412 }
413 if output.grid.is_3x3() {
414 cats.insert(Is3x3Out);
415 }
416 if input.grid.div9() {
417 cats.insert(Div9In);
418 }
419 if output.grid.div9() {
420 cats.insert(Div9Out);
421 }
422 if in_dim.0 * 2 == out_dim.0 && in_dim.1 * 2 == out_dim.1 {
423 cats.insert(Double);
424 }
425 if input.shapes.shapes.len() == output.shapes.shapes.len() {
426 cats.insert(InOutShapeCount);
427 }
428 if input.coloured_shapes.shapes.len() == output.coloured_shapes.shapes.len() {
429 cats.insert(InOutShapeCountColoured);
430 }
431 if !input.black.shapes.is_empty() {
432 cats.insert(BlackPatches);
433 }
434 if input.has_bg_shape() && output.has_bg_shape() {
435 cats.insert(HasBGShape);
436 }
437 if input.has_bg_coloured_shape() && output.has_bg_coloured_shape() {
438 cats.insert(HasBGShapeColoured);
439 }
440 let hin = input.grid.cell_colour_cnt_map();
441 let hout = output.grid.cell_colour_cnt_map();
442 if hin == hout {
443 cats.insert(IdenticalColours);
444 } else if hin.len() == hout.len() {
445 cats.insert(IdenticalNoColours);
446 } else {
447 let inp: usize = hin.values().sum();
448 let outp: usize = hout.values().sum();
449
450 if inp == outp {
451 cats.insert(IdenticalNoPixels);
452 }
453 }
454 let hin_colours: usize = hin.values().sum();
455 let hout_colours: usize = hout.values().sum();
456 if hin.len() == 1 {
457 cats.insert(SingleColourCountIn(hin_colours));
458 }
459 if hout.len() == 1 {
460 cats.insert(SingleColourCountOut(hout_colours));
461 }
462 if hin_colours == hout_colours * 2 {
463 cats.insert(SingleColourIn2xOut);
464 }
465 if hin_colours == hout_colours * 4 {
466 cats.insert(SingleColourIn2xOut);
467 }
468 if hin_colours * 2 == hout_colours {
469 cats.insert(SingleColourOut2xIn);
470 }
471 if hin_colours * 4 == hout_colours {
472 cats.insert(SingleColourOut2xIn);
473 }
474 if output.grid.is_diag_origin() {
475 cats.insert(DiagonalOutOrigin);
476 } else if output.grid.is_diag_not_origin() {
477 cats.insert(DiagonalOutNotOrigin);
478 }
479 if input.grid.colour == Mixed {
480 cats.insert(NoColouredShapesIn(input.coloured_shapes.len()));
481 }
482 if output.grid.colour == Mixed {
483 cats.insert(NoColouredShapesOut(output.coloured_shapes.len()));
484 }
485 if input.shapes.overlay_shapes_same_colour() {
486 cats.insert(OverlayInSame);
487 }
488 if output.shapes.overlay_shapes_same_colour() {
489 cats.insert(OverlayOutSame);
490 }
491 if input.shapes.overlay_shapes_diff_colour() {
492 cats.insert(OverlayInDiff);
493 }
494 if output.shapes.overlay_shapes_diff_colour() {
495 cats.insert(OverlayOutDiff);
496 }
497 if input.shapes.len() == 1 && input.shapes.shapes[0].is_line() {
498 cats.insert(InLine);
499 }
500 if output.shapes.len() == 1 && output.shapes.shapes[0].is_line() {
501 cats.insert(OutLine);
502 }
503 if input.shapes.len() == input.coloured_shapes.len() {
504 cats.insert(NoNetColouredShapesIn);
505 }
506 if output.shapes.len() == output.coloured_shapes.len() {
507 cats.insert(NoNetColouredShapesOut);
508 }
509 cats.insert(NoShapesIn(input.shapes.len()));
510 cats.insert(NoShapesOut(output.shapes.len()));
511 if input.shapes.is_square_same() {
512 cats.insert(SquareShapeSide(input.shapes.shapes[0].cells.rows));
513 cats.insert(SquareShapeSize(input.shapes.shapes[0].size()));
514 }
515 let mut cc = input.shapes.colour_cnt();
516 if cc.len() == 2 {
517 let first = if let Some(first) = cc.pop_first() {
518 first.1
519 } else {
520 0
521 };
522 let second = if let Some(second) = cc.pop_first() {
523 second.1
524 } else {
525 0
526 };
527
528 cats.insert(ShapeMinCntIn(first.min(second)));
529 cats.insert(ShapeMaxCntIn(first.max(second)));
530 }
531 let mut cc = output.shapes.colour_cnt();
532 if cc.len() == 2 {
533 let first = if let Some(first) = cc.pop_first() {
534 first.1
535 } else {
536 0
537 };
538 let second = if let Some(second) = cc.pop_first() {
539 second.1
540 } else {
541 0
542 };
543
544 cats.insert(ShapeMinCntOut(first.min(second)));
545 cats.insert(ShapeMaxCntOut(first.max(second)));
546 }
547
548 cats
549 }
550
551 pub fn single_shape_in(&self) -> usize {
552 self.input.shapes.len()
553 }
554
555 pub fn single_shape_out(&self) -> usize {
556 self.output.shapes.len()
557 }
558
559 pub fn single_coloured_shape_in(&self) -> usize {
560 self.input.coloured_shapes.len()
561 }
562
563 pub fn single_coloured_shape_out(&self) -> usize {
564 self.output.coloured_shapes.len()
565 }
566
567 pub fn io_colour_diff(&self) -> Colour {
568 let in_colours = self.input.grid.cell_colour_cnt_map();
569 let out_colours = self.output.grid.cell_colour_cnt_map();
570
571 let remainder: Vec<_> = if out_colours.len() > in_colours.len() {
572 out_colours.keys().filter(|k| !in_colours.contains_key(k)).collect()
573 } else {
574 in_colours.keys().filter(|k| !out_colours.contains_key(k)).collect()
575 };
576
577 if remainder.len() == 1 {
578 *remainder[0]
579 } else {
580 NoColour
581 }
582 }
583
584 pub fn colour_shape_map(&self, out: bool) -> BTreeMap<Colour, Shape> {
585 let mut bt: BTreeMap<Colour, Shape> = BTreeMap::new();
586 let io = if out {
587 &self.output.shapes.shapes
588 } else {
589 &self.input.shapes.shapes
590 };
591
592 for s in io.iter() {
593 bt.insert(s.colour, s.clone());
594 }
595
596 bt
597 }
598
599 pub fn colour_attachment_map(&self, out: bool) -> BTreeMap<Colour, bool> {
600 let mut bt: BTreeMap<Colour, bool> = BTreeMap::new();
601 let io = if out {
602 &self.output.shapes.shapes
603 } else {
604 &self.input.shapes.shapes
605 };
606
607 let mut prev = &Shape::trivial();
608
609 for s in io.iter() {
610 if *prev != Shape::trivial() {
611 bt.insert(s.colour, s.ocol == prev.ocol || s.cells.columns == 1);
612 }
613 prev = &s;
614 }
615
616 bt
617 }
618
619 pub fn shape_pixels_to_colour(&self) -> BTreeMap<usize, Colour> {
620 let mut spc: BTreeMap<usize, Colour> = BTreeMap::new();
621
622 for s in self.output.shapes.shapes.iter() {
623 spc.insert(s.pixels(), s.colour);
624 }
625
626 spc
627 }
628
629 pub fn shape_adjacency_map(&self) -> BTreeMap<Shape, Colour> {
630 let in_shapes = &self.input.shapes;
631 let out_shapes = &self.output.shapes;
632 let mut sam: BTreeMap<Shape, Colour> = BTreeMap::new();
633 let ind_colour = in_shapes.smallest().colour;
634
635 for si in in_shapes.shapes.iter() {
636 for so in out_shapes.shapes.iter() {
637 if si.colour != ind_colour && si.equal_shape(&so) {
638 sam.insert(so.clone(), NoColour);
639 } else if si.colour == ind_colour && si.touching(&so) {
640 sam.insert(si.clone(), so.colour);
641 }
642 }
643 }
644
645 let mut map: BTreeMap<Shape, Colour> = BTreeMap::new();
646
647 for (s1, colour1) in sam.iter() {
648 if *colour1 != NoColour {
649 for (s2, colour2) in sam.iter() {
650 if *colour2 == NoColour && s1.touching(&s2) {
651 map.insert(s1.to_origin(), s2.colour);
652 }
653 }
654 }
655 }
656map
659 }
660
661 pub fn some(&self, isout: bool, f: &dyn Fn(&Shapes) -> Shape) -> Shape {
662 let s = if isout {
663 &self.output.shapes
664 } else {
665 &self.input.shapes
666 };
667
668 f(&s)
669 }
670
671 pub fn all(&self, isout: bool) -> Vec<Shape> {
672 let s = if isout {
673 &self.output.shapes
674 } else {
675 &self.input.shapes
676 };
677
678 s.shapes.clone()
679 }
680
681 pub fn some_coloured(&self, isout: bool, f: &dyn Fn(&Shapes) -> Shape) -> Shape {
682 let s = if isout {
683 &self.output.coloured_shapes
684 } else {
685 &self.input.coloured_shapes
686 };
687
688 f(&s)
689 }
690
691 pub fn all_coloured(&self, isout: bool) -> Vec<Shape> {
692 let s = if isout {
693 &self.output.coloured_shapes
694 } else {
695 &self.input.coloured_shapes
696 };
697
698 s.shapes.clone()
699 }
700
701 pub fn map_coloured_shapes_to_shape(&self, _shapes: Vec<Shape>) -> Vec<Shapes> {
702 Vec::new()
704 }
705
706 pub fn colour_cnt_diff(&self, inc: bool) -> Colour {
707 let incc = self.input.grid.cell_colour_cnt_map();
708 let outcc = self.output.grid.cell_colour_cnt_map();
709
710 if incc.len() != outcc.len() {
711 return NoColour;
712 }
713
714 for ((icol, icnt), (ocol, ocnt)) in incc.iter().zip(outcc.iter()) {
715 if icol != ocol {
716 return NoColour;
717 }
718 if inc && icnt < ocnt || !inc && icnt > ocnt {
719 return *icol;
720 }
721 }
722
723 NoColour
724 }
725
726 pub fn colour_cnt_inc(&self) -> Colour {
727 self.colour_cnt_diff(true)
728 }
729
730 pub fn colour_cnt_dec(&self) -> Colour {
731 self.colour_cnt_diff(false)
732 }
733
734 pub fn split_n_map_horizontal(&self, n: usize) -> BTreeMap<Grid, Grid> {
735 let ins: Vec<Grid> = self.input.grid.split_n_horizontal(n);
736 let outs: Vec<Grid> = self.output.grid.split_n_horizontal(n);
737 let mut bt: BTreeMap<Grid, Grid> = BTreeMap::new();
738
739 if ins.len() != outs.len() {
740 return bt;
741 }
742
743 for (is, os) in ins.iter().zip(outs.iter()) {
744 bt.insert(is.to_origin(), os.to_origin());
745 }
746
747 bt
748 }
749
750 pub fn split_n_map_vertical(&self, n: usize) -> BTreeMap<Grid, Grid> {
751 let ins: Vec<Grid> = self.input.grid.split_n_vertical(n);
752 let outs: Vec<Grid> = self.output.grid.split_n_vertical(n);
753 let mut bt: BTreeMap<Grid, Grid> = BTreeMap::new();
754
755 if ins.len() != outs.len() {
756 return bt;
757 }
758
759 for (is, os) in ins.iter().zip(outs.iter()) {
760 bt.insert(is.to_origin(), os.to_origin());
761 }
762
763 bt
764 }
765
766 pub fn majority_dimensions(&self) -> (usize, usize){
767 if self.input.shapes.shapes.is_empty() {
768 return (0, 0);
769 }
770
771 let mut sc: BTreeMap<(usize, usize), usize> = BTreeMap::new();
772
773 for s in self.input.shapes.shapes.iter() {
774 *sc.entry(s.dimensions()).or_insert(0) += 1;
775 }
776
777 if let Some((_, dim)) = sc.iter().map(|(k, v)| (v, k)).max() {
778 *dim
779 } else {
780 (0, 0)
781 }
782 }
783}
784
785#[derive(Debug, Clone, PartialEq)]
786pub struct Examples {
787 pub examples: Vec<Example>,
788 pub tests: Vec<Example>,
789 pub cat: BTreeSet<GridCategory>,
790}
791
792impl Examples {
793 pub fn new(data: &Data) -> Self {
794 let mut examples: Vec<Example> = data.train.iter()
795 .map(Example::new)
796 .collect();
797
798 let tests: Vec<Example> = data.test.iter()
799 .map(Example::new)
800 .collect();
801
802 let cat = Self::categorise_grids(&mut examples);
803
804 Examples { examples, tests, cat }
805 }
806
807 pub fn new_cons(data: &Data) -> Self {
808 let mut examples: Vec<Example> = data.train.iter()
809 .map(Example::new_cons)
810 .collect();
811
812 let tests: Vec<Example> = data.test.iter()
813 .map(Example::new_cons)
814 .collect();
815
816 let cat = Self::categorise_grids(&mut examples);
817
818 Examples { examples, tests, cat }
819 }
820
821 pub fn transformation(&self, trans: Transformation) -> Self {
822 let mut examples = self.clone();
823
824 examples.transformation_mut(trans);
825
826 examples
827 }
828
829 pub fn transformation_mut(&mut self, trans: Transformation) {
830 self.examples.iter_mut().for_each(|ex| ex.transform_mut(trans, true));
831 }
834
835 pub fn inverse_transformation(&self, trans: Transformation) -> Self {
836 let mut examples = self.clone();
837
838 examples.inverse_transformation_mut(trans);
839
840 examples
841 }
842
843 pub fn inverse_transformation_mut(&mut self, trans: Transformation) {
844 let trans = Transformation::inverse(&trans);
845
846 self.transformation(trans);
847 }
848
849 pub fn match_shapes(&self) -> BTreeMap<Shape, Shape> {
850 let mut mapping: BTreeMap<Shape, Shape> = BTreeMap::new();
851
852 for shapes in &self.examples {
853 for (si, so) in shapes.input.coloured_shapes.shapes.iter().zip(shapes.output.coloured_shapes.shapes.iter()) {
854 if so.is_contained(si) {
855 let (si, so) = so.normalise(si);
856
857 mapping.insert(si.clone(), so.clone());
858 }
859 }
860 }
861
862 mapping
863 }
864
865 pub fn full_shapes(&self, input: bool, sq: bool) -> Shapes {
866 let mut shapes: Vec<Shape> = Vec::new();
867
868 for ex in self.examples.iter() {
869 let ss = if input {
870 &ex.input.shapes
871 } else {
872 &ex.output.shapes
873 };
874
875 let it = if sq {
876 ss.full_shapes()
877 } else {
878 ss.full_shapes_sq()
879 };
880
881 for s in it.iter() {
882 let s = s.to_origin();
883
884 if !shapes.contains(&s) {
885 shapes.push(s);
886 }
887 }
888 }
889
890 Shapes::new_shapes(&shapes)
891 }
892
893 pub fn full_shapes_in(&self) -> Shapes {
894 self.full_shapes(true, false)
895 }
896
897 pub fn full_shapes_out(&self) -> Shapes {
898 self.full_shapes(false, false)
899 }
900
901 pub fn full_shapes_in_sq(&self) -> Shapes {
902 self.full_shapes(true, true)
903 }
904
905 pub fn full_shapes_out_sq(&self) -> Shapes {
906 self.full_shapes(false, true)
907 }
908
909 pub fn common(&self, input: bool) -> Grid {
910 let mut grid = Grid::trivial();
911
912 for ex in self.examples.iter() {
913 let g = if input {
914 &ex.input.grid
915 } else {
916 &ex.output.grid
917 };
918
919 if grid == Grid::trivial() {
920 grid = g.clone();
921 } else if grid.dimensions() != g.dimensions() {
922 return Grid::trivial();
923 } else {
924 for (c1, c2) in g.cells.values().zip(grid.cells.values_mut()) {
925 c2.colour = c1.colour.and(c2.colour);
926 }
927 }
928 }
929
930 grid
931 }
932
933 pub fn all_shapes(&self, input: bool, sq: bool) -> Shapes {
934 let mut shapes: Vec<Shape> = Vec::new();
935
936 for ex in self.examples.iter() {
937 let ss = if input {
938 &ex.input.shapes
939 } else {
940 &ex.output.shapes
941 };
942
943 let it = if sq {
944 ss.all_shapes()
945 } else {
946 ss.all_shapes_sq()
947 };
948
949 for s in it.iter() {
950 let s = s.to_origin();
951
952 shapes.push(s);
953 }
954 }
955
956 Shapes::new_shapes(&shapes)
957 }
958
959 pub fn all_shapes_in(&self) -> Shapes {
960 self.all_shapes(true, false)
961 }
962
963 pub fn all_shapes_out(&self) -> Shapes {
964 self.all_shapes(false, false)
965 }
966
967 pub fn all_shapes_in_sq(&self) -> Shapes {
968 self.all_shapes(true, true)
969 }
970
971 pub fn all_shapes_out_sq(&self) -> Shapes {
972 self.all_shapes(false, true)
973 }
974
975 pub fn some(&self, isout: bool, f: &dyn Fn(&Shapes) -> Shape) -> Vec<Shape> {
976 let mut s: Vec::<Shape> = Vec::new();
977
978 for ex in self.examples.iter() {
979 s.push(ex.some(isout, f));
980 }
981
982 s
983 }
984
985 pub fn all(&self, isout: bool) -> Vec<Shape> {
986 let mut s: Vec::<Shape> = Vec::new();
987
988 for ex in self.examples.iter() {
989 let vs = ex.all(isout);
990
991 for ex2 in vs.iter() {
992 s.push(ex2.clone());
993 }
994 }
995
996 s
997 }
998
999 pub fn some_coloured(&self, isout: bool, f: &dyn Fn(&Shapes) -> Shape) -> Vec<Shape> {
1000 let mut s: Vec::<Shape> = Vec::new();
1001
1002 for ex in self.examples.iter() {
1003 s.push(ex.some_coloured(isout, f));
1004 }
1005
1006 s
1007 }
1008
1009 pub fn all_coloured(&self, isout: bool) -> Vec<Shape> {
1010 let mut s: Vec::<Shape> = Vec::new();
1011
1012 for ex in self.examples.iter() {
1013 let vs = ex.all_coloured(isout);
1014
1015 for ex2 in vs.iter() {
1016 s.push(ex2.clone());
1017 }
1018 }
1019
1020 s
1021 }
1022
1023 pub fn categorise_grids(examples: &mut [Example]) -> BTreeSet<GridCategory> {
1024 let mut cats: BTreeSet<GridCategory> = BTreeSet::new();
1025 let mut extra: BTreeSet<GridCategory> = BTreeSet::new();
1026
1027 for ex in examples.iter_mut() {
1028 let cat = Example::categorise_grid(&ex.input, &ex.output);
1029
1030 if cat.contains(&InOutSquareSameSize) {
1031 extra.insert(InOutSameSize);
1032 }
1033 if cat.contains(&OverlayInSame) {
1034 extra.insert(OverlayInSame);
1035 }
1036 if cat.contains(&OverlayOutSame) {
1037 extra.insert(OverlayOutSame);
1038 }
1039 if cat.contains(&OverlayInDiff) {
1040 extra.insert(OverlayInDiff);
1041 }
1042 if cat.contains(&OverlayOutDiff) {
1043 extra.insert(OverlayOutDiff);
1044 }
1045 ex.pairs = ex.input.shapes.pair_shapes(&ex.output.shapes, true);
1046 if !ex.pairs.is_empty() {
1047 extra.insert(InOutSameShapes);
1048 }
1049 ex.coloured_pairs = ex.input.coloured_shapes.pair_shapes(&ex.output.coloured_shapes, true);
1050 if !ex.coloured_pairs.is_empty() {
1051 extra.insert(InOutSameShapesColoured);
1052 }
1053
1054 if cats.is_empty() {
1055 cats = cat;
1056 } else {
1057 cats = cats.intersection(&cat).cloned().collect();
1058 }
1059 }
1060
1061 if cats.contains(&InOutSquareSameSize) && cats.contains(&InOutSameSize) {
1062 cats.remove(&InOutSameSize);
1063 }
1064
1065 cats = cats.union(&extra).cloned().collect();
1066
1067 cats
1068 }
1069
1070 pub fn find_input_colours(&self) -> Vec<Colour> {
1071 let mut common = Colour::all_colours();
1072
1073 for ex in self.examples.iter() {
1074 let h = ex.input.grid.cell_colour_cnt_map();
1075 let v: BTreeSet<Colour> = h.keys().copied().collect();
1076
1077 common = common.intersection(&v).copied().collect();
1078 }
1079
1080 Vec::from_iter(common)
1081 }
1082
1083 pub fn find_output_colours(&self) -> Vec<Colour> {
1084 let mut common = Colour::all_colours();
1085
1086 for ex in self.examples.iter() {
1087 let h = ex.output.grid.cell_colour_cnt_map();
1088 let v: BTreeSet<Colour> = h.keys().copied().collect();
1089
1090 common = common.intersection(&v).copied().collect();
1091 }
1092
1093 Vec::from_iter(common)
1094 }
1095
1096 pub fn find_all_output_colours(&self) -> Vec<Colour> {
1097 let mut common = Vec::new();
1098
1099 for ex in self.examples.iter() {
1100 let h = ex.output.grid.cell_colour_cnt_map();
1101 let v: Vec<Colour> = h.keys().copied().collect();
1102
1103 common = Union::union(&common, v);
1104 }
1105
1106 Vec::from_iter(common)
1107 }
1108
1109 pub fn find_all_input_colours(&self) -> Vec<Colour> {
1110 let mut common = Vec::new();
1111
1112 for ex in self.examples.iter() {
1113 let h = ex.input.grid.cell_colour_cnt_map();
1114 let v: Vec<Colour> = h.keys().copied().collect();
1115
1116 common = Union::union(&common, v);
1117 }
1118
1119 Vec::from_iter(common)
1120 }
1121
1122 pub fn io_colour_diff(&self) -> Vec<Colour> {
1123 let in_colours = self.find_all_input_colours();
1124 let out_colours = self.find_all_output_colours();
1125
1126 Uniq::uniq(&out_colours, in_colours)
1127 }
1128
1129 pub fn io_all_colour_diff(&self) -> Vec<Colour> {
1130 let in_colours = self.find_all_input_colours();
1131 let out_colours = self.find_all_output_colours();
1132
1133 if out_colours.len() > in_colours.len() {
1134 Uniq::uniq(&out_colours, in_colours)
1135 } else {
1136 Uniq::uniq(&in_colours, out_colours)
1137 }
1138 }
1139
1140 pub fn io_colour_common(&self) -> Vec<Colour> {
1141 let in_colours = self.find_input_colours();
1142 let out_colours = self.find_output_colours();
1143
1144 Union::union(&out_colours, in_colours)
1145 }
1146
1147 pub fn io_common_row_colour(&self) -> Colour {
1148 let mut colour = NoColour;
1149
1150 for ex in self.examples.iter() {
1151 for (i, o) in ex.input.shapes.shapes.iter().zip(ex.input.shapes.shapes.iter()) {
1152 if colour == NoColour && i.orow == o.orow && i.colour == o.colour {
1153 colour = i.colour;
1154
1155 break;
1156 } else if i.orow == o.orow && i.colour != o.colour {
1157 return NoColour;
1158 }
1159 }
1160 }
1161
1162 colour
1163 }
1164
1165 pub fn io_common_col_colour(&self) -> Colour {
1166 let mut colour = NoColour;
1167
1168 for ex in self.examples.iter() {
1169 for (i, o) in ex.input.shapes.shapes.iter().zip(ex.input.shapes.shapes.iter()) {
1170 if colour == NoColour && i.ocol == o.ocol && i.colour == o.colour {
1171 colour = i.colour;
1172
1173 break;
1174 } else if i.ocol == o.ocol && i.colour != o.colour {
1175 return NoColour;
1176 }
1177 }
1178 }
1179
1180 colour
1181 }
1182
1183 pub fn find_hollow_cnt_colour_map(&self) -> BTreeMap<usize, Colour> {
1184 let mut ccm: BTreeMap<usize, Colour> = BTreeMap::new();
1185
1186 for ex in self.examples.iter() {
1187 let h = ex.output.shapes.hollow_cnt_colour_map();
1188
1189 for (k, v) in &h {
1190 ccm.insert(*k, *v);
1191 }
1192 }
1193
1194 ccm
1195 }
1196
1197 pub fn find_colour_io_map(&self) -> BTreeMap<Colour, Colour> {
1198 let mut h: BTreeMap<Colour, Colour> = BTreeMap::new();
1199
1200 for ex in self.examples.iter() {
1201 if ex.input.shapes.shapes.len() != ex.output.shapes.shapes.len() {
1202 return h;
1203 }
1204
1205 for (si, so) in ex.input.shapes.shapes.iter().zip(ex.output.shapes.shapes.iter()) {
1206 h.insert(si.colour, so.colour);
1207 }
1208 }
1209
1210 h
1211 }
1212
1213 pub fn largest_shape_colour(&self) -> Colour {
1214 let mut colour = NoColour;
1215
1216 for ex in self.examples.iter() {
1217 let s = ex.output.shapes.largest();
1218
1219 if colour == NoColour {
1220 colour = s.colour;
1221 } else if colour != s.colour {
1222 return NoColour; }
1224 }
1225
1226 colour
1227 }
1228
1229 pub fn bleached_io_map(&self) -> BTreeMap<String, Grid> {
1230 let mut h: BTreeMap<String, Grid> = BTreeMap::new();
1231
1232 for ex in self.examples.iter() {
1233 h.insert(ex.input.grid.bleach().to_json(), ex.output.grid.clone());
1234 }
1235
1236 h
1237 }
1238
1239 pub fn in_max_size(&self) -> (usize, usize) {
1240 let mut rs = 0;
1241 let mut cs = 0;
1242
1243 for ex in self.examples.iter() {
1244 let (r, c) = ex.input.grid.dimensions();
1245
1246 (rs, cs) = (rs, cs).max((r, c));
1247 }
1248
1249 (rs, cs)
1250 }
1251
1252 pub fn derive_missing_rule(&self) -> Grid {
1253 let mut i_rs = 0;
1254 let mut i_cs = 0;
1255 let mut in_grid = Grid::trivial();
1256 let mut out_grid = Grid::trivial();
1257
1258 for ex in self.examples.iter() {
1260 let grid = &ex.input.grid;
1261
1262 if grid.is_square() && grid.pixels() == 1 && grid.cells[(grid.cells.rows / 2, grid.cells.columns / 2)].colour != Black {
1263 (i_rs, i_cs) = (i_rs, i_cs).max((grid.cells.rows, grid.cells.columns));
1264
1265 in_grid = ex.input.grid.clone();
1266 out_grid = ex.output.grid.clone();
1267 }
1268 }
1269
1270 in_grid.derive_missing_rule(&out_grid)
1271 }
1272
1273 pub fn shape_pixels_to_colour(&self) -> BTreeMap<usize, Colour> {
1274 let mut spc: BTreeMap<usize, Colour> = BTreeMap::new();
1275
1276 for ex in self.examples.iter() {
1277 spc.extend(ex.shape_pixels_to_colour());
1278 }
1279
1280 spc
1281 }
1282
1283 pub fn shape_adjacency_map(&self) -> BTreeMap<Shape, Colour> {
1284 let mut sam: BTreeMap<Shape, Colour> = BTreeMap::new();
1285
1286 for ex in self.examples.iter() {
1287 sam.extend(ex.shape_adjacency_map());
1288 }
1289
1290 sam
1291 }
1292
1293 pub fn colour_shape_map(&self, out: bool) -> BTreeMap<Colour, Shape> {
1294 let mut bt: BTreeMap<Colour, Shape> = BTreeMap::new();
1295
1296 for ex in self.examples.iter() {
1297 bt.extend(ex.colour_shape_map(out));
1298 }
1299
1300 bt
1301 }
1302
1303 pub fn colour_attachment_map(&self, out: bool) -> BTreeMap<Colour, bool> {
1304 let mut bt: BTreeMap<Colour, bool> = BTreeMap::new();
1305
1306 for ex in self.examples.iter() {
1307 bt.extend(ex.colour_attachment_map(out));
1308 }
1309
1310 bt
1311 }
1312
1313 pub fn colour_cnt_diff(&self, inc: bool) -> Colour {
1314 let mut colour = NoColour;
1315
1316 for ex in self.examples.iter() {
1317 let new_col = ex.colour_cnt_diff(inc);
1318
1319 if colour == NoColour {
1320 colour = new_col;
1321 } else if colour != new_col {
1322 return NoColour;
1323 }
1324 }
1325
1326 colour
1327 }
1328
1329 pub fn colour_cnt_inc(&self) -> Colour {
1330 self.colour_cnt_diff(true)
1331 }
1332
1333 pub fn colour_cnt_dec(&self) -> Colour {
1334 self.colour_cnt_diff(false)
1335 }
1336
1337 pub fn colour_diffs(&self, inc: bool) -> Vec<Colour> {
1338 let mut cc: Vec<Colour> = Vec::new();
1339
1340 for ex in self.examples.iter() {
1341 let new_col = ex.colour_cnt_diff(inc);
1342
1343 if new_col != NoColour {
1344 cc.push(new_col);
1345 }
1346 }
1347
1348 cc
1349 }
1350
1351 pub fn colour_incs(&self) -> Vec<Colour> {
1352 self.colour_diffs(true)
1353 }
1354
1355 pub fn colour_decs(&self) -> Vec<Colour> {
1356 self.colour_diffs(false)
1357 }
1358
1359 pub fn split_n_map_horizontal(&self, n: usize) -> BTreeMap<Grid, Grid> {
1360 let mut bt: BTreeMap<Grid, Grid> = BTreeMap::new();
1361
1362 for ex in self.examples.iter() {
1363 bt.extend(ex.split_n_map_horizontal(n));
1364 }
1365
1366 bt
1367 }
1368
1369 pub fn split_n_map_vertical(&self, n: usize) -> BTreeMap<Grid, Grid> {
1370 let mut bt: BTreeMap<Grid, Grid> = BTreeMap::new();
1371
1372 for ex in self.examples.iter() {
1373 bt.extend(ex.split_n_map_vertical(n));
1374 }
1375
1376 bt
1377 }
1378}