kodept_interpret/
operator_desugaring.rs

1use std::convert::Infallible;
2
3use kodept_ast::{Acc, Appl, BinExpr, BinaryExpressionKind, BitKind, ComparisonKind, EqKind, Expression, Identifier, LogicKind, MathKind, Operation, Ref, ReferenceContext, Term, UnExpr, UnaryExpressionKind};
4use kodept_ast::graph::{Change, ChangeSet, AnyNode, tags};
5use kodept_ast::traits::Identifiable;
6use kodept_ast::utils::Execution;
7use kodept_ast::visit_side::{VisitGuard, VisitSide};
8use kodept_macros::Macro;
9use kodept_macros::traits::Context;
10
11#[derive(Default)]
12pub struct BinaryOperatorExpander;
13
14#[derive(Default)]
15pub struct UnaryOperatorExpander;
16
17#[derive(Default)]
18pub struct AccessExpander;
19
20fn replace_with<N: Identifiable + Into<AnyNode>>(
21    replaced: &N,
22    function_name: &'static str,
23) -> ChangeSet {
24    // ::Prelude::<function_name>(<left>, <right>)
25    let id = replaced.get_id().widen();
26
27    ChangeSet::from_iter([
28        Change::replace(id, Appl::uninit()),
29        Change::add::<_, _, { tags::PRIMARY }>(
30            id.narrow::<Appl>(),
31            Ref::uninit(
32                ReferenceContext::global(["Prelude"]),
33                Identifier::Reference {
34                    name: function_name.to_string(),
35                },
36            )
37            .map_into::<Term>()
38            .map_into::<Expression>()
39            .map_into::<Operation>(),
40        ),
41    ])
42}
43
44impl BinaryOperatorExpander {
45    pub fn new() -> Self {
46        Self
47    }
48}
49
50impl UnaryOperatorExpander {
51    pub fn new() -> Self {
52        Self
53    }
54}
55
56impl AccessExpander {
57    pub fn new() -> Self {
58        Self
59    }
60}
61
62impl Macro for BinaryOperatorExpander {
63    type Error = Infallible;
64    type Node = BinExpr;
65
66    fn transform(
67        &mut self,
68        guard: VisitGuard<Self::Node>,
69        _context: &mut impl Context,
70    ) -> Execution<Self::Error, ChangeSet> {
71        let node = guard.allow_only(VisitSide::Entering)?;
72
73        Execution::Completed(match &node.kind {
74            BinaryExpressionKind::Math(x) => match x {
75                MathKind::Add => replace_with(&*node, "__add_internal"),
76                MathKind::Sub => replace_with(&*node, "__sub_internal"),
77                MathKind::Mul => replace_with(&*node, "__mul_internal"),
78                MathKind::Pow => replace_with(&*node, "__pow_internal"),
79                MathKind::Div => replace_with(&*node, "__div_internal"),
80                MathKind::Mod => replace_with(&*node, "__mod_internal"),
81            },
82            BinaryExpressionKind::Cmp(x) => match x {
83                ComparisonKind::Less => replace_with(&*node, "__less_internal"),
84                ComparisonKind::LessEq => replace_with(&*node, "__less_eq_internal"),
85                ComparisonKind::Greater => replace_with(&*node, "__greater_internal"),
86                ComparisonKind::GreaterEq => replace_with(&*node, "__greater_internal"),
87            },
88            BinaryExpressionKind::Eq(x) => match x {
89                EqKind::Eq => replace_with(&*node, "__eq_internal"),
90                EqKind::NEq => replace_with(&*node, "__neq_internal")
91            },
92            BinaryExpressionKind::Bit(x) => match x {
93                BitKind::Or => replace_with(&*node, "__or_internal"),
94                BitKind::And => replace_with(&*node, "__and_internal"),
95                BitKind::Xor => replace_with(&*node, "__xor_internal"),
96            },
97            BinaryExpressionKind::Logic(x) => match x {
98                LogicKind::Disj => replace_with(&*node, "__dis_internal"),
99                LogicKind::Conj => replace_with(&*node, "__con_internal"),
100            },
101            BinaryExpressionKind::ComplexComparison => replace_with(&*node, "__cmp_internal"),
102            BinaryExpressionKind::Assign => replace_with(&*node, "__assign_internal")
103        })
104    }
105}
106
107impl Macro for UnaryOperatorExpander {
108    type Error = Infallible;
109    type Node = UnExpr;
110
111    fn transform(
112        &mut self,
113        guard: VisitGuard<Self::Node>,
114        _: &mut impl Context,
115    ) -> Execution<Self::Error, ChangeSet> {
116        let node = guard.allow_only(VisitSide::Entering)?;
117
118        Execution::Completed(match node.kind {
119            UnaryExpressionKind::Neg => replace_with(&*node, "__neg_internal"),
120            UnaryExpressionKind::Not => replace_with(&*node, "__not_internal"),
121            UnaryExpressionKind::Inv => replace_with(&*node, "__inv_internal"),
122            UnaryExpressionKind::Plus => replace_with(&*node, "__plus_internal"),
123        })
124    }
125}
126
127impl Macro for AccessExpander {
128    type Error = Infallible;
129    type Node = Acc;
130
131    fn transform(&mut self, guard: VisitGuard<Self::Node>, _: &mut impl Context) -> Execution<Self::Error, ChangeSet> {
132        let node = guard.allow_only(VisitSide::Entering)?;
133        
134        Execution::Completed(replace_with(&*node, "compose"))
135    }
136}