geo_aid_script/
math.rs

1//! Math stage is responsible for most if not all of the optimizations the compiler does.
2//! It's at this point where rules are analyzed, expressions normalized, patterns that
3//! can be optimized optimized. It's the final and most important stage of compilation.
4
5use crate::figure::Item;
6use crate::math::optimizations::ZeroLineDst;
7use crate::token::number::{CompExponent, ProcNum};
8use crate::unroll::figure::Node;
9use crate::unroll::flags::Flag;
10use derive_recursive::Recursive;
11use geo_aid_figure::{EntityIndex as EntityId, VarIndex};
12use num_traits::{FromPrimitive, One, Zero};
13use serde::Serialize;
14use std::any::Any;
15use std::cell::OnceCell;
16use std::cmp::Ordering;
17use std::collections::{hash_map, HashMap, HashSet};
18use std::fmt::{Debug, Display, Formatter};
19use std::hash::Hash;
20use std::iter::Peekable;
21use std::mem;
22use std::ops::{Deref, DerefMut};
23use std::rc::Rc;
24
25use self::optimizations::{EqExpressions, EqPointDst, RightAngle};
26
27use super::unroll::GetData;
28use super::{
29    figure::Figure,
30    unroll::{
31        self, Circle as UnrolledCircle, Displayed, Expr as Unrolled, Line as UnrolledLine,
32        NumberData as UnrolledNumber, Point as UnrolledPoint, UnrolledRule, UnrolledRuleKind,
33    },
34    ComplexUnit, Error, SimpleUnit,
35};
36
37mod optimizations;
38
39/// The `optimizations` flag group. Currently empty.
40/// Has nothing to do with the [`optimizations`] module.
41#[derive(Debug, Clone)]
42pub struct Optimizations {}
43
44/// Compiler flags.
45#[derive(Debug, Clone)]
46pub struct Flags {
47    /// The `optimizations` flag group.
48    pub optimizations: Optimizations,
49    /// Whether to include point inequalitiy rules.
50    pub point_inequalities: bool,
51}
52
53impl Default for Flags {
54    fn default() -> Self {
55        Self {
56            optimizations: Optimizations {},
57            point_inequalities: false,
58        }
59    }
60}
61
62/// Helper trait for getting the type of the expression.
63pub trait GetMathType {
64    #[must_use]
65    fn get_math_type() -> ExprType;
66}
67
68impl GetMathType for UnrolledPoint {
69    fn get_math_type() -> ExprType {
70        ExprType::Point
71    }
72}
73
74impl GetMathType for UnrolledLine {
75    fn get_math_type() -> ExprType {
76        ExprType::Line
77    }
78}
79
80impl GetMathType for UnrolledCircle {
81    fn get_math_type() -> ExprType {
82        ExprType::Circle
83    }
84}
85
86impl GetMathType for unroll::Number {
87    fn get_math_type() -> ExprType {
88        ExprType::Number
89    }
90}
91
92/// Deep clone a mathematical expression. This is different from `Clone` in that
93/// it works with the math stage's flattened memory model and deep clones all of
94/// an expression's dependencies (subexpressions).
95pub trait DeepClone {
96    /// Perform a deep clone.
97    #[must_use]
98    fn deep_clone(&self, math: &mut Math) -> Self;
99}
100
101impl DeepClone for ProcNum {
102    fn deep_clone(&self, _math: &mut Math) -> Self {
103        self.clone()
104    }
105}
106
107impl DeepClone for CompExponent {
108    fn deep_clone(&self, _math: &mut Math) -> Self {
109        *self
110    }
111}
112
113impl<T: DeepClone> DeepClone for Vec<T> {
114    fn deep_clone(&self, math: &mut Math) -> Self {
115        self.iter().map(|x| x.deep_clone(math)).collect()
116    }
117}
118
119/// Compare two things using the math context.
120trait Compare {
121    /// Compare two things.
122    #[must_use]
123    fn compare(&self, other: &Self, math: &Math) -> Ordering;
124}
125
126impl<T: Compare> Compare for Vec<T> {
127    fn compare(&self, other: &Self, math: &Math) -> Ordering {
128        self.iter()
129            .zip(other)
130            .map(|(a, b)| a.compare(b, math))
131            .find(|x| x.is_ne())
132            .unwrap_or_else(|| self.len().cmp(&other.len()))
133    }
134}
135
136/// Check if an expression contains an entity.
137trait ContainsEntity {
138    /// Checks if an expression contains the specified entity.
139    fn contains_entity(&self, entity: EntityId, math: &Math) -> bool;
140}
141
142impl ContainsEntity for CompExponent {
143    fn contains_entity(&self, _entity: EntityId, _math: &Math) -> bool {
144        false
145    }
146}
147
148impl ContainsEntity for ProcNum {
149    fn contains_entity(&self, _entity: EntityId, _math: &Math) -> bool {
150        false
151    }
152}
153
154impl<T: ContainsEntity> ContainsEntity for Box<T> {
155    fn contains_entity(&self, entity: EntityId, math: &Math) -> bool {
156        self.as_ref().contains_entity(entity, math)
157    }
158}
159
160impl<T: ContainsEntity> ContainsEntity for Vec<T> {
161    fn contains_entity(&self, entity: EntityId, math: &Math) -> bool {
162        self.iter().any(|item| item.contains_entity(entity, math))
163    }
164}
165
166/// Defines how the reconstruction process should affect a given entity.
167#[derive(Debug, Clone, PartialEq, Eq)]
168enum EntityBehavior {
169    /// Map this id to another id.
170    MapEntity(EntityId),
171    /// Map this id to another expression (deep clones the expression)
172    MapVar(VarIndex),
173}
174
175/// A context for the recosntruction process.
176pub struct ReconstructCtx<'r> {
177    /// Beheaviors of each entity
178    entity_replacement: &'r [EntityBehavior],
179    /// Old expressions
180    old_vars: &'r [Expr<()>],
181    /// New, reconstructed expressions.
182    new_vars: Vec<Expr<()>>,
183    /// Old entities.
184    old_entities: &'r [EntityKind],
185    /// New, reconstructed entities.
186    new_entities: Vec<Option<EntityKind>>,
187}
188
189impl<'r> ReconstructCtx<'r> {
190    /// Create a new context
191    #[must_use]
192    fn new(
193        entity_replacement: &'r [EntityBehavior],
194        old_vars: &'r [Expr<()>],
195        old_entities: &'r [EntityKind],
196    ) -> Self {
197        let mut ctx = Self {
198            entity_replacement,
199            old_vars,
200            new_vars: Vec::new(),
201            old_entities,
202            new_entities: vec![None; old_entities.len()],
203        };
204
205        for i in 0..old_entities.len() {
206            ctx.reconstruct_entity(EntityId(i));
207        }
208
209        ctx
210    }
211
212    /// Reconstruct an entity.
213    fn reconstruct_entity(&mut self, id: EntityId) {
214        if self.new_entities[id.0].is_none() {
215            self.new_entities[id.0] = Some(self.old_entities[id.0].clone().reconstruct(self));
216        }
217    }
218}
219
220/// The main actor in reconstruction process.
221///
222/// The point of a reconstruction is to remove all potential forward references
223/// for the sake of later processing. This means that if a reconstructed expression
224/// requires some other expressions to be computed before it can be computed itself,
225/// those expressions will have a smaller index in the expression vector.
226pub trait Reconstruct {
227    /// Reconstruct the value.
228    #[must_use]
229    fn reconstruct(self, ctx: &mut ReconstructCtx) -> Self;
230}
231
232/// Helper for reconstructing entities.
233fn reconstruct_entity(entity_id: EntityId, ctx: &mut ReconstructCtx) -> ExprKind {
234    match &ctx.entity_replacement[entity_id.0] {
235        EntityBehavior::MapEntity(id) => {
236            ctx.reconstruct_entity(*id);
237            ExprKind::Entity { id: *id }
238        }
239        EntityBehavior::MapVar(index) => ctx.old_vars[index.0].kind.clone().reconstruct(ctx),
240    }
241}
242
243impl Reconstruct for ProcNum {
244    fn reconstruct(self, _ctx: &mut ReconstructCtx) -> Self {
245        self
246    }
247}
248
249impl Reconstruct for CompExponent {
250    fn reconstruct(self, _ctx: &mut ReconstructCtx) -> Self {
251        self
252    }
253}
254
255impl<T: Reconstruct> Reconstruct for Option<T> {
256    fn reconstruct(self, ctx: &mut ReconstructCtx) -> Self {
257        self.map(|v| v.reconstruct(ctx))
258    }
259}
260
261impl<T: Reconstruct> Reconstruct for Box<T> {
262    fn reconstruct(self, ctx: &mut ReconstructCtx) -> Self {
263        Self::new((*self).reconstruct(ctx))
264    }
265}
266
267impl<T: Reconstruct> Reconstruct for Vec<T> {
268    fn reconstruct(self, ctx: &mut ReconstructCtx) -> Self {
269        self.into_iter().map(|x| x.reconstruct(ctx)).collect()
270    }
271}
272
273/// Helper for finding entities inside rules and expressions.
274trait FindEntities {
275    /// Find all entities in this expression based on the entities
276    /// found in all previous expressions.
277    fn find_entities(
278        &self,
279        previous: &[HashSet<EntityId>],
280        entities: &[EntityKind],
281    ) -> HashSet<EntityId>;
282}
283
284impl FindEntities for EntityId {
285    fn find_entities(
286        &self,
287        previous: &[HashSet<EntityId>],
288        entities: &[EntityKind],
289    ) -> HashSet<EntityId> {
290        entities[self.0].find_entities(previous, entities)
291    }
292}
293
294impl FindEntities for Vec<VarIndex> {
295    fn find_entities(
296        &self,
297        previous: &[HashSet<EntityId>],
298        _entities: &[EntityKind],
299    ) -> HashSet<EntityId> {
300        self.iter()
301            .flat_map(|x| previous[x.0].iter().copied())
302            .collect()
303    }
304}
305
306/// Helper trait for converting unrolled expressions into math expressions.
307pub trait FromUnrolled<T: Displayed> {
308    /// Load the unroll expression
309    fn load(expr: &Unrolled<T>, math: &mut Expand) -> Self;
310}
311
312/// Helper trait for normalizing expressions.
313trait Normalize {
314    /// Normalization is a crucial step for further processing.
315    /// Usually normalized expressions and rules have their parameters
316    /// ordered. There can also be extra requirements on certain kinds
317    /// but they are documented in their respective places.
318    fn normalize(&mut self, math: &mut Math);
319}
320
321impl Reconstruct for VarIndex {
322    fn reconstruct(self, ctx: &mut ReconstructCtx) -> Self {
323        let expr = ctx.old_vars[self.0].clone();
324        let kind = expr.kind.reconstruct(ctx);
325        ctx.new_vars.push(Expr::new(kind, expr.ty));
326        VarIndex(ctx.new_vars.len() - 1)
327    }
328}
329
330impl DeepClone for VarIndex {
331    fn deep_clone(&self, math: &mut Math) -> Self {
332        // The clone here is necessary to satisfy the borrow checker.
333        // Looks ugly. but otherwise, we'd borrow `math` both mutably and immutably.
334        let ty = math.at(self).ty;
335        let expr = math.at(self).kind.clone().deep_clone(math);
336        math.store(expr, ty)
337    }
338}
339
340impl Compare for VarIndex {
341    fn compare(&self, other: &Self, math: &Math) -> Ordering {
342        math.at(self).kind.compare(&math.at(other).kind, math)
343    }
344}
345
346impl Reindex for VarIndex {
347    fn reindex(&mut self, map: &IndexMap) {
348        self.0 = map.get(self.0);
349    }
350}
351
352impl ContainsEntity for VarIndex {
353    fn contains_entity(&self, entity: EntityId, math: &Math) -> bool {
354        math.at(self).contains_entity(entity, math)
355    }
356}
357
358/// The primitive type of an expression.
359#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Hash, Serialize)]
360pub enum ExprType {
361    Number,
362    #[default]
363    Point,
364    Line,
365    Circle,
366}
367
368/// A mathematical expression with a flattened memory model.
369#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Recursive, Hash, Serialize)]
370#[recursive(
371    impl ContainsEntity for Self {
372        fn contains_entity(&self, entity: EntityId, math: &Math) -> bool {
373            aggregate = ||
374        }
375    }
376)]
377#[recursive(
378    impl Reindex for Self {
379        fn reindex(&mut self, map: &IndexMap) {
380            aggregate = _
381        }
382    }
383)]
384#[recursive(
385    impl DeepClone for Self {
386        fn deep_clone(&self, math: &mut Math) -> Self {
387            aggregate = {}
388        }
389    }
390)]
391#[recursive(
392    impl Reconstruct for Self {
393        fn reconstruct(self, ctx: &mut ReconstructCtx) -> Self {
394            aggregate = {},
395            override_marker = override_reconstruct
396        }
397    }
398)]
399pub enum ExprKind {
400    /// An entity referece.
401    #[recursive(override_reconstruct = reconstruct_entity)]
402    Entity { id: EntityId },
403
404    // POINT
405    /// An intersection of two lines: `k` and `l`
406    LineLineIntersection { k: VarIndex, l: VarIndex },
407    /// The arithmetic mean of given point expressions
408    AveragePoint { items: Vec<VarIndex> },
409    /// Center of a circle.
410    CircleCenter { circle: VarIndex },
411    /// Convert a complex number to a point (no-op)
412    ComplexToPoint { number: VarIndex },
413
414    // NUMBER
415    /// Sum of numbers.
416    ///
417    /// A normalized sum must be sorted and must not contain other sums or negations.
418    /// An aggregated constant, if any, must be at the end.
419    Sum {
420        /// Items to add.
421        plus: Vec<VarIndex>,
422        /// Items to subtract.
423        minus: Vec<VarIndex>,
424    },
425    /// Product of numbers.
426    ///
427    /// A normalized product must be sorted and must not contain other products.
428    /// An aggregated constant, if any, must be at the end.
429    Product {
430        /// Items to multiply by.
431        times: Vec<VarIndex>,
432        /// Items to divide by.
433        by: Vec<VarIndex>,
434    },
435    /// A constant.
436    Const { value: ProcNum },
437    /// Raising a value to a power.
438    Exponentiation {
439        value: VarIndex,
440        exponent: CompExponent,
441    },
442    /// A distance between two points.
443    PointPointDistance { p: VarIndex, q: VarIndex },
444    /// A distance of a point from a line.
445    PointLineDistance { point: VarIndex, line: VarIndex },
446    /// An angle defined by three points (arm, origin, arm)
447    ThreePointAngle {
448        /// Angle's one arm.
449        p: VarIndex,
450        /// Angle's origin
451        q: VarIndex,
452        /// Angle's other arm
453        r: VarIndex,
454    },
455    /// A directed angle defined by three points (arm, origin, arm)
456    ThreePointAngleDir {
457        /// Angle's first arm.
458        p: VarIndex,
459        /// Angle's origin.
460        q: VarIndex,
461        /// Angle's second arm.
462        r: VarIndex,
463    },
464    /// The angle described by two lines.
465    TwoLineAngle { k: VarIndex, l: VarIndex },
466    /// The real part of a point.
467    PointX { point: VarIndex },
468    /// The imaginary part of a point.
469    PointY { point: VarIndex },
470    /// Convert a point to a complex number (no-op)
471    PointToComplex { point: VarIndex },
472    /// Real part of a number
473    Real { number: VarIndex },
474    /// Imaginary part of a number
475    Imaginary { number: VarIndex },
476    /// Natural logarithm
477    Log { number: VarIndex },
478    /// Exponential function (e^this)
479    Exp { number: VarIndex },
480    /// Sine of an angle
481    Sin { angle: VarIndex },
482    /// Cosine of an angle
483    Cos { angle: VarIndex },
484    /// Arctan2 function
485    Atan2 { y: VarIndex, x: VarIndex },
486    /// Line's direction vector
487    DirectionVector { line: VarIndex },
488
489    // Line
490    /// A line through two points.
491    PointPoint { p: VarIndex, q: VarIndex },
492    /// The angle bisector line.
493    AngleBisector {
494        p: VarIndex,
495        q: VarIndex,
496        r: VarIndex,
497    },
498    /// A line parallel to another `line` going through a `point`
499    ParallelThrough { point: VarIndex, line: VarIndex },
500    /// A line perpendicular to another `line` going through a `point`
501    PerpendicularThrough { point: VarIndex, line: VarIndex },
502    /// A line made from a point and a direction vector
503    PointVector { point: VarIndex, vector: VarIndex },
504
505    // Circle
506    /// A circle constructed from its center and radius.
507    ConstructCircle { center: VarIndex, radius: VarIndex },
508}
509
510impl ExprKind {
511    /// Get the id of a variant. Used for comparison
512    #[must_use]
513    pub fn variant_id(&self) -> usize {
514        match self {
515            Self::Entity { .. } => 0,
516            Self::Const { .. } => 1,
517            Self::LineLineIntersection { .. } => 2,
518            Self::AveragePoint { .. } => 3,
519            Self::CircleCenter { .. } => 4,
520            Self::Sum { .. } => 5,
521            Self::Product { .. } => 6,
522            Self::Exponentiation { .. } => 7,
523            Self::PointPointDistance { .. } => 8,
524            Self::PointLineDistance { .. } => 9,
525            Self::ThreePointAngle { .. } => 10,
526            Self::ThreePointAngleDir { .. } => 11,
527            Self::TwoLineAngle { .. } => 12,
528            Self::PointX { .. } => 13,
529            Self::PointY { .. } => 14,
530            Self::PointPoint { .. } => 15,
531            Self::AngleBisector { .. } => 16,
532            Self::ParallelThrough { .. } => 17,
533            Self::PerpendicularThrough { .. } => 18,
534            Self::ConstructCircle { .. } => 19,
535            Self::PointToComplex { .. } => 20,
536            Self::ComplexToPoint { .. } => 21,
537            Self::Real { .. } => 22,
538            Self::Imaginary { .. } => 23,
539            Self::Sin { .. } => 24,
540            Self::Cos { .. } => 25,
541            Self::Atan2 { .. } => 26,
542            Self::Log { .. } => 27,
543            Self::Exp { .. } => 28,
544            Self::DirectionVector { .. } => 29,
545            Self::PointVector { .. } => 30,
546        }
547    }
548
549    /// Compare two expressions.
550    #[must_use]
551    #[allow(clippy::too_many_lines)]
552    pub fn compare(&self, other: &Self, math: &Math) -> Ordering {
553        self.variant_id()
554            .cmp(&other.variant_id())
555            .then_with(|| match (self, other) {
556                (Self::Entity { id: self_id }, Self::Entity { id: other_id }) => {
557                    self_id.cmp(other_id)
558                }
559                (
560                    Self::LineLineIntersection {
561                        k: self_a,
562                        l: self_b,
563                    },
564                    Self::LineLineIntersection {
565                        k: other_a,
566                        l: other_b,
567                    },
568                )
569                | (
570                    Self::PointPointDistance {
571                        p: self_a,
572                        q: self_b,
573                    },
574                    Self::PointPointDistance {
575                        p: other_a,
576                        q: other_b,
577                    },
578                )
579                | (
580                    Self::PointLineDistance {
581                        point: self_a,
582                        line: self_b,
583                    },
584                    Self::PointLineDistance {
585                        point: other_a,
586                        line: other_b,
587                    },
588                )
589                | (
590                    Self::TwoLineAngle {
591                        k: self_a,
592                        l: self_b,
593                    },
594                    Self::TwoLineAngle {
595                        k: other_a,
596                        l: other_b,
597                    },
598                )
599                | (
600                    Self::PointPoint {
601                        p: self_a,
602                        q: self_b,
603                    },
604                    Self::PointPoint {
605                        p: other_a,
606                        q: other_b,
607                    },
608                )
609                | (
610                    Self::ParallelThrough {
611                        point: self_a,
612                        line: self_b,
613                    },
614                    Self::ParallelThrough {
615                        point: other_a,
616                        line: other_b,
617                    },
618                )
619                | (
620                    Self::PerpendicularThrough {
621                        point: self_a,
622                        line: self_b,
623                    },
624                    Self::PerpendicularThrough {
625                        point: other_a,
626                        line: other_b,
627                    },
628                )
629                | (
630                    Self::ConstructCircle {
631                        center: self_a,
632                        radius: self_b,
633                    },
634                    Self::ConstructCircle {
635                        center: other_a,
636                        radius: other_b,
637                    },
638                ) => self_a
639                    .compare(other_a, math)
640                    .then_with(|| self_b.compare(other_b, math)),
641                (
642                    Self::AveragePoint { items: self_items },
643                    Self::AveragePoint { items: other_items },
644                ) => self_items.compare(other_items, math),
645                (Self::CircleCenter { circle: self_x }, Self::CircleCenter { circle: other_x })
646                | (Self::PointX { point: self_x }, Self::PointX { point: other_x })
647                | (Self::PointY { point: self_x }, Self::PointY { point: other_x }) => {
648                    self_x.compare(other_x, math)
649                }
650                (
651                    Self::Sum {
652                        plus: self_v,
653                        minus: self_u,
654                    },
655                    Self::Sum {
656                        plus: other_v,
657                        minus: other_u,
658                    },
659                )
660                | (
661                    Self::Product {
662                        times: self_v,
663                        by: self_u,
664                    },
665                    Self::Product {
666                        times: other_v,
667                        by: other_u,
668                    },
669                ) => self_v
670                    .compare(other_v, math)
671                    .then_with(|| self_u.compare(other_u, math)),
672                (Self::Const { value: self_v }, Self::Const { value: other_v }) => {
673                    self_v.cmp(other_v)
674                }
675                (
676                    Self::Exponentiation {
677                        value: self_v,
678                        exponent: self_exp,
679                    },
680                    Self::Exponentiation {
681                        value: other_v,
682                        exponent: other_exp,
683                    },
684                ) => self_v
685                    .compare(other_v, math)
686                    .then_with(|| self_exp.cmp(other_exp)),
687                (
688                    Self::ThreePointAngle {
689                        p: self_p,
690                        q: self_q,
691                        r: self_r,
692                    },
693                    Self::ThreePointAngle {
694                        p: other_p,
695                        q: other_q,
696                        r: other_r,
697                    },
698                )
699                | (
700                    Self::ThreePointAngleDir {
701                        p: self_p,
702                        q: self_q,
703                        r: self_r,
704                    },
705                    Self::ThreePointAngleDir {
706                        p: other_p,
707                        q: other_q,
708                        r: other_r,
709                    },
710                )
711                | (
712                    Self::AngleBisector {
713                        p: self_p,
714                        q: self_q,
715                        r: self_r,
716                    },
717                    Self::AngleBisector {
718                        p: other_p,
719                        q: other_q,
720                        r: other_r,
721                    },
722                ) => self_p
723                    .compare(other_p, math)
724                    .then_with(|| self_q.compare(other_q, math))
725                    .then_with(|| self_r.compare(other_r, math)),
726                (_, _) => Ordering::Equal,
727            })
728    }
729
730    /// Get the expression's type.
731    #[must_use]
732    pub fn get_type<M>(&self, expressions: &[Expr<M>], entities: &[Entity<M>]) -> ExprType {
733        match self {
734            Self::Entity { id } => entities[id.0].get_type(expressions, entities),
735            Self::LineLineIntersection { .. }
736            | Self::AveragePoint { .. }
737            | Self::CircleCenter { .. }
738            | Self::ComplexToPoint { .. } => ExprType::Point,
739            Self::Sum { .. }
740            | Self::Product { .. }
741            | Self::Const { .. }
742            | Self::Exponentiation { .. }
743            | Self::PointPointDistance { .. }
744            | Self::PointLineDistance { .. }
745            | Self::ThreePointAngle { .. }
746            | Self::ThreePointAngleDir { .. }
747            | Self::TwoLineAngle { .. }
748            | Self::Sin { .. }
749            | Self::Cos { .. }
750            | Self::Atan2 { .. }
751            | Self::PointX { .. }
752            | Self::PointY { .. }
753            | Self::Real { .. }
754            | Self::Imaginary { .. }
755            | Self::Log { .. }
756            | Self::Exp { .. }
757            | Self::DirectionVector { .. }
758            | Self::PointToComplex { .. } => ExprType::Number,
759            Self::PointPoint { .. }
760            | Self::AngleBisector { .. }
761            | Self::ParallelThrough { .. }
762            | Self::PerpendicularThrough { .. }
763            | Self::PointVector { .. } => ExprType::Line,
764            Self::ConstructCircle { .. } => ExprType::Circle,
765        }
766    }
767}
768
769impl From<ExprKind> for geo_aid_figure::ExpressionKind {
770    fn from(value: ExprKind) -> Self {
771        match value {
772            ExprKind::Entity { id } => Self::Entity { id },
773            ExprKind::LineLineIntersection { k, l } => Self::LineLineIntersection { k, l },
774            ExprKind::AveragePoint { items } => Self::AveragePoint { items },
775            ExprKind::CircleCenter { circle } => Self::CircleCenter { circle },
776            ExprKind::ComplexToPoint { number } => Self::ComplexToPoint { number },
777            ExprKind::Sum { plus, minus } => Self::Sum { plus, minus },
778            ExprKind::Product { times, by } => Self::Product { times, by },
779            ExprKind::Const { value } => Self::Const {
780                value: value.to_complex().into(),
781            },
782            ExprKind::Exponentiation { value, exponent } => Self::Power {
783                value,
784                exponent: exponent.into(),
785            },
786            ExprKind::PointPointDistance { p, q } => Self::PointPointDistance { p, q },
787            ExprKind::PointLineDistance { point, line } => Self::PointLineDistance { point, line },
788            ExprKind::ThreePointAngle { p, q, r } => Self::ThreePointAngle { a: p, b: q, c: r },
789            ExprKind::ThreePointAngleDir { p, q, r } => {
790                Self::ThreePointAngleDir { a: p, b: q, c: r }
791            }
792            ExprKind::TwoLineAngle { k, l } => Self::TwoLineAngle { k, l },
793            ExprKind::Sin { angle } => Self::Sin { angle },
794            ExprKind::Cos { angle } => Self::Cos { angle },
795            ExprKind::Atan2 { y, x } => Self::Atan2 { y, x },
796            ExprKind::DirectionVector { line } => Self::DirectionVector { line },
797            ExprKind::PointX { point } => Self::PointX { point },
798            ExprKind::PointY { point } => Self::PointY { point },
799            ExprKind::PointToComplex { point } => Self::PointToComplex { point },
800            ExprKind::Real { number } => Self::Real { number },
801            ExprKind::Imaginary { number } => Self::Imaginary { number },
802            ExprKind::Log { number } => Self::Log { number },
803            ExprKind::Exp { number } => Self::Exp { number },
804            ExprKind::PointPoint { p, q } => Self::PointPointLine { p, q },
805            ExprKind::PointVector { point, vector } => Self::PointVectorLine { point, vector },
806            ExprKind::AngleBisector { p, q, r } => Self::AngleBisector { p, q, r },
807            ExprKind::ParallelThrough { point, line } => Self::ParallelThrough { point, line },
808            ExprKind::PerpendicularThrough { point, line } => {
809                Self::PerpendicularThrough { point, line }
810            }
811            ExprKind::ConstructCircle { center, radius } => {
812                Self::ConstructCircle { center, radius }
813            }
814        }
815    }
816}
817
818impl FindEntities for ExprKind {
819    fn find_entities(
820        &self,
821        previous: &[HashSet<EntityId>],
822        entities: &[EntityKind],
823    ) -> HashSet<EntityId> {
824        let mut set = HashSet::new();
825
826        match self {
827            Self::Entity { id } => {
828                set.insert(*id);
829                set.extend(id.find_entities(previous, entities));
830            }
831            Self::AveragePoint { items } => {
832                set.extend(items.iter().flat_map(|x| previous[x.0].iter().copied()));
833            }
834            Self::CircleCenter { circle: x }
835            | Self::PointX { point: x }
836            | Self::PointY { point: x }
837            | Self::Sin { angle: x }
838            | Self::Cos { angle: x }
839            | Self::Exponentiation { value: x, .. }
840            | Self::PointToComplex { point: x }
841            | Self::ComplexToPoint { number: x }
842            | Self::Log { number: x }
843            | Self::Exp { number: x }
844            | Self::DirectionVector { line: x }
845            | Self::Real { number: x }
846            | Self::Imaginary { number: x } => {
847                set.extend(previous[x.0].iter().copied());
848            }
849            Self::Sum {
850                plus: v1,
851                minus: v2,
852            }
853            | Self::Product { times: v1, by: v2 } => {
854                set.extend(v1.iter().flat_map(|x| previous[x.0].iter().copied()));
855                set.extend(v2.iter().flat_map(|x| previous[x.0].iter().copied()));
856            }
857            Self::PointPointDistance { p: a, q: b }
858            | Self::PointLineDistance { point: a, line: b }
859            | Self::TwoLineAngle { k: a, l: b }
860            | Self::LineLineIntersection { k: a, l: b }
861            | Self::ParallelThrough { point: a, line: b }
862            | Self::PerpendicularThrough { point: a, line: b }
863            | Self::PointPoint { p: a, q: b }
864            | Self::Atan2 { y: a, x: b }
865            | Self::PointVector {
866                point: a,
867                vector: b,
868            }
869            | Self::ConstructCircle {
870                center: a,
871                radius: b,
872            } => {
873                set.extend(previous[a.0].iter().copied());
874                set.extend(previous[b.0].iter().copied());
875            }
876            Self::ThreePointAngle { p, q, r }
877            | Self::ThreePointAngleDir { p, q, r }
878            | Self::AngleBisector { p, q, r } => {
879                set.extend(previous[p.0].iter().copied());
880                set.extend(previous[q.0].iter().copied());
881                set.extend(previous[r.0].iter().copied());
882            }
883            Self::Const { .. } => {}
884        }
885
886        set
887    }
888}
889
890impl Default for ExprKind {
891    fn default() -> Self {
892        Self::Entity { id: EntityId(0) }
893    }
894}
895
896impl FromUnrolled<UnrolledPoint> for ExprKind {
897    fn load(expr: &Unrolled<UnrolledPoint>, math: &mut Expand) -> Self {
898        let mut kind = match expr.get_data() {
899            UnrolledPoint::LineLineIntersection(a, b) => ExprKind::LineLineIntersection {
900                k: math.load(a),
901                l: math.load(b),
902            },
903            UnrolledPoint::Average(exprs) => ExprKind::AveragePoint {
904                items: exprs.iter().map(|x| math.load(x)).collect(),
905            },
906            UnrolledPoint::CircleCenter(circle) => match circle.get_data() {
907                UnrolledCircle::Circle(center, _) => return math.load_no_store(center),
908                UnrolledCircle::Generic(_) => unreachable!(),
909            },
910            UnrolledPoint::Free => ExprKind::Entity {
911                id: math.add_point(),
912            },
913            UnrolledPoint::FromComplex(number) => ExprKind::ComplexToPoint {
914                number: math.load(number),
915            },
916            UnrolledPoint::Generic(g) => {
917                unreachable!("Expression shouldn't have reached math stage: {g}")
918            }
919        };
920
921        kind.normalize(math);
922        kind
923    }
924}
925
926impl FromUnrolled<unroll::Number> for ExprKind {
927    fn load(expr: &Unrolled<unroll::Number>, math: &mut Expand) -> Self {
928        let mut kind = match &expr.get_data().data {
929            UnrolledNumber::Add(a, b) => ExprKind::Sum {
930                plus: vec![math.load(a), math.load(b)],
931                minus: Vec::new(),
932            },
933            UnrolledNumber::Subtract(a, b) => ExprKind::Sum {
934                plus: vec![math.load(a)],
935                minus: vec![math.load(b)],
936            },
937            UnrolledNumber::Multiply(a, b) => ExprKind::Product {
938                times: vec![math.load(a), math.load(b)],
939                by: Vec::new(),
940            },
941            UnrolledNumber::Divide(a, b) => ExprKind::Product {
942                times: vec![math.load(a)],
943                by: vec![math.load(b)],
944            },
945            UnrolledNumber::Average(exprs) => {
946                let times = ExprKind::Sum {
947                    plus: exprs.iter().map(|x| math.load(x)).collect(),
948                    minus: Vec::new(),
949                };
950                let by = ExprKind::Const {
951                    value: ProcNum::from_usize(exprs.len()).unwrap(),
952                };
953
954                ExprKind::Product {
955                    times: vec![math.store(times, ExprType::Number)],
956                    by: vec![math.store(by, ExprType::Number)],
957                }
958            }
959            UnrolledNumber::CircleRadius(circle) => match circle.get_data() {
960                UnrolledCircle::Circle(_, radius) => math.load_no_store(radius),
961                UnrolledCircle::Generic(_) => unreachable!(),
962            },
963            UnrolledNumber::Free => ExprKind::Entity {
964                id: math.add_real(),
965            },
966            UnrolledNumber::Number(x) => {
967                return fix_dst(ExprKind::Const { value: x.clone() }, expr.data.unit, math);
968            }
969            UnrolledNumber::DstLiteral(x) => ExprKind::Const { value: x.clone() },
970            UnrolledNumber::SetUnit(x, unit) => {
971                return fix_dst(math.load_no_store(x), Some(*unit), math)
972            }
973            UnrolledNumber::PointPointDistance(p, q) => ExprKind::PointPointDistance {
974                p: math.load(p),
975                q: math.load(q),
976            },
977            UnrolledNumber::PointLineDistance(point, line) => ExprKind::PointLineDistance {
978                point: math.load(point),
979                line: math.load(line),
980            },
981            UnrolledNumber::Negate(x) => ExprKind::Sum {
982                plus: Vec::new(),
983                minus: vec![math.load(x)],
984            },
985            UnrolledNumber::ThreePointAngle(p, q, r) => ExprKind::ThreePointAngle {
986                p: math.load(p),
987                q: math.load(q),
988                r: math.load(r),
989            },
990            UnrolledNumber::ThreePointAngleDir(p, q, r) => ExprKind::ThreePointAngleDir {
991                p: math.load(p),
992                q: math.load(q),
993                r: math.load(r),
994            },
995            UnrolledNumber::TwoLineAngle(k, l) => ExprKind::TwoLineAngle {
996                k: math.load(k),
997                l: math.load(l),
998            },
999            UnrolledNumber::Pow(base, exponent) => ExprKind::Exponentiation {
1000                value: math.load(base),
1001                exponent: *exponent,
1002            },
1003            UnrolledNumber::PointX(point) => ExprKind::PointX {
1004                point: math.load(point),
1005            },
1006            UnrolledNumber::PointY(point) => ExprKind::PointY {
1007                point: math.load(point),
1008            },
1009            UnrolledNumber::Sin(angle) => ExprKind::Sin {
1010                angle: math.load(angle),
1011            },
1012            UnrolledNumber::Cos(angle) => ExprKind::Cos {
1013                angle: math.load(angle),
1014            },
1015            UnrolledNumber::Atan2(y, x) => ExprKind::Atan2 {
1016                y: math.load(y),
1017                x: math.load(x),
1018            },
1019            UnrolledNumber::Direction(line) => ExprKind::DirectionVector {
1020                line: math.load(line),
1021            },
1022            UnrolledNumber::FromPoint(point) => ExprKind::PointToComplex {
1023                point: math.load(point),
1024            },
1025            UnrolledNumber::Real(number) => ExprKind::Real {
1026                number: math.load(number),
1027            },
1028            UnrolledNumber::Imaginary(number) => ExprKind::Imaginary {
1029                number: math.load(number),
1030            },
1031            UnrolledNumber::Log(number) => ExprKind::Log {
1032                number: math.load(number),
1033            },
1034            UnrolledNumber::Exp(number) => ExprKind::Exp {
1035                number: math.load(number),
1036            },
1037            UnrolledNumber::Generic(_) => unreachable!(),
1038        };
1039
1040        kind.normalize(math);
1041        kind
1042    }
1043}
1044
1045impl FromUnrolled<UnrolledLine> for ExprKind {
1046    fn load(expr: &Unrolled<UnrolledLine>, math: &mut Expand) -> Self {
1047        let mut kind = match expr.get_data() {
1048            UnrolledLine::LineFromPoints(a, b) => Self::PointPoint {
1049                p: math.load(a),
1050                q: math.load(b),
1051            },
1052            UnrolledLine::AngleBisector(a, b, c) => Self::AngleBisector {
1053                p: math.load(a),
1054                q: math.load(b),
1055                r: math.load(c),
1056            },
1057            UnrolledLine::PerpendicularThrough(k, p) => {
1058                // Remove unnecessary intermediates
1059                match k.get_data() {
1060                    UnrolledLine::PerpendicularThrough(l, _) => Self::ParallelThrough {
1061                        point: math.load(p),
1062                        line: math.load(l),
1063                    },
1064                    UnrolledLine::ParallelThrough(l, _) => Self::PerpendicularThrough {
1065                        point: math.load(p),
1066                        line: math.load(l),
1067                    },
1068                    _ => Self::PerpendicularThrough {
1069                        point: math.load(p),
1070                        line: math.load(k),
1071                    },
1072                }
1073            }
1074            UnrolledLine::ParallelThrough(k, p) => {
1075                // Remove unnecessary intermediates
1076                match k.get_data() {
1077                    UnrolledLine::PerpendicularThrough(l, _) => Self::PerpendicularThrough {
1078                        point: math.load(p),
1079                        line: math.load(l),
1080                    },
1081                    UnrolledLine::ParallelThrough(l, _) => Self::ParallelThrough {
1082                        point: math.load(p),
1083                        line: math.load(l),
1084                    },
1085                    _ => Self::ParallelThrough {
1086                        point: math.load(p),
1087                        line: math.load(k),
1088                    },
1089                }
1090            }
1091            UnrolledLine::PointVector(point, vector) => Self::PointVector {
1092                point: math.load(point),
1093                vector: math.load(vector),
1094            },
1095            UnrolledLine::Generic(_) => unreachable!(),
1096        };
1097
1098        kind.normalize(math);
1099        kind
1100    }
1101}
1102
1103impl FromUnrolled<UnrolledCircle> for ExprKind {
1104    fn load(expr: &Unrolled<UnrolledCircle>, math: &mut Expand) -> Self {
1105        let mut kind = match expr.data.as_ref() {
1106            UnrolledCircle::Circle(center, radius) => Self::ConstructCircle {
1107                center: math.load(center),
1108                radius: math.load(radius),
1109            },
1110            UnrolledCircle::Generic(_) => unreachable!(),
1111        };
1112
1113        kind.normalize(math);
1114        kind
1115    }
1116}
1117
1118impl Normalize for ExprKind {
1119    fn normalize(&mut self, math: &mut Math) {
1120        let cmp_and_swap = |a: &mut VarIndex, b: &mut VarIndex| {
1121            if math.compare(a, b) == Ordering::Greater {
1122                mem::swap(a, b);
1123            }
1124        };
1125        let cmp = |a: &VarIndex, b: &VarIndex| math.compare(a, b);
1126        let mut new_self = None;
1127
1128        match self {
1129            Self::CircleCenter { .. }
1130            | Self::PointLineDistance { .. }
1131            | Self::PointX { .. }
1132            | Self::PointY { .. }
1133            | Self::Sin { .. }
1134            | Self::Cos { .. }
1135            | Self::DirectionVector { .. }
1136            | Self::Atan2 { .. }
1137            | Self::Exponentiation { .. }
1138            | Self::ConstructCircle { .. }
1139            | Self::Const { .. }
1140            | Self::ThreePointAngleDir { .. } // DO NOT NORMALIZE DIRECTED ANGLES
1141            | Self::Entity { .. }
1142            | Self::ComplexToPoint { .. }
1143            | Self::PointToComplex { .. }
1144            | Self::Log { .. }
1145            | Self::Exp { .. }
1146            | Self::PointVector { .. }
1147            | Self::Real { .. }
1148            | Self::Imaginary { .. } => (),
1149            Self::LineLineIntersection { k: a, l: b }
1150            | Self::PointPoint { p: a, q: b }
1151            | Self::TwoLineAngle { k: a, l: b }
1152            | Self::AngleBisector { p: a, r: b, .. }
1153            | Self::ThreePointAngle { p: a, r: b, .. }
1154            | Self::PointPointDistance { p: a, q: b } => {
1155                cmp_and_swap(a, b);
1156            }
1157            Self::AveragePoint { items } => {
1158                items.sort_by(&cmp);
1159            }
1160            Self::Sum { plus, minus } => {
1161                normalize_sum(plus, minus, math);
1162                if plus.len() == 1 && minus.is_empty() {
1163                    new_self = Some(math.at(&plus[0]).kind.clone());
1164                }
1165            }
1166            Self::Product { times, by } => {
1167                normalize_product(times, by, math);
1168                if times.len() == 1 && by.is_empty() {
1169                    new_self = Some(math.at(&times[0]).kind.clone());
1170                }
1171            }
1172            Self::ParallelThrough { point, line } => {
1173                // This is technically a move, although ugly, so we clone.
1174                let point = point.clone();
1175                new_self = Some(match &math.at(line).kind {
1176                    Self::ParallelThrough { line, .. } => Self::ParallelThrough { point, line: line.clone() },
1177                    Self::PerpendicularThrough { line, .. } => Self::PerpendicularThrough { point, line: line.clone() },
1178                    _ => Self::ParallelThrough { point, line: line.clone() }
1179                });
1180            }
1181            Self::PerpendicularThrough { point, line } => {
1182                // This is technically a move, although ugly, so we clone.
1183                let point = point.clone();
1184                new_self = Some(match &math.at(line).kind {
1185                    Self::ParallelThrough { line, .. } => Self::PerpendicularThrough { point, line: line.clone() },
1186                    Self::PerpendicularThrough { line, .. } => Self::ParallelThrough { point, line: line.clone() },
1187                    _ => Self::PerpendicularThrough { point, line: line.clone() }
1188                });
1189            }
1190        }
1191
1192        if let Some(new_self) = new_self {
1193            *self = new_self;
1194        }
1195    }
1196}
1197
1198/// Distance units must have special treatment as they need to be represented by a special
1199/// entity. This function makes sure this entity is inserted and raised to the right power
1200/// corresponding to the unit.
1201fn fix_dst(expr: ExprKind, unit: Option<ComplexUnit>, math: &mut Expand) -> ExprKind {
1202    match unit {
1203        None => expr,
1204        Some(unit) => {
1205            if unit.0[SimpleUnit::Distance as usize].is_zero() {
1206                expr
1207            } else {
1208                let dst_var = math.get_dst_var();
1209                ExprKind::Product {
1210                    times: vec![
1211                        math.store(expr, ExprType::Number),
1212                        math.store(
1213                            ExprKind::Exponentiation {
1214                                value: dst_var,
1215                                exponent: unit.0[SimpleUnit::Distance as usize],
1216                            },
1217                            ExprType::Number,
1218                        ),
1219                    ],
1220                    by: Vec::new(),
1221                }
1222            }
1223        }
1224    }
1225}
1226
1227/// A utility iterator for merging two sorted iterators.
1228#[derive(Debug, Clone)]
1229pub struct Merge<T, I, J, F>
1230where
1231    I: Iterator<Item = T>,
1232    J: Iterator<Item = T>,
1233    F: FnMut(&T, &T) -> Ordering,
1234{
1235    i: Peekable<I>,
1236    j: Peekable<J>,
1237    f: F,
1238}
1239
1240impl<T, I: Iterator<Item = T>, J: Iterator<Item = T>, F: FnMut(&T, &T) -> Ordering>
1241    Merge<T, I, J, F>
1242{
1243    /// Create a new merger.
1244    #[must_use]
1245    pub fn new<A: IntoIterator<IntoIter = I>, B: IntoIterator<IntoIter = J>>(
1246        a: A,
1247        b: B,
1248        f: F,
1249    ) -> Self {
1250        Self {
1251            i: a.into_iter().peekable(),
1252            j: b.into_iter().peekable(),
1253            f,
1254        }
1255    }
1256
1257    /// Merge this iterator with another iterator.
1258    #[must_use]
1259    pub fn merge_with<It: IntoIterator<Item = T>>(
1260        self,
1261        other: It,
1262    ) -> Merge<T, Self, It::IntoIter, F>
1263    where
1264        F: Clone,
1265    {
1266        let f_cloned = self.f.clone();
1267        Merge::new(self, other, f_cloned)
1268    }
1269}
1270
1271impl<T, F: FnMut(&T, &T) -> Ordering>
1272    Merge<T, std::option::IntoIter<T>, std::option::IntoIter<T>, F>
1273{
1274    /// Create an empty merge iterator.
1275    #[must_use]
1276    pub fn empty(f: F) -> Self {
1277        Self::new(None, None, f)
1278    }
1279}
1280
1281impl<T, I: Iterator<Item = T>, J: Iterator<Item = T>, F: FnMut(&T, &T) -> Ordering> Iterator
1282    for Merge<T, I, J, F>
1283{
1284    type Item = T;
1285
1286    fn next(&mut self) -> Option<Self::Item> {
1287        if let Some(i_item) = self.i.peek() {
1288            if let Some(j_item) = self.j.peek() {
1289                if (self.f)(i_item, j_item) == Ordering::Less {
1290                    self.i.next()
1291                } else {
1292                    self.j.next()
1293                }
1294            } else {
1295                self.i.next()
1296            }
1297        } else {
1298            self.j.next()
1299        }
1300    }
1301}
1302
1303/// Normalize a sum. Specific normalization rules are given in [`ExprKind::Sum`] documentation.
1304fn normalize_sum(plus: &mut Vec<VarIndex>, minus: &mut Vec<VarIndex>, math: &mut Math) {
1305    let plus_v = mem::take(plus);
1306    let minus_v = mem::take(minus);
1307
1308    let mut constant = ProcNum::zero();
1309
1310    let mut plus_final = Vec::new();
1311    let mut minus_final = Vec::new();
1312
1313    let cmp = |a: &VarIndex, b: &VarIndex| math.compare(a, b);
1314
1315    for item in plus_v {
1316        match &math.at(&item).kind {
1317            ExprKind::Sum { plus, minus } => {
1318                // Yet again, we're technically moving, so we can clone
1319                plus_final = Merge::new(plus_final, plus.iter().cloned(), &cmp).collect();
1320                minus_final = Merge::new(minus_final, minus.iter().cloned(), &cmp).collect();
1321            }
1322            ExprKind::Const { value } => constant += value,
1323            _ => {
1324                plus_final = Merge::new(plus_final, Some(item), &cmp).collect();
1325            }
1326        }
1327    }
1328
1329    for item in minus_v {
1330        match &math.at(&item).kind {
1331            ExprKind::Sum { plus, minus } => {
1332                // Yet again, we're technically moving, so we can clone
1333                plus_final = Merge::new(plus_final, minus.iter().cloned(), &cmp).collect();
1334                minus_final = Merge::new(minus_final, plus.iter().cloned(), &cmp).collect();
1335            }
1336            ExprKind::Const { value } => constant -= value,
1337            _ => {
1338                minus_final = Merge::new(minus_final, Some(item), &cmp).collect();
1339            }
1340        }
1341    }
1342
1343    if !constant.is_zero() || (plus_final.is_empty() && minus_final.is_empty()) {
1344        plus_final.push(math.store(ExprKind::Const { value: constant }, ExprType::Number));
1345    }
1346
1347    *plus = plus_final;
1348    *minus = minus_final;
1349}
1350
1351/// Normalize a sum. Specific normalization rules are given in [`ExprKindi::Product`] documentation.
1352fn normalize_product(times: &mut Vec<VarIndex>, by: &mut Vec<VarIndex>, math: &mut Math) {
1353    let times_v = mem::take(times);
1354    let by_v = mem::take(by);
1355
1356    let mut constant = ProcNum::one();
1357
1358    let mut times_final = Vec::new();
1359    let mut by_final = Vec::new();
1360
1361    let cmp = |a: &VarIndex, b: &VarIndex| math.compare(a, b);
1362
1363    for item in times_v {
1364        match &math.at(&item).kind {
1365            ExprKind::Product { times, by } => {
1366                // Yet again, we're technically moving, so we can clone
1367                times_final = Merge::new(times_final, times.iter().cloned(), &cmp).collect();
1368                by_final = Merge::new(by_final, by.iter().cloned(), &cmp).collect();
1369            }
1370            ExprKind::Const { value } => constant *= value,
1371            _ => {
1372                times_final = Merge::new(times_final, Some(item), &cmp).collect();
1373            }
1374        }
1375    }
1376
1377    for item in by_v {
1378        match &math.at(&item).kind {
1379            ExprKind::Product { times, by } => {
1380                // Yet again, we're technically moving, so we can clone
1381                times_final = Merge::new(times_final, by.iter().cloned(), &cmp).collect();
1382                by_final = Merge::new(by_final, times.iter().cloned(), &cmp).collect();
1383            }
1384            ExprKind::Const { value } => constant /= value,
1385            _ => {
1386                by_final = Merge::new(by_final, Some(item), &cmp).collect();
1387            }
1388        }
1389    }
1390
1391    if !constant.is_one() || (times_final.is_empty() && by_final.is_empty()) {
1392        times_final.push(math.store(ExprKind::Const { value: constant }, ExprType::Number));
1393    }
1394
1395    *times = times_final;
1396    *by = by_final;
1397}
1398
1399/// An expression with some metadata and a type.
1400#[derive(Debug, Clone, PartialEq, Eq, Default, Hash, Serialize)]
1401pub struct Expr<M> {
1402    pub meta: M,
1403    pub kind: ExprKind,
1404    pub ty: ExprType,
1405}
1406
1407impl<M> Expr<M> {
1408    /// Get the expression's type.
1409    #[must_use]
1410    pub fn get_type(&self, expressions: &[Expr<M>], entities: &[Entity<M>]) -> ExprType {
1411        self.kind.get_type(expressions, entities)
1412    }
1413}
1414
1415impl<M> FindEntities for Expr<M> {
1416    fn find_entities(
1417        &self,
1418        previous: &[HashSet<EntityId>],
1419        entities: &[EntityKind],
1420    ) -> HashSet<EntityId> {
1421        self.kind.find_entities(previous, entities)
1422    }
1423}
1424
1425impl<M> Reindex for Expr<M> {
1426    fn reindex(&mut self, map: &IndexMap) {
1427        self.kind.reindex(map);
1428    }
1429}
1430
1431impl<M> ContainsEntity for Expr<M> {
1432    fn contains_entity(&self, entity: EntityId, math: &Math) -> bool {
1433        self.kind.contains_entity(entity, math)
1434    }
1435}
1436
1437impl<M> Normalize for Expr<M> {
1438    fn normalize(&mut self, math: &mut Math) {
1439        self.kind.normalize(math);
1440    }
1441}
1442
1443impl Expr<()> {
1444    /// Create a new expression with a kind and a type.
1445    #[must_use]
1446    pub fn new(kind: ExprKind, ty: ExprType) -> Self {
1447        Self { kind, meta: (), ty }
1448    }
1449}
1450
1451/// Represents a rule of the figure.
1452/// Rules are normalized iff:
1453/// * their operands are normalized
1454/// * their operands are sorted in ascending order
1455#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Recursive)]
1456#[recursive(
1457    impl ContainsEntity for Self {
1458        fn contains_entity(&self, entity: EntityId, math: &Math) -> bool {
1459            aggregate = ||,
1460            init = false
1461        }
1462    }
1463)]
1464#[recursive(
1465    impl Reconstruct for Self {
1466        fn reconstruct(self, ctx: &mut ReconstructCtx) -> Self {
1467            aggregate = {}
1468        }
1469    }
1470)]
1471pub enum RuleKind {
1472    /// Equality of two points (distance of 0)
1473    PointEq(VarIndex, VarIndex),
1474    /// Equality of two numbers
1475    NumberEq(VarIndex, VarIndex),
1476    /// a > b
1477    Gt(VarIndex, VarIndex),
1478    /// At least one of the rules must be satisfied
1479    Alternative(Vec<RuleKind>),
1480    /// The inverse of a rule
1481    Invert(Box<RuleKind>),
1482    /// A special bias rule for making the entity more stable in certain engines.
1483    Bias,
1484}
1485
1486impl Display for RuleKind {
1487    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1488        match self {
1489            RuleKind::PointEq(a, b) | RuleKind::NumberEq(a, b) => write!(f, "{a} = {b}"),
1490            RuleKind::Gt(a, b) => write!(f, "{a} > {b}"),
1491            RuleKind::Alternative(v) => {
1492                for kind in v {
1493                    write!(f, "| {kind}")?;
1494                }
1495
1496                Ok(())
1497            }
1498            RuleKind::Invert(v) => write!(f, "not {v}"),
1499            RuleKind::Bias => write!(f, "bias"),
1500        }
1501    }
1502}
1503
1504impl FindEntities for RuleKind {
1505    #[allow(clippy::only_used_in_recursion)]
1506    fn find_entities(
1507        &self,
1508        previous: &[HashSet<EntityId>],
1509        entities: &[EntityKind],
1510    ) -> HashSet<EntityId> {
1511        let mut set = HashSet::new();
1512
1513        match self {
1514            Self::PointEq(a, b) | Self::NumberEq(a, b) | Self::Gt(a, b) => {
1515                set.extend(previous[a.0].iter().copied());
1516                set.extend(previous[b.0].iter().copied());
1517            }
1518            Self::Alternative(items) => {
1519                return items
1520                    .iter()
1521                    .flat_map(|x| x.find_entities(previous, entities).into_iter())
1522                    .collect();
1523            }
1524            Self::Invert(rule) => {
1525                return rule.find_entities(previous, entities);
1526            }
1527            Self::Bias => unreachable!(),
1528        }
1529
1530        set
1531    }
1532}
1533
1534/// A rule along if its aggregated information. Note that `entities`
1535/// is not filled until the final steps of compilation.
1536#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
1537pub struct Rule {
1538    /// The kind of this rule
1539    pub kind: RuleKind,
1540    /// The rule's weight
1541    pub weight: ProcNum,
1542    /// Entities this rule affects.
1543    pub entities: Vec<EntityId>,
1544}
1545
1546impl Display for Rule {
1547    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1548        write!(f, "{}", self.kind)
1549    }
1550}
1551
1552impl ContainsEntity for Rule {
1553    fn contains_entity(&self, entity: EntityId, math: &Math) -> bool {
1554        self.kind.contains_entity(entity, math)
1555    }
1556}
1557
1558impl Reconstruct for Rule {
1559    fn reconstruct(self, ctx: &mut ReconstructCtx) -> Self {
1560        Self {
1561            kind: self.kind.reconstruct(ctx),
1562            ..self
1563        }
1564    }
1565}
1566
1567impl Reindex for RuleKind {
1568    fn reindex(&mut self, map: &IndexMap) {
1569        match self {
1570            Self::PointEq(a, b) | Self::NumberEq(a, b) | Self::Gt(a, b) => {
1571                a.reindex(map);
1572                b.reindex(map);
1573            }
1574            Self::Alternative(items) => {
1575                items.reindex(map);
1576            }
1577            Self::Invert(rule) => rule.reindex(map),
1578            Self::Bias => {}
1579        }
1580    }
1581}
1582
1583impl Reindex for Rule {
1584    fn reindex(&mut self, map: &IndexMap) {
1585        self.kind.reindex(map);
1586    }
1587}
1588
1589impl RuleKind {
1590    /// Load a rule from an [`UnrolledRule`].
1591    ///
1592    /// # Returns
1593    /// A normalized rule.
1594    fn load(rule: &UnrolledRule, math: &mut Expand) -> Self {
1595        let mut mathed = match &rule.kind {
1596            UnrolledRuleKind::PointEq(a, b) => Self::PointEq(math.load(a), math.load(b)),
1597            UnrolledRuleKind::NumberEq(a, b) => Self::NumberEq(math.load(a), math.load(b)),
1598            UnrolledRuleKind::Gt(a, b) => Self::Gt(math.load(a), math.load(b)),
1599            UnrolledRuleKind::Alternative(rules) => {
1600                Self::Alternative(rules.iter().map(|x| Self::load(x, math)).collect())
1601            }
1602            UnrolledRuleKind::Bias(_) => Self::Bias,
1603        };
1604
1605        mathed.normalize(math);
1606
1607        if rule.inverted {
1608            Self::Invert(Box::new(mathed))
1609        } else {
1610            mathed
1611        }
1612    }
1613}
1614
1615impl Rule {
1616    /// Load a rule from an [`UnrolledRule`].
1617    ///
1618    /// # Returns
1619    /// A normalized rule.
1620    fn load(rule: &UnrolledRule, math: &mut Expand) -> Self {
1621        Self {
1622            kind: RuleKind::load(rule, math),
1623            weight: rule.weight.clone(),
1624            entities: Vec::new(),
1625        }
1626    }
1627}
1628
1629impl Normalize for RuleKind {
1630    fn normalize(&mut self, math: &mut Math) {
1631        match self {
1632            Self::PointEq(a, b) | Self::NumberEq(a, b) => {
1633                if math.compare(a, b) == Ordering::Greater {
1634                    mem::swap(a, b);
1635                }
1636            }
1637            Self::Alternative(v) => v.sort(),
1638            Self::Invert(_) | Self::Bias | Self::Gt(_, _) => (),
1639        }
1640    }
1641}
1642
1643impl Normalize for Rule {
1644    fn normalize(&mut self, math: &mut Math) {
1645        self.kind.normalize(math);
1646    }
1647}
1648
1649/// The adjusted (optimized) part of IR.
1650#[derive(Debug)]
1651pub struct Adjusted {
1652    /// Expressions needed for rules and entities.
1653    pub variables: Vec<Expr<()>>,
1654    /// Rules binding the figure
1655    pub rules: Vec<Rule>,
1656    /// Entities of the figure.
1657    pub entities: Vec<EntityKind>,
1658}
1659
1660/// The full Math IR, the ultimate result of the entire compiler
1661#[derive(Debug)]
1662pub struct Intermediate {
1663    /// A figure IR.
1664    pub figure: Figure,
1665    /// The adjusted, later optimized part.
1666    pub adjusted: Adjusted,
1667    /// Compiler flags.
1668    pub flags: Flags,
1669}
1670
1671/// An entity along with some metadata.
1672#[derive(Debug, Clone, Serialize)]
1673pub struct Entity<M> {
1674    pub kind: EntityKind,
1675    pub meta: M,
1676}
1677
1678impl<M> Entity<M> {
1679    /// Get the entity's type.
1680    #[must_use]
1681    pub fn get_type(&self, expressions: &[Expr<M>], entities: &[Entity<M>]) -> ExprType {
1682        match &self.kind {
1683            EntityKind::FreePoint
1684            | EntityKind::PointOnLine { .. }
1685            | EntityKind::PointOnCircle { .. } => ExprType::Point,
1686            EntityKind::FreeReal | EntityKind::DistanceUnit => ExprType::Number,
1687            EntityKind::Bind(expr) => expressions[expr.0].get_type(expressions, entities),
1688        }
1689    }
1690}
1691
1692/// The kind of an entity.
1693#[derive(Debug, Clone, Recursive, Serialize)]
1694#[recursive(
1695    impl ContainsEntity for Self {
1696        fn contains_entity(&self, entity: EntityId, math: &Math) -> bool {
1697            aggregate = ||,
1698            init = false
1699        }
1700    }
1701)]
1702#[recursive(
1703    impl Reconstruct for Self {
1704        fn reconstruct(self, ctx: &mut ReconstructCtx) -> Self {
1705            aggregate = {}
1706        }
1707    }
1708)]
1709pub enum EntityKind {
1710    /// A free point with two degrees of freedom.
1711    FreePoint,
1712    /// A point bound to a specific line, only one relative degree of freedom.
1713    PointOnLine { line: VarIndex },
1714    /// A point bound to a specific circle, only one relative degree of freedom.
1715    PointOnCircle { circle: VarIndex },
1716    /// A free real with one degree of freedom
1717    FreeReal,
1718    /// A special distance unit entity, effectively a free real.
1719    DistanceUnit,
1720    /// A bind. Never shows up past the compilation stage. It serves as a temporary
1721    /// value in-between compilation steps.
1722    Bind(VarIndex),
1723}
1724
1725impl FindEntities for EntityKind {
1726    fn find_entities(
1727        &self,
1728        previous: &[HashSet<EntityId>],
1729        _entities: &[EntityKind],
1730    ) -> HashSet<EntityId> {
1731        match self {
1732            Self::PointOnLine { line: var } | Self::PointOnCircle { circle: var } => {
1733                previous[var.0].clone()
1734            }
1735            Self::FreePoint | Self::FreeReal | Self::DistanceUnit => HashSet::new(),
1736            Self::Bind(_) => unreachable!(),
1737        }
1738    }
1739}
1740
1741impl From<EntityKind> for geo_aid_figure::EntityKind {
1742    fn from(value: EntityKind) -> Self {
1743        match value {
1744            EntityKind::FreePoint => Self::FreePoint,
1745            EntityKind::PointOnLine { line } => Self::PointOnLine { line },
1746            EntityKind::PointOnCircle { circle } => Self::PointOnCircle { circle },
1747            EntityKind::FreeReal => Self::FreeReal,
1748            EntityKind::DistanceUnit => Self::DistanceUnit,
1749            EntityKind::Bind(_) => unreachable!(),
1750        }
1751    }
1752}
1753
1754impl Reindex for EntityKind {
1755    fn reindex(&mut self, map: &IndexMap) {
1756        match self {
1757            Self::FreePoint | Self::DistanceUnit | Self::FreeReal => {}
1758            Self::PointOnLine { line } => line.reindex(map),
1759            Self::PointOnCircle { circle } => circle.reindex(map),
1760            Self::Bind(_) => unreachable!("Should not appear"),
1761        }
1762    }
1763}
1764
1765impl Reindex for EntityId {
1766    fn reindex(&mut self, _map: &IndexMap) {}
1767}
1768
1769impl DeepClone for EntityId {
1770    fn deep_clone(&self, _math: &mut Math) -> Self {
1771        // The clone is not THAT deep
1772        *self
1773    }
1774}
1775
1776impl ContainsEntity for EntityId {
1777    fn contains_entity(&self, entity: EntityId, math: &Math) -> bool {
1778        *self == entity || math.entities[self.0].contains_entity(entity, math)
1779    }
1780}
1781
1782/// A context struct for loading unrolled data.
1783#[derive(Debug, Clone, Default)]
1784pub struct Expand {
1785    /// Expressions are mapped to the record entries.
1786    pub expr_map: HashMap<usize, Expr<()>>,
1787    /// Processing context
1788    pub math: Math,
1789    /// Used to keep pointers with the data we need alive, so that no memory issues occur.
1790    /// Normally, an address add to `expr_map` as a key had the risk of expiring and a collision
1791    /// occurring. This way, this should be prevented. It will also increase memory usage, but shhh.
1792    /// It's an ugly solution, but it works. I'm most likely going to come back to this one with some
1793    /// new ideas for solving the issue.
1794    pub rc_keepalive: Vec<Rc<dyn Any>>,
1795}
1796
1797impl Deref for Expand {
1798    type Target = Math;
1799
1800    fn deref(&self) -> &Self::Target {
1801        &self.math
1802    }
1803}
1804
1805impl DerefMut for Expand {
1806    fn deref_mut(&mut self) -> &mut Self::Target {
1807        &mut self.math
1808    }
1809}
1810
1811/// A math stage context used across the whole compilation process.
1812#[derive(Debug, Clone, Default)]
1813pub struct Math {
1814    /// All figure's entities
1815    pub entities: Vec<EntityKind>,
1816    /// Distance unit, if exists.
1817    pub dst_var: OnceCell<EntityId>,
1818    /// Collected expressions in flattened layout.
1819    pub expr_record: Vec<Expr<()>>,
1820}
1821
1822impl Expand {
1823    /// Load an unrolled expression. Also stores it in the variable record.
1824    pub fn load<T: Displayed + GetMathType + Debug + GetData + 'static>(
1825        &mut self,
1826        unrolled: &Unrolled<T>,
1827    ) -> VarIndex
1828    where
1829        ExprKind: FromUnrolled<T>,
1830    {
1831        let expr = self.load_no_store(unrolled);
1832        self.store(expr, T::get_math_type())
1833    }
1834
1835    /// Load an unrolled expression without storing it.
1836    pub fn load_no_store<T: Displayed + GetMathType + GetData + Debug + 'static>(
1837        &mut self,
1838        unrolled: &Unrolled<T>,
1839    ) -> ExprKind
1840    where
1841        ExprKind: FromUnrolled<T>,
1842    {
1843        // Keep the smart pointer inside `unrolled` alive.
1844        self.rc_keepalive
1845            .push(Rc::clone(&unrolled.data) as Rc<dyn Any>);
1846
1847        let key = std::ptr::from_ref(unrolled.get_data()) as usize;
1848        let loaded = self.expr_map.get(&key).cloned();
1849
1850        if let Some(loaded) = loaded {
1851            loaded.kind.deep_clone(self)
1852        } else {
1853            // If expression has not been mathed yet, math it and put it into the record.
1854            let loaded = ExprKind::load(unrolled, self);
1855            self.expr_map
1856                .insert(key, Expr::new(loaded.clone(), T::get_math_type()));
1857            loaded
1858        }
1859    }
1860}
1861
1862impl Math {
1863    /// Store an expression in the expression record.
1864    #[must_use]
1865    pub fn store(&mut self, expr: ExprKind, ty: ExprType) -> VarIndex {
1866        self.expr_record.push(Expr::new(expr, ty));
1867        VarIndex(self.expr_record.len() - 1)
1868    }
1869
1870    /// Compare two expressions referenced by indices.
1871    #[must_use]
1872    pub fn compare(&self, a: &VarIndex, b: &VarIndex) -> Ordering {
1873        self.at(a).kind.compare(&self.at(b).kind, self)
1874    }
1875
1876    /// Get the distance unit and generate it if it doesn't exist.
1877    ///
1878    /// # Panics
1879    /// Will never.
1880    #[must_use]
1881    pub fn get_dst_var(&mut self) -> VarIndex {
1882        let id = self.dst_var.get();
1883        let is_some = id.is_some();
1884
1885        let id = *if is_some {
1886            id.unwrap()
1887        } else {
1888            let real = self.add_entity(EntityKind::DistanceUnit);
1889            self.dst_var.get_or_init(|| real)
1890        };
1891
1892        self.store(ExprKind::Entity { id }, ExprType::Number)
1893    }
1894
1895    /// Get the expression at given index.
1896    #[must_use]
1897    pub fn at(&self, index: &VarIndex) -> &Expr<()> {
1898        &self.expr_record[index.0]
1899    }
1900
1901    /// Add an entity to the record.
1902    fn add_entity(&mut self, entity: EntityKind) -> EntityId {
1903        self.entities.push(entity);
1904        EntityId(self.entities.len() - 1)
1905    }
1906
1907    /// Add a free point entity.
1908    pub fn add_point(&mut self) -> EntityId {
1909        self.add_entity(EntityKind::FreePoint)
1910    }
1911
1912    /// Add a free real entity.
1913    pub fn add_real(&mut self) -> EntityId {
1914        self.add_entity(EntityKind::FreeReal)
1915    }
1916}
1917
1918/// Used explicitly for figure IR building.
1919#[derive(Debug, Clone, Default)]
1920pub struct Build {
1921    /// A loading context for unrolled data.
1922    expand: Expand,
1923    /// Aggregated items to be drawn on the figure.
1924    items: Vec<Item>,
1925}
1926
1927impl Build {
1928    /// Load an unrolled expression.
1929    pub fn load<T: Displayed + GetMathType + Debug + GetData + 'static>(
1930        &mut self,
1931        expr: &Unrolled<T>,
1932    ) -> VarIndex
1933    where
1934        ExprKind: FromUnrolled<T>,
1935    {
1936        self.expand.load(expr)
1937    }
1938
1939    pub fn add<I: Into<Item>>(&mut self, item: I) {
1940        let item = item.into();
1941        self.items.push(item);
1942    }
1943}
1944
1945/// Tries to transform the rules so that they are simpler to process for the generator.
1946///
1947/// # Returns
1948/// `true` if an optimization was performed. `false` otherwise.
1949fn optimize_rules(rules: &mut Vec<Option<Rule>>, math: &mut Math) -> bool {
1950    let mut performed = false;
1951
1952    for rule in rules.iter_mut() {
1953        let rule_performed = ZeroLineDst::process(rule, math)
1954            | RightAngle::process(rule, math)
1955            | EqPointDst::process(rule, math)
1956            | EqExpressions::process(rule, math);
1957
1958        performed |= rule_performed;
1959    }
1960
1961    if performed {
1962        rules.retain(Option::is_some);
1963    }
1964
1965    performed
1966}
1967
1968/// Constructs a map between two sets of numbers A -> B.
1969#[derive(Debug, Clone, Default)]
1970pub struct IndexMap {
1971    /// Consecutive mappings. Not incredibly efficient, but simple and infallible.
1972    mappings: Vec<(usize, usize)>,
1973}
1974
1975impl IndexMap {
1976    /// Create a new, identity map.
1977    #[must_use]
1978    pub fn new() -> Self {
1979        Self {
1980            mappings: Vec::new(),
1981        }
1982    }
1983
1984    /// Get the value `a` maps to.
1985    #[must_use]
1986    pub fn get(&self, mut a: usize) -> usize {
1987        for m in &self.mappings {
1988            if a == m.0 {
1989                a = m.1;
1990            }
1991        }
1992
1993        a
1994    }
1995
1996    /// Creates a mapping from a to b (works like function composition).
1997    pub fn map(&mut self, a: usize, b: usize) {
1998        if a != b {
1999            self.mappings.push((a, b));
2000        }
2001    }
2002
2003    /// Composes two index maps: any get call will be functionally equivalent to self.get(other.get(i)).
2004    pub fn compose(lhs: Self, rhs: &mut Self) {
2005        rhs.mappings.extend(lhs.mappings);
2006    }
2007}
2008
2009/// Helper trait for reindexing process.
2010///
2011/// Reindexing is responsible for eliminating expression duplicates.
2012/// It must be performed on a structure with no forward references.
2013pub trait Reindex {
2014    /// Reindex the expression/rule according to the given map.
2015    fn reindex(&mut self, map: &IndexMap);
2016}
2017
2018impl Reindex for CompExponent {
2019    fn reindex(&mut self, _map: &IndexMap) {}
2020}
2021
2022impl Reindex for ProcNum {
2023    fn reindex(&mut self, _map: &IndexMap) {}
2024}
2025
2026impl<T: Reindex> Reindex for Box<T> {
2027    fn reindex(&mut self, map: &IndexMap) {
2028        self.as_mut().reindex(map);
2029    }
2030}
2031
2032impl<T: Reindex> Reindex for Vec<T> {
2033    fn reindex(&mut self, map: &IndexMap) {
2034        for item in self {
2035            item.reindex(map);
2036        }
2037    }
2038}
2039
2040/// Folds the given matrix:
2041///     - We reindex the element
2042///     - We put it into a hashmap
2043///     - If the element repeats, we replace it and add an entry into the index map
2044///     - Otherwise, we rejoice
2045/// Returns the index map created by the folding.
2046fn fold(matrix: &mut Vec<Expr<()>>) -> IndexMap {
2047    let mut target = Vec::new();
2048    let mut final_map = IndexMap::new();
2049    let mut record = HashMap::new();
2050
2051    loop {
2052        // println!("Folding...");
2053        let mut map = IndexMap::new();
2054        let mut folded = false;
2055        for (i, expr) in matrix.iter_mut().enumerate() {
2056            let mut expr = mem::take(expr);
2057            // print!("Found {expr:?}. Remapping to ");
2058            expr.reindex(&map);
2059            // println!("{expr:?}");
2060            match record.entry(expr) {
2061                hash_map::Entry::Vacant(entry) => {
2062                    target.push(entry.key().clone());
2063                    let new_i = target.len() - 1;
2064                    // println!("Not recorded, mapping {i} -> {new_i}");
2065                    map.map(i, new_i);
2066                    entry.insert(new_i);
2067                }
2068                hash_map::Entry::Occupied(entry) => {
2069                    // We have to update the index map. No push into target happens.
2070                    let j = *entry.get();
2071                    map.map(i, j);
2072                    // println!("Already recorded at {j}. Mapping {i} -> {j}");
2073                    folded = true;
2074                }
2075            }
2076        }
2077        // println!();
2078        // println!("We've build a map: {map:#?}");
2079
2080        // We have to also build the final map.
2081        IndexMap::compose(map, &mut final_map);
2082        // println!("After composition, it became {final_map:#?}");
2083
2084        // Swap target with matrix before next loop.
2085        mem::swap(matrix, &mut target);
2086        // And clear aux variables.
2087        target.clear();
2088        record.clear();
2089
2090        if !folded {
2091            break final_map;
2092        }
2093    }
2094}
2095
2096fn read_flags(flags: &HashMap<&'static str, Flag>) -> Flags {
2097    Flags {
2098        optimizations: Optimizations {},
2099        point_inequalities: flags["point_inequalities"].as_bool().unwrap(),
2100    }
2101}
2102
2103/// Optimize, Normalize, Repeat
2104fn optimize_cycle(rules: &mut Vec<Option<Rule>>, math: &mut Math, items: &mut Vec<Item>) {
2105    let mut entity_map = Vec::new();
2106    loop {
2107        if !optimize_rules(rules, math) {
2108            break;
2109        }
2110
2111        // Create the entity reduction map:
2112        // If an entity is a bind, it should be turned into an expression
2113        // Remaining entities should be appropriately offset.
2114
2115        let mut offset = 0;
2116        entity_map.clear();
2117
2118        // We'll need mutable access to expand in the loop, so we take the entities out of it.
2119        let mut entities = mem::take(&mut math.entities);
2120
2121        for (i, entity) in entities.iter().enumerate() {
2122            entity_map.push(match entity {
2123                EntityKind::FreePoint
2124                | EntityKind::FreeReal
2125                | EntityKind::DistanceUnit
2126                | EntityKind::PointOnCircle { .. }
2127                | EntityKind::PointOnLine { .. } => EntityBehavior::MapEntity(EntityId(i - offset)),
2128                EntityKind::Bind(expr) => {
2129                    offset += 1;
2130                    EntityBehavior::MapVar(expr.clone()) // Technically moving
2131                }
2132            });
2133        }
2134
2135        entities.retain(|x| !matches!(x, EntityKind::Bind(_)));
2136
2137        // The entities are now corrected, but the rules, entities (sic) and variables don't know that.
2138        // The easiest way to fix it is to reconstruct the loaded variable vector recursively.
2139        // This way, we can also fix forward referencing and remove potentially unused expressions.
2140        // We'll also have to update items. Otherwise, they will have misguided indices.
2141
2142        let old_vars = mem::take(&mut math.expr_record);
2143        let mut ctx = ReconstructCtx::new(&entity_map, &old_vars, &entities);
2144        let old_items = mem::take(items);
2145        *items = old_items.reconstruct(&mut ctx);
2146        let old_rules = mem::take(rules);
2147        *rules = old_rules.reconstruct(&mut ctx);
2148        math.expr_record = ctx.new_vars;
2149        math.entities = ctx.new_entities.into_iter().map(Option::unwrap).collect();
2150
2151        // After reconstruction, all forward referencing is gone.
2152
2153        // Normalize expressions. Unfortunately due to borrow checking this requires quite the mess.
2154        // First, we clone the entire expression vector (shallow, this time) and normalize it.
2155        let v = math.expr_record.clone();
2156        for (i, mut var) in v.into_iter().enumerate() {
2157            var.normalize(math);
2158            // Then, set each corresponding var.
2159            math.expr_record[i] = var;
2160        }
2161
2162        // Normalize all rules now
2163        for rule in rules.iter_mut().flatten() {
2164            rule.normalize(math);
2165        }
2166    }
2167}
2168
2169/// Loads a `GeoScript` script and compiles it into Math IR. Encapsulates the entire compiler's
2170/// work.
2171///
2172/// # Errors
2173/// Returns an error if the script is not a valid one.
2174/// Any errors should result from tokenizing, parsing and unrolling, not mathing.
2175pub fn load_script(input: &str) -> Result<Intermediate, Vec<Error>> {
2176    // Unroll script
2177    // Expand rules & figure maximally, normalize them
2178    // ---
2179    // Optimize rules and entities
2180    // Reduce entities
2181    // --- repeat
2182    // Turn entities into adjustables
2183    // Split rules & figure
2184    // Fold rules & figure separately
2185    // Assign reserved registers to figure expressions
2186    // Return
2187
2188    // Unroll script
2189    let (mut unrolled, nodes) = unroll::unroll(input)?;
2190
2191    // for rule in unrolled.rules.borrow().iter() {
2192    //     println!("{rule}");
2193    // }
2194
2195    // Expand & normalize figure
2196    let mut build = Build::default();
2197    Box::new(nodes).build(&mut build);
2198
2199    // Move expand base
2200    let mut expand = build.expand;
2201
2202    // for (i, v) in expand.expr_record.iter().enumerate() {
2203    //     println!("[{i}] = {:?}", v.kind);
2204    // }
2205    //
2206    // println!("{:#?}", build.items);
2207
2208    // Expand & normalize rules
2209    let mut rules = Vec::new();
2210
2211    for rule in unrolled.take_rules() {
2212        rules.push(Some(Rule::load(&rule, &mut expand)));
2213    }
2214
2215    // for (i, ent) in expand.entities.iter().enumerate() {
2216    //     println!("[{i}] = {ent:?}");
2217    // }
2218
2219    // Get the math out of the `Expand`.
2220    let mut math = expand.math;
2221
2222    optimize_cycle(&mut rules, &mut math, &mut build.items);
2223
2224    // Now everything that could be normalized is normalized.
2225    // Unfortunately, normalization can introduce forward referencing, which is not what we want.
2226    // This means we have to fix it. And the easiest way to fix it is to reconstruct it once more.
2227
2228    let old_entities = mem::take(&mut math.entities);
2229    let entity_map: Vec<_> = (0..old_entities.len())
2230        .map(|i| EntityBehavior::MapEntity(EntityId(i)))
2231        .collect();
2232
2233    let old_vars = mem::take(&mut math.expr_record);
2234    let mut ctx = ReconstructCtx::new(&entity_map, &old_vars, &old_entities);
2235    build.items = build.items.reconstruct(&mut ctx);
2236    rules = rules.reconstruct(&mut ctx);
2237    math.expr_record = ctx.new_vars;
2238    let new_entities: Vec<_> = ctx.new_entities.into_iter().map(Option::unwrap).collect();
2239
2240    // for (i, ent) in new_entities.iter().enumerate() {
2241    //     println!("[{i}] = {ent:?}");
2242    // }
2243    //
2244    // for (i, var) in math.expr_record.iter().enumerate() {
2245    //     println!("[{i}] = {:?}", var.kind);
2246    // }
2247
2248    // We can also finalize rules:
2249    let mut rules: Vec<_> = rules.into_iter().flatten().collect();
2250
2251    let flags = read_flags(&unrolled.flags);
2252
2253    // And add point inequalities
2254    if flags.point_inequalities {
2255        for i in new_entities
2256            .iter()
2257            .enumerate()
2258            .filter(|ent| {
2259                matches!(
2260                    ent.1,
2261                    EntityKind::PointOnLine { .. }
2262                        | EntityKind::FreePoint
2263                        | EntityKind::PointOnCircle { .. }
2264                )
2265            })
2266            .map(|x| x.0)
2267        {
2268            for j in new_entities
2269                .iter()
2270                .enumerate()
2271                .skip(i + 1)
2272                .filter(|ent| {
2273                    matches!(
2274                        ent.1,
2275                        EntityKind::PointOnLine { .. }
2276                            | EntityKind::FreePoint
2277                            | EntityKind::PointOnCircle { .. }
2278                    )
2279                })
2280                .map(|x| x.0)
2281            {
2282                let ent1 = math.store(ExprKind::Entity { id: EntityId(i) }, ExprType::Point);
2283                let ent2 = math.store(ExprKind::Entity { id: EntityId(j) }, ExprType::Point);
2284                rules.push(Rule {
2285                    weight: ProcNum::one(),
2286                    entities: Vec::new(),
2287                    kind: RuleKind::Invert(Box::new(RuleKind::PointEq(ent1, ent2))),
2288                });
2289            }
2290        }
2291    }
2292
2293    math.entities = new_entities;
2294
2295    // THE FOLDING STEP
2296    // We perform similar expression elimination multiple times
2297    //      - We reindex the element
2298    //      - We put it into a hashmap
2299    //      - If the element repeats, we replace it and add an entry into the index map
2300    //      - Otherwise, we rejoice
2301    // When the process ends NO REORGANIZATION HAPPENS AS NONE IS NECESSARY
2302    // Rules and entities are updated using the obtained index map.
2303
2304    let mut entities = math.entities;
2305    let mut variables = math.expr_record;
2306    let mut fig_variables = variables.clone();
2307    let mut fig_entities = entities.clone();
2308
2309    let index_map = fold(&mut variables);
2310    entities.reindex(&index_map);
2311    rules.reindex(&index_map);
2312
2313    // Find entities affected by specific rules
2314    let mut found_entities = Vec::new();
2315    for expr in &variables {
2316        let found = expr.find_entities(&found_entities, &entities);
2317        found_entities.push(found);
2318    }
2319
2320    for rule in &mut rules {
2321        let entities = rule.kind.find_entities(&found_entities, &entities);
2322        rule.entities = entities.into_iter().collect();
2323    }
2324
2325    // Fold figure variables
2326    // println!("PRE-FOLD");
2327    //
2328    // for (i, v) in fig_variables.iter().enumerate() {
2329    //     println!("[{i}] = {:?}", v.kind);
2330    // }
2331
2332    // println!("{:#?}", build.items);
2333
2334    let index_map = fold(&mut fig_variables);
2335    // println!("POST-FOLD");
2336
2337    // println!("{index_map:#?}");
2338
2339    // for (i, v) in fig_variables.iter().enumerate() {
2340    //     println!("[{i}] = {:?}", v.kind);
2341    // }
2342    let mut items = build.items;
2343    items.reindex(&index_map);
2344    fig_entities.reindex(&index_map);
2345
2346    // for (i, v) in fig_variables.iter().enumerate() {
2347    //     println!("[{i}] = {:?}", v.kind);
2348    // }
2349    //
2350    // for rule in &rules {
2351    //     println!("\n{:?}", rule.kind);
2352    // }
2353
2354    Ok(Intermediate {
2355        adjusted: Adjusted {
2356            variables,
2357            rules,
2358            entities,
2359        },
2360        figure: Figure {
2361            entities: fig_entities,
2362            variables: fig_variables,
2363            items,
2364        },
2365        flags,
2366    })
2367}