1use std::convert::Infallible;
2
3use kodept_ast::{Acc, Appl, BinExpr, BinaryExpressionKind, BitKind, ComparisonKind, EqKind, Expression, Identifier, LogicKind, MathKind, Operation, Ref, ReferenceContext, Term, UnExpr, UnaryExpressionKind};
4use kodept_ast::graph::{Change, ChangeSet, AnyNode, tags};
5use kodept_ast::traits::Identifiable;
6use kodept_ast::utils::Execution;
7use kodept_ast::visit_side::{VisitGuard, VisitSide};
8use kodept_macros::Macro;
9use kodept_macros::traits::Context;
10
11#[derive(Default)]
12pub struct BinaryOperatorExpander;
13
14#[derive(Default)]
15pub struct UnaryOperatorExpander;
16
17#[derive(Default)]
18pub struct AccessExpander;
19
20fn replace_with<N: Identifiable + Into<AnyNode>>(
21 replaced: &N,
22 function_name: &'static str,
23) -> ChangeSet {
24 let id = replaced.get_id().widen();
26
27 ChangeSet::from_iter([
28 Change::replace(id, Appl::uninit()),
29 Change::add::<_, _, { tags::PRIMARY }>(
30 id.narrow::<Appl>(),
31 Ref::uninit(
32 ReferenceContext::global(["Prelude"]),
33 Identifier::Reference {
34 name: function_name.to_string(),
35 },
36 )
37 .map_into::<Term>()
38 .map_into::<Expression>()
39 .map_into::<Operation>(),
40 ),
41 ])
42}
43
44impl BinaryOperatorExpander {
45 pub fn new() -> Self {
46 Self
47 }
48}
49
50impl UnaryOperatorExpander {
51 pub fn new() -> Self {
52 Self
53 }
54}
55
56impl AccessExpander {
57 pub fn new() -> Self {
58 Self
59 }
60}
61
62impl Macro for BinaryOperatorExpander {
63 type Error = Infallible;
64 type Node = BinExpr;
65
66 fn transform(
67 &mut self,
68 guard: VisitGuard<Self::Node>,
69 _context: &mut impl Context,
70 ) -> Execution<Self::Error, ChangeSet> {
71 let node = guard.allow_only(VisitSide::Entering)?;
72
73 Execution::Completed(match &node.kind {
74 BinaryExpressionKind::Math(x) => match x {
75 MathKind::Add => replace_with(&*node, "__add_internal"),
76 MathKind::Sub => replace_with(&*node, "__sub_internal"),
77 MathKind::Mul => replace_with(&*node, "__mul_internal"),
78 MathKind::Pow => replace_with(&*node, "__pow_internal"),
79 MathKind::Div => replace_with(&*node, "__div_internal"),
80 MathKind::Mod => replace_with(&*node, "__mod_internal"),
81 },
82 BinaryExpressionKind::Cmp(x) => match x {
83 ComparisonKind::Less => replace_with(&*node, "__less_internal"),
84 ComparisonKind::LessEq => replace_with(&*node, "__less_eq_internal"),
85 ComparisonKind::Greater => replace_with(&*node, "__greater_internal"),
86 ComparisonKind::GreaterEq => replace_with(&*node, "__greater_internal"),
87 },
88 BinaryExpressionKind::Eq(x) => match x {
89 EqKind::Eq => replace_with(&*node, "__eq_internal"),
90 EqKind::NEq => replace_with(&*node, "__neq_internal")
91 },
92 BinaryExpressionKind::Bit(x) => match x {
93 BitKind::Or => replace_with(&*node, "__or_internal"),
94 BitKind::And => replace_with(&*node, "__and_internal"),
95 BitKind::Xor => replace_with(&*node, "__xor_internal"),
96 },
97 BinaryExpressionKind::Logic(x) => match x {
98 LogicKind::Disj => replace_with(&*node, "__dis_internal"),
99 LogicKind::Conj => replace_with(&*node, "__con_internal"),
100 },
101 BinaryExpressionKind::ComplexComparison => replace_with(&*node, "__cmp_internal"),
102 BinaryExpressionKind::Assign => replace_with(&*node, "__assign_internal")
103 })
104 }
105}
106
107impl Macro for UnaryOperatorExpander {
108 type Error = Infallible;
109 type Node = UnExpr;
110
111 fn transform(
112 &mut self,
113 guard: VisitGuard<Self::Node>,
114 _: &mut impl Context,
115 ) -> Execution<Self::Error, ChangeSet> {
116 let node = guard.allow_only(VisitSide::Entering)?;
117
118 Execution::Completed(match node.kind {
119 UnaryExpressionKind::Neg => replace_with(&*node, "__neg_internal"),
120 UnaryExpressionKind::Not => replace_with(&*node, "__not_internal"),
121 UnaryExpressionKind::Inv => replace_with(&*node, "__inv_internal"),
122 UnaryExpressionKind::Plus => replace_with(&*node, "__plus_internal"),
123 })
124 }
125}
126
127impl Macro for AccessExpander {
128 type Error = Infallible;
129 type Node = Acc;
130
131 fn transform(&mut self, guard: VisitGuard<Self::Node>, _: &mut impl Context) -> Execution<Self::Error, ChangeSet> {
132 let node = guard.allow_only(VisitSide::Entering)?;
133
134 Execution::Completed(replace_with(&*node, "compose"))
135 }
136}