1use crate::figure::Item;
6use crate::math::optimizations::ZeroLineDst;
7use crate::token::number::{CompExponent, ProcNum};
8use crate::unroll::figure::Node;
9use crate::unroll::flags::Flag;
10use derive_recursive::Recursive;
11use geo_aid_figure::{EntityIndex as EntityId, VarIndex};
12use num_traits::{FromPrimitive, One, Zero};
13use serde::Serialize;
14use std::any::Any;
15use std::cell::OnceCell;
16use std::cmp::Ordering;
17use std::collections::{hash_map, HashMap, HashSet};
18use std::fmt::{Debug, Display, Formatter};
19use std::hash::Hash;
20use std::iter::Peekable;
21use std::mem;
22use std::ops::{Deref, DerefMut};
23use std::rc::Rc;
24
25use self::optimizations::{EqExpressions, EqPointDst, RightAngle};
26
27use super::unroll::GetData;
28use super::{
29 figure::Figure,
30 unroll::{
31 self, Circle as UnrolledCircle, Displayed, Expr as Unrolled, Line as UnrolledLine,
32 NumberData as UnrolledNumber, Point as UnrolledPoint, UnrolledRule, UnrolledRuleKind,
33 },
34 ComplexUnit, Error, SimpleUnit,
35};
36
37mod optimizations;
38
39#[derive(Debug, Clone)]
42pub struct Optimizations {}
43
44#[derive(Debug, Clone)]
46pub struct Flags {
47 pub optimizations: Optimizations,
49 pub point_inequalities: bool,
51}
52
53impl Default for Flags {
54 fn default() -> Self {
55 Self {
56 optimizations: Optimizations {},
57 point_inequalities: false,
58 }
59 }
60}
61
62pub trait GetMathType {
64 #[must_use]
65 fn get_math_type() -> ExprType;
66}
67
68impl GetMathType for UnrolledPoint {
69 fn get_math_type() -> ExprType {
70 ExprType::Point
71 }
72}
73
74impl GetMathType for UnrolledLine {
75 fn get_math_type() -> ExprType {
76 ExprType::Line
77 }
78}
79
80impl GetMathType for UnrolledCircle {
81 fn get_math_type() -> ExprType {
82 ExprType::Circle
83 }
84}
85
86impl GetMathType for unroll::Number {
87 fn get_math_type() -> ExprType {
88 ExprType::Number
89 }
90}
91
92pub trait DeepClone {
96 #[must_use]
98 fn deep_clone(&self, math: &mut Math) -> Self;
99}
100
101impl DeepClone for ProcNum {
102 fn deep_clone(&self, _math: &mut Math) -> Self {
103 self.clone()
104 }
105}
106
107impl DeepClone for CompExponent {
108 fn deep_clone(&self, _math: &mut Math) -> Self {
109 *self
110 }
111}
112
113impl<T: DeepClone> DeepClone for Vec<T> {
114 fn deep_clone(&self, math: &mut Math) -> Self {
115 self.iter().map(|x| x.deep_clone(math)).collect()
116 }
117}
118
119trait Compare {
121 #[must_use]
123 fn compare(&self, other: &Self, math: &Math) -> Ordering;
124}
125
126impl<T: Compare> Compare for Vec<T> {
127 fn compare(&self, other: &Self, math: &Math) -> Ordering {
128 self.iter()
129 .zip(other)
130 .map(|(a, b)| a.compare(b, math))
131 .find(|x| x.is_ne())
132 .unwrap_or_else(|| self.len().cmp(&other.len()))
133 }
134}
135
136trait ContainsEntity {
138 fn contains_entity(&self, entity: EntityId, math: &Math) -> bool;
140}
141
142impl ContainsEntity for CompExponent {
143 fn contains_entity(&self, _entity: EntityId, _math: &Math) -> bool {
144 false
145 }
146}
147
148impl ContainsEntity for ProcNum {
149 fn contains_entity(&self, _entity: EntityId, _math: &Math) -> bool {
150 false
151 }
152}
153
154impl<T: ContainsEntity> ContainsEntity for Box<T> {
155 fn contains_entity(&self, entity: EntityId, math: &Math) -> bool {
156 self.as_ref().contains_entity(entity, math)
157 }
158}
159
160impl<T: ContainsEntity> ContainsEntity for Vec<T> {
161 fn contains_entity(&self, entity: EntityId, math: &Math) -> bool {
162 self.iter().any(|item| item.contains_entity(entity, math))
163 }
164}
165
166#[derive(Debug, Clone, PartialEq, Eq)]
168enum EntityBehavior {
169 MapEntity(EntityId),
171 MapVar(VarIndex),
173}
174
175pub struct ReconstructCtx<'r> {
177 entity_replacement: &'r [EntityBehavior],
179 old_vars: &'r [Expr<()>],
181 new_vars: Vec<Expr<()>>,
183 old_entities: &'r [EntityKind],
185 new_entities: Vec<Option<EntityKind>>,
187}
188
189impl<'r> ReconstructCtx<'r> {
190 #[must_use]
192 fn new(
193 entity_replacement: &'r [EntityBehavior],
194 old_vars: &'r [Expr<()>],
195 old_entities: &'r [EntityKind],
196 ) -> Self {
197 let mut ctx = Self {
198 entity_replacement,
199 old_vars,
200 new_vars: Vec::new(),
201 old_entities,
202 new_entities: vec![None; old_entities.len()],
203 };
204
205 for i in 0..old_entities.len() {
206 ctx.reconstruct_entity(EntityId(i));
207 }
208
209 ctx
210 }
211
212 fn reconstruct_entity(&mut self, id: EntityId) {
214 if self.new_entities[id.0].is_none() {
215 self.new_entities[id.0] = Some(self.old_entities[id.0].clone().reconstruct(self));
216 }
217 }
218}
219
220pub trait Reconstruct {
227 #[must_use]
229 fn reconstruct(self, ctx: &mut ReconstructCtx) -> Self;
230}
231
232fn reconstruct_entity(entity_id: EntityId, ctx: &mut ReconstructCtx) -> ExprKind {
234 match &ctx.entity_replacement[entity_id.0] {
235 EntityBehavior::MapEntity(id) => {
236 ctx.reconstruct_entity(*id);
237 ExprKind::Entity { id: *id }
238 }
239 EntityBehavior::MapVar(index) => ctx.old_vars[index.0].kind.clone().reconstruct(ctx),
240 }
241}
242
243impl Reconstruct for ProcNum {
244 fn reconstruct(self, _ctx: &mut ReconstructCtx) -> Self {
245 self
246 }
247}
248
249impl Reconstruct for CompExponent {
250 fn reconstruct(self, _ctx: &mut ReconstructCtx) -> Self {
251 self
252 }
253}
254
255impl<T: Reconstruct> Reconstruct for Option<T> {
256 fn reconstruct(self, ctx: &mut ReconstructCtx) -> Self {
257 self.map(|v| v.reconstruct(ctx))
258 }
259}
260
261impl<T: Reconstruct> Reconstruct for Box<T> {
262 fn reconstruct(self, ctx: &mut ReconstructCtx) -> Self {
263 Self::new((*self).reconstruct(ctx))
264 }
265}
266
267impl<T: Reconstruct> Reconstruct for Vec<T> {
268 fn reconstruct(self, ctx: &mut ReconstructCtx) -> Self {
269 self.into_iter().map(|x| x.reconstruct(ctx)).collect()
270 }
271}
272
273trait FindEntities {
275 fn find_entities(
278 &self,
279 previous: &[HashSet<EntityId>],
280 entities: &[EntityKind],
281 ) -> HashSet<EntityId>;
282}
283
284impl FindEntities for EntityId {
285 fn find_entities(
286 &self,
287 previous: &[HashSet<EntityId>],
288 entities: &[EntityKind],
289 ) -> HashSet<EntityId> {
290 entities[self.0].find_entities(previous, entities)
291 }
292}
293
294impl FindEntities for Vec<VarIndex> {
295 fn find_entities(
296 &self,
297 previous: &[HashSet<EntityId>],
298 _entities: &[EntityKind],
299 ) -> HashSet<EntityId> {
300 self.iter()
301 .flat_map(|x| previous[x.0].iter().copied())
302 .collect()
303 }
304}
305
306pub trait FromUnrolled<T: Displayed> {
308 fn load(expr: &Unrolled<T>, math: &mut Expand) -> Self;
310}
311
312trait Normalize {
314 fn normalize(&mut self, math: &mut Math);
319}
320
321impl Reconstruct for VarIndex {
322 fn reconstruct(self, ctx: &mut ReconstructCtx) -> Self {
323 let expr = ctx.old_vars[self.0].clone();
324 let kind = expr.kind.reconstruct(ctx);
325 ctx.new_vars.push(Expr::new(kind, expr.ty));
326 VarIndex(ctx.new_vars.len() - 1)
327 }
328}
329
330impl DeepClone for VarIndex {
331 fn deep_clone(&self, math: &mut Math) -> Self {
332 let ty = math.at(self).ty;
335 let expr = math.at(self).kind.clone().deep_clone(math);
336 math.store(expr, ty)
337 }
338}
339
340impl Compare for VarIndex {
341 fn compare(&self, other: &Self, math: &Math) -> Ordering {
342 math.at(self).kind.compare(&math.at(other).kind, math)
343 }
344}
345
346impl Reindex for VarIndex {
347 fn reindex(&mut self, map: &IndexMap) {
348 self.0 = map.get(self.0);
349 }
350}
351
352impl ContainsEntity for VarIndex {
353 fn contains_entity(&self, entity: EntityId, math: &Math) -> bool {
354 math.at(self).contains_entity(entity, math)
355 }
356}
357
358#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Hash, Serialize)]
360pub enum ExprType {
361 Number,
362 #[default]
363 Point,
364 Line,
365 Circle,
366}
367
368#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Recursive, Hash, Serialize)]
370#[recursive(
371 impl ContainsEntity for Self {
372 fn contains_entity(&self, entity: EntityId, math: &Math) -> bool {
373 aggregate = ||
374 }
375 }
376)]
377#[recursive(
378 impl Reindex for Self {
379 fn reindex(&mut self, map: &IndexMap) {
380 aggregate = _
381 }
382 }
383)]
384#[recursive(
385 impl DeepClone for Self {
386 fn deep_clone(&self, math: &mut Math) -> Self {
387 aggregate = {}
388 }
389 }
390)]
391#[recursive(
392 impl Reconstruct for Self {
393 fn reconstruct(self, ctx: &mut ReconstructCtx) -> Self {
394 aggregate = {},
395 override_marker = override_reconstruct
396 }
397 }
398)]
399pub enum ExprKind {
400 #[recursive(override_reconstruct = reconstruct_entity)]
402 Entity { id: EntityId },
403
404 LineLineIntersection { k: VarIndex, l: VarIndex },
407 AveragePoint { items: Vec<VarIndex> },
409 CircleCenter { circle: VarIndex },
411 ComplexToPoint { number: VarIndex },
413
414 Sum {
420 plus: Vec<VarIndex>,
422 minus: Vec<VarIndex>,
424 },
425 Product {
430 times: Vec<VarIndex>,
432 by: Vec<VarIndex>,
434 },
435 Const { value: ProcNum },
437 Exponentiation {
439 value: VarIndex,
440 exponent: CompExponent,
441 },
442 PointPointDistance { p: VarIndex, q: VarIndex },
444 PointLineDistance { point: VarIndex, line: VarIndex },
446 ThreePointAngle {
448 p: VarIndex,
450 q: VarIndex,
452 r: VarIndex,
454 },
455 ThreePointAngleDir {
457 p: VarIndex,
459 q: VarIndex,
461 r: VarIndex,
463 },
464 TwoLineAngle { k: VarIndex, l: VarIndex },
466 PointX { point: VarIndex },
468 PointY { point: VarIndex },
470 PointToComplex { point: VarIndex },
472 Real { number: VarIndex },
474 Imaginary { number: VarIndex },
476 Log { number: VarIndex },
478 Exp { number: VarIndex },
480 Sin { angle: VarIndex },
482 Cos { angle: VarIndex },
484 Atan2 { y: VarIndex, x: VarIndex },
486 DirectionVector { line: VarIndex },
488
489 PointPoint { p: VarIndex, q: VarIndex },
492 AngleBisector {
494 p: VarIndex,
495 q: VarIndex,
496 r: VarIndex,
497 },
498 ParallelThrough { point: VarIndex, line: VarIndex },
500 PerpendicularThrough { point: VarIndex, line: VarIndex },
502 PointVector { point: VarIndex, vector: VarIndex },
504
505 ConstructCircle { center: VarIndex, radius: VarIndex },
508}
509
510impl ExprKind {
511 #[must_use]
513 pub fn variant_id(&self) -> usize {
514 match self {
515 Self::Entity { .. } => 0,
516 Self::Const { .. } => 1,
517 Self::LineLineIntersection { .. } => 2,
518 Self::AveragePoint { .. } => 3,
519 Self::CircleCenter { .. } => 4,
520 Self::Sum { .. } => 5,
521 Self::Product { .. } => 6,
522 Self::Exponentiation { .. } => 7,
523 Self::PointPointDistance { .. } => 8,
524 Self::PointLineDistance { .. } => 9,
525 Self::ThreePointAngle { .. } => 10,
526 Self::ThreePointAngleDir { .. } => 11,
527 Self::TwoLineAngle { .. } => 12,
528 Self::PointX { .. } => 13,
529 Self::PointY { .. } => 14,
530 Self::PointPoint { .. } => 15,
531 Self::AngleBisector { .. } => 16,
532 Self::ParallelThrough { .. } => 17,
533 Self::PerpendicularThrough { .. } => 18,
534 Self::ConstructCircle { .. } => 19,
535 Self::PointToComplex { .. } => 20,
536 Self::ComplexToPoint { .. } => 21,
537 Self::Real { .. } => 22,
538 Self::Imaginary { .. } => 23,
539 Self::Sin { .. } => 24,
540 Self::Cos { .. } => 25,
541 Self::Atan2 { .. } => 26,
542 Self::Log { .. } => 27,
543 Self::Exp { .. } => 28,
544 Self::DirectionVector { .. } => 29,
545 Self::PointVector { .. } => 30,
546 }
547 }
548
549 #[must_use]
551 #[allow(clippy::too_many_lines)]
552 pub fn compare(&self, other: &Self, math: &Math) -> Ordering {
553 self.variant_id()
554 .cmp(&other.variant_id())
555 .then_with(|| match (self, other) {
556 (Self::Entity { id: self_id }, Self::Entity { id: other_id }) => {
557 self_id.cmp(other_id)
558 }
559 (
560 Self::LineLineIntersection {
561 k: self_a,
562 l: self_b,
563 },
564 Self::LineLineIntersection {
565 k: other_a,
566 l: other_b,
567 },
568 )
569 | (
570 Self::PointPointDistance {
571 p: self_a,
572 q: self_b,
573 },
574 Self::PointPointDistance {
575 p: other_a,
576 q: other_b,
577 },
578 )
579 | (
580 Self::PointLineDistance {
581 point: self_a,
582 line: self_b,
583 },
584 Self::PointLineDistance {
585 point: other_a,
586 line: other_b,
587 },
588 )
589 | (
590 Self::TwoLineAngle {
591 k: self_a,
592 l: self_b,
593 },
594 Self::TwoLineAngle {
595 k: other_a,
596 l: other_b,
597 },
598 )
599 | (
600 Self::PointPoint {
601 p: self_a,
602 q: self_b,
603 },
604 Self::PointPoint {
605 p: other_a,
606 q: other_b,
607 },
608 )
609 | (
610 Self::ParallelThrough {
611 point: self_a,
612 line: self_b,
613 },
614 Self::ParallelThrough {
615 point: other_a,
616 line: other_b,
617 },
618 )
619 | (
620 Self::PerpendicularThrough {
621 point: self_a,
622 line: self_b,
623 },
624 Self::PerpendicularThrough {
625 point: other_a,
626 line: other_b,
627 },
628 )
629 | (
630 Self::ConstructCircle {
631 center: self_a,
632 radius: self_b,
633 },
634 Self::ConstructCircle {
635 center: other_a,
636 radius: other_b,
637 },
638 ) => self_a
639 .compare(other_a, math)
640 .then_with(|| self_b.compare(other_b, math)),
641 (
642 Self::AveragePoint { items: self_items },
643 Self::AveragePoint { items: other_items },
644 ) => self_items.compare(other_items, math),
645 (Self::CircleCenter { circle: self_x }, Self::CircleCenter { circle: other_x })
646 | (Self::PointX { point: self_x }, Self::PointX { point: other_x })
647 | (Self::PointY { point: self_x }, Self::PointY { point: other_x }) => {
648 self_x.compare(other_x, math)
649 }
650 (
651 Self::Sum {
652 plus: self_v,
653 minus: self_u,
654 },
655 Self::Sum {
656 plus: other_v,
657 minus: other_u,
658 },
659 )
660 | (
661 Self::Product {
662 times: self_v,
663 by: self_u,
664 },
665 Self::Product {
666 times: other_v,
667 by: other_u,
668 },
669 ) => self_v
670 .compare(other_v, math)
671 .then_with(|| self_u.compare(other_u, math)),
672 (Self::Const { value: self_v }, Self::Const { value: other_v }) => {
673 self_v.cmp(other_v)
674 }
675 (
676 Self::Exponentiation {
677 value: self_v,
678 exponent: self_exp,
679 },
680 Self::Exponentiation {
681 value: other_v,
682 exponent: other_exp,
683 },
684 ) => self_v
685 .compare(other_v, math)
686 .then_with(|| self_exp.cmp(other_exp)),
687 (
688 Self::ThreePointAngle {
689 p: self_p,
690 q: self_q,
691 r: self_r,
692 },
693 Self::ThreePointAngle {
694 p: other_p,
695 q: other_q,
696 r: other_r,
697 },
698 )
699 | (
700 Self::ThreePointAngleDir {
701 p: self_p,
702 q: self_q,
703 r: self_r,
704 },
705 Self::ThreePointAngleDir {
706 p: other_p,
707 q: other_q,
708 r: other_r,
709 },
710 )
711 | (
712 Self::AngleBisector {
713 p: self_p,
714 q: self_q,
715 r: self_r,
716 },
717 Self::AngleBisector {
718 p: other_p,
719 q: other_q,
720 r: other_r,
721 },
722 ) => self_p
723 .compare(other_p, math)
724 .then_with(|| self_q.compare(other_q, math))
725 .then_with(|| self_r.compare(other_r, math)),
726 (_, _) => Ordering::Equal,
727 })
728 }
729
730 #[must_use]
732 pub fn get_type<M>(&self, expressions: &[Expr<M>], entities: &[Entity<M>]) -> ExprType {
733 match self {
734 Self::Entity { id } => entities[id.0].get_type(expressions, entities),
735 Self::LineLineIntersection { .. }
736 | Self::AveragePoint { .. }
737 | Self::CircleCenter { .. }
738 | Self::ComplexToPoint { .. } => ExprType::Point,
739 Self::Sum { .. }
740 | Self::Product { .. }
741 | Self::Const { .. }
742 | Self::Exponentiation { .. }
743 | Self::PointPointDistance { .. }
744 | Self::PointLineDistance { .. }
745 | Self::ThreePointAngle { .. }
746 | Self::ThreePointAngleDir { .. }
747 | Self::TwoLineAngle { .. }
748 | Self::Sin { .. }
749 | Self::Cos { .. }
750 | Self::Atan2 { .. }
751 | Self::PointX { .. }
752 | Self::PointY { .. }
753 | Self::Real { .. }
754 | Self::Imaginary { .. }
755 | Self::Log { .. }
756 | Self::Exp { .. }
757 | Self::DirectionVector { .. }
758 | Self::PointToComplex { .. } => ExprType::Number,
759 Self::PointPoint { .. }
760 | Self::AngleBisector { .. }
761 | Self::ParallelThrough { .. }
762 | Self::PerpendicularThrough { .. }
763 | Self::PointVector { .. } => ExprType::Line,
764 Self::ConstructCircle { .. } => ExprType::Circle,
765 }
766 }
767}
768
769impl From<ExprKind> for geo_aid_figure::ExpressionKind {
770 fn from(value: ExprKind) -> Self {
771 match value {
772 ExprKind::Entity { id } => Self::Entity { id },
773 ExprKind::LineLineIntersection { k, l } => Self::LineLineIntersection { k, l },
774 ExprKind::AveragePoint { items } => Self::AveragePoint { items },
775 ExprKind::CircleCenter { circle } => Self::CircleCenter { circle },
776 ExprKind::ComplexToPoint { number } => Self::ComplexToPoint { number },
777 ExprKind::Sum { plus, minus } => Self::Sum { plus, minus },
778 ExprKind::Product { times, by } => Self::Product { times, by },
779 ExprKind::Const { value } => Self::Const {
780 value: value.to_complex().into(),
781 },
782 ExprKind::Exponentiation { value, exponent } => Self::Power {
783 value,
784 exponent: exponent.into(),
785 },
786 ExprKind::PointPointDistance { p, q } => Self::PointPointDistance { p, q },
787 ExprKind::PointLineDistance { point, line } => Self::PointLineDistance { point, line },
788 ExprKind::ThreePointAngle { p, q, r } => Self::ThreePointAngle { a: p, b: q, c: r },
789 ExprKind::ThreePointAngleDir { p, q, r } => {
790 Self::ThreePointAngleDir { a: p, b: q, c: r }
791 }
792 ExprKind::TwoLineAngle { k, l } => Self::TwoLineAngle { k, l },
793 ExprKind::Sin { angle } => Self::Sin { angle },
794 ExprKind::Cos { angle } => Self::Cos { angle },
795 ExprKind::Atan2 { y, x } => Self::Atan2 { y, x },
796 ExprKind::DirectionVector { line } => Self::DirectionVector { line },
797 ExprKind::PointX { point } => Self::PointX { point },
798 ExprKind::PointY { point } => Self::PointY { point },
799 ExprKind::PointToComplex { point } => Self::PointToComplex { point },
800 ExprKind::Real { number } => Self::Real { number },
801 ExprKind::Imaginary { number } => Self::Imaginary { number },
802 ExprKind::Log { number } => Self::Log { number },
803 ExprKind::Exp { number } => Self::Exp { number },
804 ExprKind::PointPoint { p, q } => Self::PointPointLine { p, q },
805 ExprKind::PointVector { point, vector } => Self::PointVectorLine { point, vector },
806 ExprKind::AngleBisector { p, q, r } => Self::AngleBisector { p, q, r },
807 ExprKind::ParallelThrough { point, line } => Self::ParallelThrough { point, line },
808 ExprKind::PerpendicularThrough { point, line } => {
809 Self::PerpendicularThrough { point, line }
810 }
811 ExprKind::ConstructCircle { center, radius } => {
812 Self::ConstructCircle { center, radius }
813 }
814 }
815 }
816}
817
818impl FindEntities for ExprKind {
819 fn find_entities(
820 &self,
821 previous: &[HashSet<EntityId>],
822 entities: &[EntityKind],
823 ) -> HashSet<EntityId> {
824 let mut set = HashSet::new();
825
826 match self {
827 Self::Entity { id } => {
828 set.insert(*id);
829 set.extend(id.find_entities(previous, entities));
830 }
831 Self::AveragePoint { items } => {
832 set.extend(items.iter().flat_map(|x| previous[x.0].iter().copied()));
833 }
834 Self::CircleCenter { circle: x }
835 | Self::PointX { point: x }
836 | Self::PointY { point: x }
837 | Self::Sin { angle: x }
838 | Self::Cos { angle: x }
839 | Self::Exponentiation { value: x, .. }
840 | Self::PointToComplex { point: x }
841 | Self::ComplexToPoint { number: x }
842 | Self::Log { number: x }
843 | Self::Exp { number: x }
844 | Self::DirectionVector { line: x }
845 | Self::Real { number: x }
846 | Self::Imaginary { number: x } => {
847 set.extend(previous[x.0].iter().copied());
848 }
849 Self::Sum {
850 plus: v1,
851 minus: v2,
852 }
853 | Self::Product { times: v1, by: v2 } => {
854 set.extend(v1.iter().flat_map(|x| previous[x.0].iter().copied()));
855 set.extend(v2.iter().flat_map(|x| previous[x.0].iter().copied()));
856 }
857 Self::PointPointDistance { p: a, q: b }
858 | Self::PointLineDistance { point: a, line: b }
859 | Self::TwoLineAngle { k: a, l: b }
860 | Self::LineLineIntersection { k: a, l: b }
861 | Self::ParallelThrough { point: a, line: b }
862 | Self::PerpendicularThrough { point: a, line: b }
863 | Self::PointPoint { p: a, q: b }
864 | Self::Atan2 { y: a, x: b }
865 | Self::PointVector {
866 point: a,
867 vector: b,
868 }
869 | Self::ConstructCircle {
870 center: a,
871 radius: b,
872 } => {
873 set.extend(previous[a.0].iter().copied());
874 set.extend(previous[b.0].iter().copied());
875 }
876 Self::ThreePointAngle { p, q, r }
877 | Self::ThreePointAngleDir { p, q, r }
878 | Self::AngleBisector { p, q, r } => {
879 set.extend(previous[p.0].iter().copied());
880 set.extend(previous[q.0].iter().copied());
881 set.extend(previous[r.0].iter().copied());
882 }
883 Self::Const { .. } => {}
884 }
885
886 set
887 }
888}
889
890impl Default for ExprKind {
891 fn default() -> Self {
892 Self::Entity { id: EntityId(0) }
893 }
894}
895
896impl FromUnrolled<UnrolledPoint> for ExprKind {
897 fn load(expr: &Unrolled<UnrolledPoint>, math: &mut Expand) -> Self {
898 let mut kind = match expr.get_data() {
899 UnrolledPoint::LineLineIntersection(a, b) => ExprKind::LineLineIntersection {
900 k: math.load(a),
901 l: math.load(b),
902 },
903 UnrolledPoint::Average(exprs) => ExprKind::AveragePoint {
904 items: exprs.iter().map(|x| math.load(x)).collect(),
905 },
906 UnrolledPoint::CircleCenter(circle) => match circle.get_data() {
907 UnrolledCircle::Circle(center, _) => return math.load_no_store(center),
908 UnrolledCircle::Generic(_) => unreachable!(),
909 },
910 UnrolledPoint::Free => ExprKind::Entity {
911 id: math.add_point(),
912 },
913 UnrolledPoint::FromComplex(number) => ExprKind::ComplexToPoint {
914 number: math.load(number),
915 },
916 UnrolledPoint::Generic(g) => {
917 unreachable!("Expression shouldn't have reached math stage: {g}")
918 }
919 };
920
921 kind.normalize(math);
922 kind
923 }
924}
925
926impl FromUnrolled<unroll::Number> for ExprKind {
927 fn load(expr: &Unrolled<unroll::Number>, math: &mut Expand) -> Self {
928 let mut kind = match &expr.get_data().data {
929 UnrolledNumber::Add(a, b) => ExprKind::Sum {
930 plus: vec![math.load(a), math.load(b)],
931 minus: Vec::new(),
932 },
933 UnrolledNumber::Subtract(a, b) => ExprKind::Sum {
934 plus: vec![math.load(a)],
935 minus: vec![math.load(b)],
936 },
937 UnrolledNumber::Multiply(a, b) => ExprKind::Product {
938 times: vec![math.load(a), math.load(b)],
939 by: Vec::new(),
940 },
941 UnrolledNumber::Divide(a, b) => ExprKind::Product {
942 times: vec![math.load(a)],
943 by: vec![math.load(b)],
944 },
945 UnrolledNumber::Average(exprs) => {
946 let times = ExprKind::Sum {
947 plus: exprs.iter().map(|x| math.load(x)).collect(),
948 minus: Vec::new(),
949 };
950 let by = ExprKind::Const {
951 value: ProcNum::from_usize(exprs.len()).unwrap(),
952 };
953
954 ExprKind::Product {
955 times: vec![math.store(times, ExprType::Number)],
956 by: vec![math.store(by, ExprType::Number)],
957 }
958 }
959 UnrolledNumber::CircleRadius(circle) => match circle.get_data() {
960 UnrolledCircle::Circle(_, radius) => math.load_no_store(radius),
961 UnrolledCircle::Generic(_) => unreachable!(),
962 },
963 UnrolledNumber::Free => ExprKind::Entity {
964 id: math.add_real(),
965 },
966 UnrolledNumber::Number(x) => {
967 return fix_dst(ExprKind::Const { value: x.clone() }, expr.data.unit, math);
968 }
969 UnrolledNumber::DstLiteral(x) => ExprKind::Const { value: x.clone() },
970 UnrolledNumber::SetUnit(x, unit) => {
971 return fix_dst(math.load_no_store(x), Some(*unit), math)
972 }
973 UnrolledNumber::PointPointDistance(p, q) => ExprKind::PointPointDistance {
974 p: math.load(p),
975 q: math.load(q),
976 },
977 UnrolledNumber::PointLineDistance(point, line) => ExprKind::PointLineDistance {
978 point: math.load(point),
979 line: math.load(line),
980 },
981 UnrolledNumber::Negate(x) => ExprKind::Sum {
982 plus: Vec::new(),
983 minus: vec![math.load(x)],
984 },
985 UnrolledNumber::ThreePointAngle(p, q, r) => ExprKind::ThreePointAngle {
986 p: math.load(p),
987 q: math.load(q),
988 r: math.load(r),
989 },
990 UnrolledNumber::ThreePointAngleDir(p, q, r) => ExprKind::ThreePointAngleDir {
991 p: math.load(p),
992 q: math.load(q),
993 r: math.load(r),
994 },
995 UnrolledNumber::TwoLineAngle(k, l) => ExprKind::TwoLineAngle {
996 k: math.load(k),
997 l: math.load(l),
998 },
999 UnrolledNumber::Pow(base, exponent) => ExprKind::Exponentiation {
1000 value: math.load(base),
1001 exponent: *exponent,
1002 },
1003 UnrolledNumber::PointX(point) => ExprKind::PointX {
1004 point: math.load(point),
1005 },
1006 UnrolledNumber::PointY(point) => ExprKind::PointY {
1007 point: math.load(point),
1008 },
1009 UnrolledNumber::Sin(angle) => ExprKind::Sin {
1010 angle: math.load(angle),
1011 },
1012 UnrolledNumber::Cos(angle) => ExprKind::Cos {
1013 angle: math.load(angle),
1014 },
1015 UnrolledNumber::Atan2(y, x) => ExprKind::Atan2 {
1016 y: math.load(y),
1017 x: math.load(x),
1018 },
1019 UnrolledNumber::Direction(line) => ExprKind::DirectionVector {
1020 line: math.load(line),
1021 },
1022 UnrolledNumber::FromPoint(point) => ExprKind::PointToComplex {
1023 point: math.load(point),
1024 },
1025 UnrolledNumber::Real(number) => ExprKind::Real {
1026 number: math.load(number),
1027 },
1028 UnrolledNumber::Imaginary(number) => ExprKind::Imaginary {
1029 number: math.load(number),
1030 },
1031 UnrolledNumber::Log(number) => ExprKind::Log {
1032 number: math.load(number),
1033 },
1034 UnrolledNumber::Exp(number) => ExprKind::Exp {
1035 number: math.load(number),
1036 },
1037 UnrolledNumber::Generic(_) => unreachable!(),
1038 };
1039
1040 kind.normalize(math);
1041 kind
1042 }
1043}
1044
1045impl FromUnrolled<UnrolledLine> for ExprKind {
1046 fn load(expr: &Unrolled<UnrolledLine>, math: &mut Expand) -> Self {
1047 let mut kind = match expr.get_data() {
1048 UnrolledLine::LineFromPoints(a, b) => Self::PointPoint {
1049 p: math.load(a),
1050 q: math.load(b),
1051 },
1052 UnrolledLine::AngleBisector(a, b, c) => Self::AngleBisector {
1053 p: math.load(a),
1054 q: math.load(b),
1055 r: math.load(c),
1056 },
1057 UnrolledLine::PerpendicularThrough(k, p) => {
1058 match k.get_data() {
1060 UnrolledLine::PerpendicularThrough(l, _) => Self::ParallelThrough {
1061 point: math.load(p),
1062 line: math.load(l),
1063 },
1064 UnrolledLine::ParallelThrough(l, _) => Self::PerpendicularThrough {
1065 point: math.load(p),
1066 line: math.load(l),
1067 },
1068 _ => Self::PerpendicularThrough {
1069 point: math.load(p),
1070 line: math.load(k),
1071 },
1072 }
1073 }
1074 UnrolledLine::ParallelThrough(k, p) => {
1075 match k.get_data() {
1077 UnrolledLine::PerpendicularThrough(l, _) => Self::PerpendicularThrough {
1078 point: math.load(p),
1079 line: math.load(l),
1080 },
1081 UnrolledLine::ParallelThrough(l, _) => Self::ParallelThrough {
1082 point: math.load(p),
1083 line: math.load(l),
1084 },
1085 _ => Self::ParallelThrough {
1086 point: math.load(p),
1087 line: math.load(k),
1088 },
1089 }
1090 }
1091 UnrolledLine::PointVector(point, vector) => Self::PointVector {
1092 point: math.load(point),
1093 vector: math.load(vector),
1094 },
1095 UnrolledLine::Generic(_) => unreachable!(),
1096 };
1097
1098 kind.normalize(math);
1099 kind
1100 }
1101}
1102
1103impl FromUnrolled<UnrolledCircle> for ExprKind {
1104 fn load(expr: &Unrolled<UnrolledCircle>, math: &mut Expand) -> Self {
1105 let mut kind = match expr.data.as_ref() {
1106 UnrolledCircle::Circle(center, radius) => Self::ConstructCircle {
1107 center: math.load(center),
1108 radius: math.load(radius),
1109 },
1110 UnrolledCircle::Generic(_) => unreachable!(),
1111 };
1112
1113 kind.normalize(math);
1114 kind
1115 }
1116}
1117
1118impl Normalize for ExprKind {
1119 fn normalize(&mut self, math: &mut Math) {
1120 let cmp_and_swap = |a: &mut VarIndex, b: &mut VarIndex| {
1121 if math.compare(a, b) == Ordering::Greater {
1122 mem::swap(a, b);
1123 }
1124 };
1125 let cmp = |a: &VarIndex, b: &VarIndex| math.compare(a, b);
1126 let mut new_self = None;
1127
1128 match self {
1129 Self::CircleCenter { .. }
1130 | Self::PointLineDistance { .. }
1131 | Self::PointX { .. }
1132 | Self::PointY { .. }
1133 | Self::Sin { .. }
1134 | Self::Cos { .. }
1135 | Self::DirectionVector { .. }
1136 | Self::Atan2 { .. }
1137 | Self::Exponentiation { .. }
1138 | Self::ConstructCircle { .. }
1139 | Self::Const { .. }
1140 | Self::ThreePointAngleDir { .. } | Self::Entity { .. }
1142 | Self::ComplexToPoint { .. }
1143 | Self::PointToComplex { .. }
1144 | Self::Log { .. }
1145 | Self::Exp { .. }
1146 | Self::PointVector { .. }
1147 | Self::Real { .. }
1148 | Self::Imaginary { .. } => (),
1149 Self::LineLineIntersection { k: a, l: b }
1150 | Self::PointPoint { p: a, q: b }
1151 | Self::TwoLineAngle { k: a, l: b }
1152 | Self::AngleBisector { p: a, r: b, .. }
1153 | Self::ThreePointAngle { p: a, r: b, .. }
1154 | Self::PointPointDistance { p: a, q: b } => {
1155 cmp_and_swap(a, b);
1156 }
1157 Self::AveragePoint { items } => {
1158 items.sort_by(&cmp);
1159 }
1160 Self::Sum { plus, minus } => {
1161 normalize_sum(plus, minus, math);
1162 if plus.len() == 1 && minus.is_empty() {
1163 new_self = Some(math.at(&plus[0]).kind.clone());
1164 }
1165 }
1166 Self::Product { times, by } => {
1167 normalize_product(times, by, math);
1168 if times.len() == 1 && by.is_empty() {
1169 new_self = Some(math.at(×[0]).kind.clone());
1170 }
1171 }
1172 Self::ParallelThrough { point, line } => {
1173 let point = point.clone();
1175 new_self = Some(match &math.at(line).kind {
1176 Self::ParallelThrough { line, .. } => Self::ParallelThrough { point, line: line.clone() },
1177 Self::PerpendicularThrough { line, .. } => Self::PerpendicularThrough { point, line: line.clone() },
1178 _ => Self::ParallelThrough { point, line: line.clone() }
1179 });
1180 }
1181 Self::PerpendicularThrough { point, line } => {
1182 let point = point.clone();
1184 new_self = Some(match &math.at(line).kind {
1185 Self::ParallelThrough { line, .. } => Self::PerpendicularThrough { point, line: line.clone() },
1186 Self::PerpendicularThrough { line, .. } => Self::ParallelThrough { point, line: line.clone() },
1187 _ => Self::PerpendicularThrough { point, line: line.clone() }
1188 });
1189 }
1190 }
1191
1192 if let Some(new_self) = new_self {
1193 *self = new_self;
1194 }
1195 }
1196}
1197
1198fn fix_dst(expr: ExprKind, unit: Option<ComplexUnit>, math: &mut Expand) -> ExprKind {
1202 match unit {
1203 None => expr,
1204 Some(unit) => {
1205 if unit.0[SimpleUnit::Distance as usize].is_zero() {
1206 expr
1207 } else {
1208 let dst_var = math.get_dst_var();
1209 ExprKind::Product {
1210 times: vec![
1211 math.store(expr, ExprType::Number),
1212 math.store(
1213 ExprKind::Exponentiation {
1214 value: dst_var,
1215 exponent: unit.0[SimpleUnit::Distance as usize],
1216 },
1217 ExprType::Number,
1218 ),
1219 ],
1220 by: Vec::new(),
1221 }
1222 }
1223 }
1224 }
1225}
1226
1227#[derive(Debug, Clone)]
1229pub struct Merge<T, I, J, F>
1230where
1231 I: Iterator<Item = T>,
1232 J: Iterator<Item = T>,
1233 F: FnMut(&T, &T) -> Ordering,
1234{
1235 i: Peekable<I>,
1236 j: Peekable<J>,
1237 f: F,
1238}
1239
1240impl<T, I: Iterator<Item = T>, J: Iterator<Item = T>, F: FnMut(&T, &T) -> Ordering>
1241 Merge<T, I, J, F>
1242{
1243 #[must_use]
1245 pub fn new<A: IntoIterator<IntoIter = I>, B: IntoIterator<IntoIter = J>>(
1246 a: A,
1247 b: B,
1248 f: F,
1249 ) -> Self {
1250 Self {
1251 i: a.into_iter().peekable(),
1252 j: b.into_iter().peekable(),
1253 f,
1254 }
1255 }
1256
1257 #[must_use]
1259 pub fn merge_with<It: IntoIterator<Item = T>>(
1260 self,
1261 other: It,
1262 ) -> Merge<T, Self, It::IntoIter, F>
1263 where
1264 F: Clone,
1265 {
1266 let f_cloned = self.f.clone();
1267 Merge::new(self, other, f_cloned)
1268 }
1269}
1270
1271impl<T, F: FnMut(&T, &T) -> Ordering>
1272 Merge<T, std::option::IntoIter<T>, std::option::IntoIter<T>, F>
1273{
1274 #[must_use]
1276 pub fn empty(f: F) -> Self {
1277 Self::new(None, None, f)
1278 }
1279}
1280
1281impl<T, I: Iterator<Item = T>, J: Iterator<Item = T>, F: FnMut(&T, &T) -> Ordering> Iterator
1282 for Merge<T, I, J, F>
1283{
1284 type Item = T;
1285
1286 fn next(&mut self) -> Option<Self::Item> {
1287 if let Some(i_item) = self.i.peek() {
1288 if let Some(j_item) = self.j.peek() {
1289 if (self.f)(i_item, j_item) == Ordering::Less {
1290 self.i.next()
1291 } else {
1292 self.j.next()
1293 }
1294 } else {
1295 self.i.next()
1296 }
1297 } else {
1298 self.j.next()
1299 }
1300 }
1301}
1302
1303fn normalize_sum(plus: &mut Vec<VarIndex>, minus: &mut Vec<VarIndex>, math: &mut Math) {
1305 let plus_v = mem::take(plus);
1306 let minus_v = mem::take(minus);
1307
1308 let mut constant = ProcNum::zero();
1309
1310 let mut plus_final = Vec::new();
1311 let mut minus_final = Vec::new();
1312
1313 let cmp = |a: &VarIndex, b: &VarIndex| math.compare(a, b);
1314
1315 for item in plus_v {
1316 match &math.at(&item).kind {
1317 ExprKind::Sum { plus, minus } => {
1318 plus_final = Merge::new(plus_final, plus.iter().cloned(), &cmp).collect();
1320 minus_final = Merge::new(minus_final, minus.iter().cloned(), &cmp).collect();
1321 }
1322 ExprKind::Const { value } => constant += value,
1323 _ => {
1324 plus_final = Merge::new(plus_final, Some(item), &cmp).collect();
1325 }
1326 }
1327 }
1328
1329 for item in minus_v {
1330 match &math.at(&item).kind {
1331 ExprKind::Sum { plus, minus } => {
1332 plus_final = Merge::new(plus_final, minus.iter().cloned(), &cmp).collect();
1334 minus_final = Merge::new(minus_final, plus.iter().cloned(), &cmp).collect();
1335 }
1336 ExprKind::Const { value } => constant -= value,
1337 _ => {
1338 minus_final = Merge::new(minus_final, Some(item), &cmp).collect();
1339 }
1340 }
1341 }
1342
1343 if !constant.is_zero() || (plus_final.is_empty() && minus_final.is_empty()) {
1344 plus_final.push(math.store(ExprKind::Const { value: constant }, ExprType::Number));
1345 }
1346
1347 *plus = plus_final;
1348 *minus = minus_final;
1349}
1350
1351fn normalize_product(times: &mut Vec<VarIndex>, by: &mut Vec<VarIndex>, math: &mut Math) {
1353 let times_v = mem::take(times);
1354 let by_v = mem::take(by);
1355
1356 let mut constant = ProcNum::one();
1357
1358 let mut times_final = Vec::new();
1359 let mut by_final = Vec::new();
1360
1361 let cmp = |a: &VarIndex, b: &VarIndex| math.compare(a, b);
1362
1363 for item in times_v {
1364 match &math.at(&item).kind {
1365 ExprKind::Product { times, by } => {
1366 times_final = Merge::new(times_final, times.iter().cloned(), &cmp).collect();
1368 by_final = Merge::new(by_final, by.iter().cloned(), &cmp).collect();
1369 }
1370 ExprKind::Const { value } => constant *= value,
1371 _ => {
1372 times_final = Merge::new(times_final, Some(item), &cmp).collect();
1373 }
1374 }
1375 }
1376
1377 for item in by_v {
1378 match &math.at(&item).kind {
1379 ExprKind::Product { times, by } => {
1380 times_final = Merge::new(times_final, by.iter().cloned(), &cmp).collect();
1382 by_final = Merge::new(by_final, times.iter().cloned(), &cmp).collect();
1383 }
1384 ExprKind::Const { value } => constant /= value,
1385 _ => {
1386 by_final = Merge::new(by_final, Some(item), &cmp).collect();
1387 }
1388 }
1389 }
1390
1391 if !constant.is_one() || (times_final.is_empty() && by_final.is_empty()) {
1392 times_final.push(math.store(ExprKind::Const { value: constant }, ExprType::Number));
1393 }
1394
1395 *times = times_final;
1396 *by = by_final;
1397}
1398
1399#[derive(Debug, Clone, PartialEq, Eq, Default, Hash, Serialize)]
1401pub struct Expr<M> {
1402 pub meta: M,
1403 pub kind: ExprKind,
1404 pub ty: ExprType,
1405}
1406
1407impl<M> Expr<M> {
1408 #[must_use]
1410 pub fn get_type(&self, expressions: &[Expr<M>], entities: &[Entity<M>]) -> ExprType {
1411 self.kind.get_type(expressions, entities)
1412 }
1413}
1414
1415impl<M> FindEntities for Expr<M> {
1416 fn find_entities(
1417 &self,
1418 previous: &[HashSet<EntityId>],
1419 entities: &[EntityKind],
1420 ) -> HashSet<EntityId> {
1421 self.kind.find_entities(previous, entities)
1422 }
1423}
1424
1425impl<M> Reindex for Expr<M> {
1426 fn reindex(&mut self, map: &IndexMap) {
1427 self.kind.reindex(map);
1428 }
1429}
1430
1431impl<M> ContainsEntity for Expr<M> {
1432 fn contains_entity(&self, entity: EntityId, math: &Math) -> bool {
1433 self.kind.contains_entity(entity, math)
1434 }
1435}
1436
1437impl<M> Normalize for Expr<M> {
1438 fn normalize(&mut self, math: &mut Math) {
1439 self.kind.normalize(math);
1440 }
1441}
1442
1443impl Expr<()> {
1444 #[must_use]
1446 pub fn new(kind: ExprKind, ty: ExprType) -> Self {
1447 Self { kind, meta: (), ty }
1448 }
1449}
1450
1451#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Recursive)]
1456#[recursive(
1457 impl ContainsEntity for Self {
1458 fn contains_entity(&self, entity: EntityId, math: &Math) -> bool {
1459 aggregate = ||,
1460 init = false
1461 }
1462 }
1463)]
1464#[recursive(
1465 impl Reconstruct for Self {
1466 fn reconstruct(self, ctx: &mut ReconstructCtx) -> Self {
1467 aggregate = {}
1468 }
1469 }
1470)]
1471pub enum RuleKind {
1472 PointEq(VarIndex, VarIndex),
1474 NumberEq(VarIndex, VarIndex),
1476 Gt(VarIndex, VarIndex),
1478 Alternative(Vec<RuleKind>),
1480 Invert(Box<RuleKind>),
1482 Bias,
1484}
1485
1486impl Display for RuleKind {
1487 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1488 match self {
1489 RuleKind::PointEq(a, b) | RuleKind::NumberEq(a, b) => write!(f, "{a} = {b}"),
1490 RuleKind::Gt(a, b) => write!(f, "{a} > {b}"),
1491 RuleKind::Alternative(v) => {
1492 for kind in v {
1493 write!(f, "| {kind}")?;
1494 }
1495
1496 Ok(())
1497 }
1498 RuleKind::Invert(v) => write!(f, "not {v}"),
1499 RuleKind::Bias => write!(f, "bias"),
1500 }
1501 }
1502}
1503
1504impl FindEntities for RuleKind {
1505 #[allow(clippy::only_used_in_recursion)]
1506 fn find_entities(
1507 &self,
1508 previous: &[HashSet<EntityId>],
1509 entities: &[EntityKind],
1510 ) -> HashSet<EntityId> {
1511 let mut set = HashSet::new();
1512
1513 match self {
1514 Self::PointEq(a, b) | Self::NumberEq(a, b) | Self::Gt(a, b) => {
1515 set.extend(previous[a.0].iter().copied());
1516 set.extend(previous[b.0].iter().copied());
1517 }
1518 Self::Alternative(items) => {
1519 return items
1520 .iter()
1521 .flat_map(|x| x.find_entities(previous, entities).into_iter())
1522 .collect();
1523 }
1524 Self::Invert(rule) => {
1525 return rule.find_entities(previous, entities);
1526 }
1527 Self::Bias => unreachable!(),
1528 }
1529
1530 set
1531 }
1532}
1533
1534#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
1537pub struct Rule {
1538 pub kind: RuleKind,
1540 pub weight: ProcNum,
1542 pub entities: Vec<EntityId>,
1544}
1545
1546impl Display for Rule {
1547 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1548 write!(f, "{}", self.kind)
1549 }
1550}
1551
1552impl ContainsEntity for Rule {
1553 fn contains_entity(&self, entity: EntityId, math: &Math) -> bool {
1554 self.kind.contains_entity(entity, math)
1555 }
1556}
1557
1558impl Reconstruct for Rule {
1559 fn reconstruct(self, ctx: &mut ReconstructCtx) -> Self {
1560 Self {
1561 kind: self.kind.reconstruct(ctx),
1562 ..self
1563 }
1564 }
1565}
1566
1567impl Reindex for RuleKind {
1568 fn reindex(&mut self, map: &IndexMap) {
1569 match self {
1570 Self::PointEq(a, b) | Self::NumberEq(a, b) | Self::Gt(a, b) => {
1571 a.reindex(map);
1572 b.reindex(map);
1573 }
1574 Self::Alternative(items) => {
1575 items.reindex(map);
1576 }
1577 Self::Invert(rule) => rule.reindex(map),
1578 Self::Bias => {}
1579 }
1580 }
1581}
1582
1583impl Reindex for Rule {
1584 fn reindex(&mut self, map: &IndexMap) {
1585 self.kind.reindex(map);
1586 }
1587}
1588
1589impl RuleKind {
1590 fn load(rule: &UnrolledRule, math: &mut Expand) -> Self {
1595 let mut mathed = match &rule.kind {
1596 UnrolledRuleKind::PointEq(a, b) => Self::PointEq(math.load(a), math.load(b)),
1597 UnrolledRuleKind::NumberEq(a, b) => Self::NumberEq(math.load(a), math.load(b)),
1598 UnrolledRuleKind::Gt(a, b) => Self::Gt(math.load(a), math.load(b)),
1599 UnrolledRuleKind::Alternative(rules) => {
1600 Self::Alternative(rules.iter().map(|x| Self::load(x, math)).collect())
1601 }
1602 UnrolledRuleKind::Bias(_) => Self::Bias,
1603 };
1604
1605 mathed.normalize(math);
1606
1607 if rule.inverted {
1608 Self::Invert(Box::new(mathed))
1609 } else {
1610 mathed
1611 }
1612 }
1613}
1614
1615impl Rule {
1616 fn load(rule: &UnrolledRule, math: &mut Expand) -> Self {
1621 Self {
1622 kind: RuleKind::load(rule, math),
1623 weight: rule.weight.clone(),
1624 entities: Vec::new(),
1625 }
1626 }
1627}
1628
1629impl Normalize for RuleKind {
1630 fn normalize(&mut self, math: &mut Math) {
1631 match self {
1632 Self::PointEq(a, b) | Self::NumberEq(a, b) => {
1633 if math.compare(a, b) == Ordering::Greater {
1634 mem::swap(a, b);
1635 }
1636 }
1637 Self::Alternative(v) => v.sort(),
1638 Self::Invert(_) | Self::Bias | Self::Gt(_, _) => (),
1639 }
1640 }
1641}
1642
1643impl Normalize for Rule {
1644 fn normalize(&mut self, math: &mut Math) {
1645 self.kind.normalize(math);
1646 }
1647}
1648
1649#[derive(Debug)]
1651pub struct Adjusted {
1652 pub variables: Vec<Expr<()>>,
1654 pub rules: Vec<Rule>,
1656 pub entities: Vec<EntityKind>,
1658}
1659
1660#[derive(Debug)]
1662pub struct Intermediate {
1663 pub figure: Figure,
1665 pub adjusted: Adjusted,
1667 pub flags: Flags,
1669}
1670
1671#[derive(Debug, Clone, Serialize)]
1673pub struct Entity<M> {
1674 pub kind: EntityKind,
1675 pub meta: M,
1676}
1677
1678impl<M> Entity<M> {
1679 #[must_use]
1681 pub fn get_type(&self, expressions: &[Expr<M>], entities: &[Entity<M>]) -> ExprType {
1682 match &self.kind {
1683 EntityKind::FreePoint
1684 | EntityKind::PointOnLine { .. }
1685 | EntityKind::PointOnCircle { .. } => ExprType::Point,
1686 EntityKind::FreeReal | EntityKind::DistanceUnit => ExprType::Number,
1687 EntityKind::Bind(expr) => expressions[expr.0].get_type(expressions, entities),
1688 }
1689 }
1690}
1691
1692#[derive(Debug, Clone, Recursive, Serialize)]
1694#[recursive(
1695 impl ContainsEntity for Self {
1696 fn contains_entity(&self, entity: EntityId, math: &Math) -> bool {
1697 aggregate = ||,
1698 init = false
1699 }
1700 }
1701)]
1702#[recursive(
1703 impl Reconstruct for Self {
1704 fn reconstruct(self, ctx: &mut ReconstructCtx) -> Self {
1705 aggregate = {}
1706 }
1707 }
1708)]
1709pub enum EntityKind {
1710 FreePoint,
1712 PointOnLine { line: VarIndex },
1714 PointOnCircle { circle: VarIndex },
1716 FreeReal,
1718 DistanceUnit,
1720 Bind(VarIndex),
1723}
1724
1725impl FindEntities for EntityKind {
1726 fn find_entities(
1727 &self,
1728 previous: &[HashSet<EntityId>],
1729 _entities: &[EntityKind],
1730 ) -> HashSet<EntityId> {
1731 match self {
1732 Self::PointOnLine { line: var } | Self::PointOnCircle { circle: var } => {
1733 previous[var.0].clone()
1734 }
1735 Self::FreePoint | Self::FreeReal | Self::DistanceUnit => HashSet::new(),
1736 Self::Bind(_) => unreachable!(),
1737 }
1738 }
1739}
1740
1741impl From<EntityKind> for geo_aid_figure::EntityKind {
1742 fn from(value: EntityKind) -> Self {
1743 match value {
1744 EntityKind::FreePoint => Self::FreePoint,
1745 EntityKind::PointOnLine { line } => Self::PointOnLine { line },
1746 EntityKind::PointOnCircle { circle } => Self::PointOnCircle { circle },
1747 EntityKind::FreeReal => Self::FreeReal,
1748 EntityKind::DistanceUnit => Self::DistanceUnit,
1749 EntityKind::Bind(_) => unreachable!(),
1750 }
1751 }
1752}
1753
1754impl Reindex for EntityKind {
1755 fn reindex(&mut self, map: &IndexMap) {
1756 match self {
1757 Self::FreePoint | Self::DistanceUnit | Self::FreeReal => {}
1758 Self::PointOnLine { line } => line.reindex(map),
1759 Self::PointOnCircle { circle } => circle.reindex(map),
1760 Self::Bind(_) => unreachable!("Should not appear"),
1761 }
1762 }
1763}
1764
1765impl Reindex for EntityId {
1766 fn reindex(&mut self, _map: &IndexMap) {}
1767}
1768
1769impl DeepClone for EntityId {
1770 fn deep_clone(&self, _math: &mut Math) -> Self {
1771 *self
1773 }
1774}
1775
1776impl ContainsEntity for EntityId {
1777 fn contains_entity(&self, entity: EntityId, math: &Math) -> bool {
1778 *self == entity || math.entities[self.0].contains_entity(entity, math)
1779 }
1780}
1781
1782#[derive(Debug, Clone, Default)]
1784pub struct Expand {
1785 pub expr_map: HashMap<usize, Expr<()>>,
1787 pub math: Math,
1789 pub rc_keepalive: Vec<Rc<dyn Any>>,
1795}
1796
1797impl Deref for Expand {
1798 type Target = Math;
1799
1800 fn deref(&self) -> &Self::Target {
1801 &self.math
1802 }
1803}
1804
1805impl DerefMut for Expand {
1806 fn deref_mut(&mut self) -> &mut Self::Target {
1807 &mut self.math
1808 }
1809}
1810
1811#[derive(Debug, Clone, Default)]
1813pub struct Math {
1814 pub entities: Vec<EntityKind>,
1816 pub dst_var: OnceCell<EntityId>,
1818 pub expr_record: Vec<Expr<()>>,
1820}
1821
1822impl Expand {
1823 pub fn load<T: Displayed + GetMathType + Debug + GetData + 'static>(
1825 &mut self,
1826 unrolled: &Unrolled<T>,
1827 ) -> VarIndex
1828 where
1829 ExprKind: FromUnrolled<T>,
1830 {
1831 let expr = self.load_no_store(unrolled);
1832 self.store(expr, T::get_math_type())
1833 }
1834
1835 pub fn load_no_store<T: Displayed + GetMathType + GetData + Debug + 'static>(
1837 &mut self,
1838 unrolled: &Unrolled<T>,
1839 ) -> ExprKind
1840 where
1841 ExprKind: FromUnrolled<T>,
1842 {
1843 self.rc_keepalive
1845 .push(Rc::clone(&unrolled.data) as Rc<dyn Any>);
1846
1847 let key = std::ptr::from_ref(unrolled.get_data()) as usize;
1848 let loaded = self.expr_map.get(&key).cloned();
1849
1850 if let Some(loaded) = loaded {
1851 loaded.kind.deep_clone(self)
1852 } else {
1853 let loaded = ExprKind::load(unrolled, self);
1855 self.expr_map
1856 .insert(key, Expr::new(loaded.clone(), T::get_math_type()));
1857 loaded
1858 }
1859 }
1860}
1861
1862impl Math {
1863 #[must_use]
1865 pub fn store(&mut self, expr: ExprKind, ty: ExprType) -> VarIndex {
1866 self.expr_record.push(Expr::new(expr, ty));
1867 VarIndex(self.expr_record.len() - 1)
1868 }
1869
1870 #[must_use]
1872 pub fn compare(&self, a: &VarIndex, b: &VarIndex) -> Ordering {
1873 self.at(a).kind.compare(&self.at(b).kind, self)
1874 }
1875
1876 #[must_use]
1881 pub fn get_dst_var(&mut self) -> VarIndex {
1882 let id = self.dst_var.get();
1883 let is_some = id.is_some();
1884
1885 let id = *if is_some {
1886 id.unwrap()
1887 } else {
1888 let real = self.add_entity(EntityKind::DistanceUnit);
1889 self.dst_var.get_or_init(|| real)
1890 };
1891
1892 self.store(ExprKind::Entity { id }, ExprType::Number)
1893 }
1894
1895 #[must_use]
1897 pub fn at(&self, index: &VarIndex) -> &Expr<()> {
1898 &self.expr_record[index.0]
1899 }
1900
1901 fn add_entity(&mut self, entity: EntityKind) -> EntityId {
1903 self.entities.push(entity);
1904 EntityId(self.entities.len() - 1)
1905 }
1906
1907 pub fn add_point(&mut self) -> EntityId {
1909 self.add_entity(EntityKind::FreePoint)
1910 }
1911
1912 pub fn add_real(&mut self) -> EntityId {
1914 self.add_entity(EntityKind::FreeReal)
1915 }
1916}
1917
1918#[derive(Debug, Clone, Default)]
1920pub struct Build {
1921 expand: Expand,
1923 items: Vec<Item>,
1925}
1926
1927impl Build {
1928 pub fn load<T: Displayed + GetMathType + Debug + GetData + 'static>(
1930 &mut self,
1931 expr: &Unrolled<T>,
1932 ) -> VarIndex
1933 where
1934 ExprKind: FromUnrolled<T>,
1935 {
1936 self.expand.load(expr)
1937 }
1938
1939 pub fn add<I: Into<Item>>(&mut self, item: I) {
1940 let item = item.into();
1941 self.items.push(item);
1942 }
1943}
1944
1945fn optimize_rules(rules: &mut Vec<Option<Rule>>, math: &mut Math) -> bool {
1950 let mut performed = false;
1951
1952 for rule in rules.iter_mut() {
1953 let rule_performed = ZeroLineDst::process(rule, math)
1954 | RightAngle::process(rule, math)
1955 | EqPointDst::process(rule, math)
1956 | EqExpressions::process(rule, math);
1957
1958 performed |= rule_performed;
1959 }
1960
1961 if performed {
1962 rules.retain(Option::is_some);
1963 }
1964
1965 performed
1966}
1967
1968#[derive(Debug, Clone, Default)]
1970pub struct IndexMap {
1971 mappings: Vec<(usize, usize)>,
1973}
1974
1975impl IndexMap {
1976 #[must_use]
1978 pub fn new() -> Self {
1979 Self {
1980 mappings: Vec::new(),
1981 }
1982 }
1983
1984 #[must_use]
1986 pub fn get(&self, mut a: usize) -> usize {
1987 for m in &self.mappings {
1988 if a == m.0 {
1989 a = m.1;
1990 }
1991 }
1992
1993 a
1994 }
1995
1996 pub fn map(&mut self, a: usize, b: usize) {
1998 if a != b {
1999 self.mappings.push((a, b));
2000 }
2001 }
2002
2003 pub fn compose(lhs: Self, rhs: &mut Self) {
2005 rhs.mappings.extend(lhs.mappings);
2006 }
2007}
2008
2009pub trait Reindex {
2014 fn reindex(&mut self, map: &IndexMap);
2016}
2017
2018impl Reindex for CompExponent {
2019 fn reindex(&mut self, _map: &IndexMap) {}
2020}
2021
2022impl Reindex for ProcNum {
2023 fn reindex(&mut self, _map: &IndexMap) {}
2024}
2025
2026impl<T: Reindex> Reindex for Box<T> {
2027 fn reindex(&mut self, map: &IndexMap) {
2028 self.as_mut().reindex(map);
2029 }
2030}
2031
2032impl<T: Reindex> Reindex for Vec<T> {
2033 fn reindex(&mut self, map: &IndexMap) {
2034 for item in self {
2035 item.reindex(map);
2036 }
2037 }
2038}
2039
2040fn fold(matrix: &mut Vec<Expr<()>>) -> IndexMap {
2047 let mut target = Vec::new();
2048 let mut final_map = IndexMap::new();
2049 let mut record = HashMap::new();
2050
2051 loop {
2052 let mut map = IndexMap::new();
2054 let mut folded = false;
2055 for (i, expr) in matrix.iter_mut().enumerate() {
2056 let mut expr = mem::take(expr);
2057 expr.reindex(&map);
2059 match record.entry(expr) {
2061 hash_map::Entry::Vacant(entry) => {
2062 target.push(entry.key().clone());
2063 let new_i = target.len() - 1;
2064 map.map(i, new_i);
2066 entry.insert(new_i);
2067 }
2068 hash_map::Entry::Occupied(entry) => {
2069 let j = *entry.get();
2071 map.map(i, j);
2072 folded = true;
2074 }
2075 }
2076 }
2077 IndexMap::compose(map, &mut final_map);
2082 mem::swap(matrix, &mut target);
2086 target.clear();
2088 record.clear();
2089
2090 if !folded {
2091 break final_map;
2092 }
2093 }
2094}
2095
2096fn read_flags(flags: &HashMap<&'static str, Flag>) -> Flags {
2097 Flags {
2098 optimizations: Optimizations {},
2099 point_inequalities: flags["point_inequalities"].as_bool().unwrap(),
2100 }
2101}
2102
2103fn optimize_cycle(rules: &mut Vec<Option<Rule>>, math: &mut Math, items: &mut Vec<Item>) {
2105 let mut entity_map = Vec::new();
2106 loop {
2107 if !optimize_rules(rules, math) {
2108 break;
2109 }
2110
2111 let mut offset = 0;
2116 entity_map.clear();
2117
2118 let mut entities = mem::take(&mut math.entities);
2120
2121 for (i, entity) in entities.iter().enumerate() {
2122 entity_map.push(match entity {
2123 EntityKind::FreePoint
2124 | EntityKind::FreeReal
2125 | EntityKind::DistanceUnit
2126 | EntityKind::PointOnCircle { .. }
2127 | EntityKind::PointOnLine { .. } => EntityBehavior::MapEntity(EntityId(i - offset)),
2128 EntityKind::Bind(expr) => {
2129 offset += 1;
2130 EntityBehavior::MapVar(expr.clone()) }
2132 });
2133 }
2134
2135 entities.retain(|x| !matches!(x, EntityKind::Bind(_)));
2136
2137 let old_vars = mem::take(&mut math.expr_record);
2143 let mut ctx = ReconstructCtx::new(&entity_map, &old_vars, &entities);
2144 let old_items = mem::take(items);
2145 *items = old_items.reconstruct(&mut ctx);
2146 let old_rules = mem::take(rules);
2147 *rules = old_rules.reconstruct(&mut ctx);
2148 math.expr_record = ctx.new_vars;
2149 math.entities = ctx.new_entities.into_iter().map(Option::unwrap).collect();
2150
2151 let v = math.expr_record.clone();
2156 for (i, mut var) in v.into_iter().enumerate() {
2157 var.normalize(math);
2158 math.expr_record[i] = var;
2160 }
2161
2162 for rule in rules.iter_mut().flatten() {
2164 rule.normalize(math);
2165 }
2166 }
2167}
2168
2169pub fn load_script(input: &str) -> Result<Intermediate, Vec<Error>> {
2176 let (mut unrolled, nodes) = unroll::unroll(input)?;
2190
2191 let mut build = Build::default();
2197 Box::new(nodes).build(&mut build);
2198
2199 let mut expand = build.expand;
2201
2202 let mut rules = Vec::new();
2210
2211 for rule in unrolled.take_rules() {
2212 rules.push(Some(Rule::load(&rule, &mut expand)));
2213 }
2214
2215 let mut math = expand.math;
2221
2222 optimize_cycle(&mut rules, &mut math, &mut build.items);
2223
2224 let old_entities = mem::take(&mut math.entities);
2229 let entity_map: Vec<_> = (0..old_entities.len())
2230 .map(|i| EntityBehavior::MapEntity(EntityId(i)))
2231 .collect();
2232
2233 let old_vars = mem::take(&mut math.expr_record);
2234 let mut ctx = ReconstructCtx::new(&entity_map, &old_vars, &old_entities);
2235 build.items = build.items.reconstruct(&mut ctx);
2236 rules = rules.reconstruct(&mut ctx);
2237 math.expr_record = ctx.new_vars;
2238 let new_entities: Vec<_> = ctx.new_entities.into_iter().map(Option::unwrap).collect();
2239
2240 let mut rules: Vec<_> = rules.into_iter().flatten().collect();
2250
2251 let flags = read_flags(&unrolled.flags);
2252
2253 if flags.point_inequalities {
2255 for i in new_entities
2256 .iter()
2257 .enumerate()
2258 .filter(|ent| {
2259 matches!(
2260 ent.1,
2261 EntityKind::PointOnLine { .. }
2262 | EntityKind::FreePoint
2263 | EntityKind::PointOnCircle { .. }
2264 )
2265 })
2266 .map(|x| x.0)
2267 {
2268 for j in new_entities
2269 .iter()
2270 .enumerate()
2271 .skip(i + 1)
2272 .filter(|ent| {
2273 matches!(
2274 ent.1,
2275 EntityKind::PointOnLine { .. }
2276 | EntityKind::FreePoint
2277 | EntityKind::PointOnCircle { .. }
2278 )
2279 })
2280 .map(|x| x.0)
2281 {
2282 let ent1 = math.store(ExprKind::Entity { id: EntityId(i) }, ExprType::Point);
2283 let ent2 = math.store(ExprKind::Entity { id: EntityId(j) }, ExprType::Point);
2284 rules.push(Rule {
2285 weight: ProcNum::one(),
2286 entities: Vec::new(),
2287 kind: RuleKind::Invert(Box::new(RuleKind::PointEq(ent1, ent2))),
2288 });
2289 }
2290 }
2291 }
2292
2293 math.entities = new_entities;
2294
2295 let mut entities = math.entities;
2305 let mut variables = math.expr_record;
2306 let mut fig_variables = variables.clone();
2307 let mut fig_entities = entities.clone();
2308
2309 let index_map = fold(&mut variables);
2310 entities.reindex(&index_map);
2311 rules.reindex(&index_map);
2312
2313 let mut found_entities = Vec::new();
2315 for expr in &variables {
2316 let found = expr.find_entities(&found_entities, &entities);
2317 found_entities.push(found);
2318 }
2319
2320 for rule in &mut rules {
2321 let entities = rule.kind.find_entities(&found_entities, &entities);
2322 rule.entities = entities.into_iter().collect();
2323 }
2324
2325 let index_map = fold(&mut fig_variables);
2335 let mut items = build.items;
2343 items.reindex(&index_map);
2344 fig_entities.reindex(&index_map);
2345
2346 Ok(Intermediate {
2355 adjusted: Adjusted {
2356 variables,
2357 rules,
2358 entities,
2359 },
2360 figure: Figure {
2361 entities: fig_entities,
2362 variables: fig_variables,
2363 items,
2364 },
2365 flags,
2366 })
2367}