bitsy_lang/circuit/
mlir.rs

1use super::*;
2use std::sync::Arc;
3use std::collections::BTreeMap;
4
5impl Package {
6    pub fn emit_mlir(&self) {
7        for moddef in self.moddefs() {
8            if let Component::Mod(_loc, _name, _children, _wires, _whens) = &*moddef {
9                self.emit_mlir_moddef(moddef);
10            }
11        }
12    }
13
14    fn emit_mlir_moddef(&self, moddef: Arc<Component>) {
15        let mut ports: Vec<(bool, String, Type)> = vec![];
16        let mut output_ports: Vec<String> = vec![];
17        let mut output_port_ssas: BTreeMap<String, String> = BTreeMap::new();
18        let mut output_port_types: Vec<Type> = vec![];
19
20        for (_path, port) in moddef.port_paths() {
21            let name = port.name().to_string();
22            let typ = self.type_of(port.clone()).unwrap();
23            ports.push((port.is_incoming_port(), name.clone(), typ.clone()));
24            if port.is_outgoing_port() {
25                output_ports.push(name);
26                output_port_types.push(typ);
27            }
28        }
29
30        let ctx = self.context_for(moddef.clone());
31
32        println!("hw.module @{}(", moddef.name());
33        self.emit_mlir_moddef_portlist(&ports);
34        println!(") {{");
35
36        for (i, Wire(_loc, target, expr, wire_type)) in moddef.wires().iter().enumerate() {
37            match wire_type {
38                WireType::Direct => {
39                    let ssa = expr.emit_mlir(format!("$comb{i}"), ctx.clone());
40                    let target_string = target.to_string();
41                    if output_ports.contains(&target_string) {
42                        output_port_ssas.insert(target_string, ssa);
43                    } else if let Some(Component::Node(_loc, name, typ)) = moddef.child(target).as_ref().map(|arc| &**arc) {
44                        let type_name = type_to_mlir(typ.clone());
45                        println!("    %{name} = comb.add {ssa} : {type_name}");
46                    }
47                },
48                WireType::Latch => {
49                    let next_ssa = expr.emit_mlir(format!("$comb{i}"), ctx.clone());
50                    let target_string = target.to_string();
51                    let reg = moddef.child(target).unwrap();
52                    let typ = reg.type_of().unwrap();
53                    let reset = reg.reset().unwrap();
54                    let reset_ssa = reset.emit_mlir(format!("$reset{i}"), ctx.clone());
55                    println!("    %{target_string} = seq.firreg {next_ssa} clock %_clock reset sync %_reset, {reset_ssa} : {}", type_to_mlir(typ));
56                },
57                _ => panic!(),
58            }
59        }
60
61        let output_port_ssas: Vec<&str> = output_ports.iter().map(|output_port| {
62            output_port_ssas[output_port].as_str()
63        }).collect();
64        let output_port_types: Vec<String> = output_port_types.iter().map(|typ| {
65            type_to_mlir(typ.clone())
66        }).collect();
67        if output_port_ssas.len() > 0 {
68            println!("    hw.output {} : {}", output_port_ssas.join(","), output_port_types.join(","));
69        }
70        println!("}}");
71    }
72
73    fn emit_mlir_moddef_portlist(&self, ports: &[(bool, String, Type)]) {
74        println!("    in %_clock : !seq.clock,");
75        print!("    in %_reset : i1");
76        if ports.len() > 0 {
77            println!(",");
78        } else {
79            println!();
80        }
81
82        for (i, (is_input, name, typ)) in ports.iter().enumerate() {
83            let typ_name = type_to_mlir(typ.clone());
84            if *is_input {
85                print!("    in %{name} : {typ_name}");
86            } else {
87                print!("    out {name} : {typ_name}");
88            }
89            if i + 1 < ports.len() {
90                println!(",");
91            } else {
92                println!();
93            }
94        }
95    }
96}
97
98impl Expr {
99    fn emit_mlir(&self, prefix: String, ctx: Context<Path, Type>) -> String {
100        let typ: Type = self.type_of();
101        let type_name = type_to_mlir(typ.clone());
102
103        match self {
104            Expr::Reference(_loc, _typ, name) => {
105                format!("%{name}")
106            },
107            Expr::Word(_loc, _typ, _w, n) => {
108                let name = format!("%{prefix}_word");
109                println!("    {name} = hw.constant {n} : {type_name}");
110                name
111            },
112            Expr::Enum(_loc, typ, _typedef, valname) => {
113                let name = format!("%{prefix}_enum");
114                let typedef = if let Type::Enum(typedef) = typ.get().unwrap() {
115                    typedef
116                } else {
117                    panic!();
118                };
119                let v = typedef.value_of(&*valname).unwrap();
120                println!("    {name} = hw.constant {v} : {type_name}");
121                name
122            },
123            Expr::ToWord(_loc, _typ, e1) => {
124                let name = format!("%{prefix}_toword");
125                let e1_ssa = e1.emit_mlir(format!("{prefix}_e1"), ctx.clone());
126                println!("    {name} = comb.add {e1_ssa} : {type_name}");
127                name
128            },
129            Expr::UnOp(_loc, _typ, UnOp::Not, e1) => {
130                let name = format!("%{prefix}_not");
131                let e1_ssa = e1.emit_mlir(format!("{prefix}_not_e1"), ctx.clone());
132                // %c-1_i8 = hw.constant -1 : i8
133                // %0 = comb.xor bin %a, %c-1_i8 : i8
134                println!("    %{prefix}_not_negone = hw.constant -1 : {type_name}");
135                println!("    {name} = comb.xor {e1_ssa}, %{prefix}_not_negone : {type_name}");
136                name
137            },
138            Expr::BinOp(_loc, _typ, BinOp::Add, e1, e2) => {
139                let name = format!("%{prefix}_add");
140                let e1_ssa = e1.emit_mlir(format!("{prefix}_add_e1"), ctx.clone());
141                let e2_ssa = e2.emit_mlir(format!("{prefix}_add_e2"), ctx.clone());
142                println!("    {name} = comb.add {e1_ssa}, {e2_ssa} : {type_name}");
143                name
144            },
145            Expr::BinOp(_loc, _typ, BinOp::Sub, e1, e2) => {
146                let name = format!("%{prefix}_sub");
147                let e1_ssa = e1.emit_mlir(format!("{prefix}_sub_e1"), ctx.clone());
148                let e2_ssa = e2.emit_mlir(format!("{prefix}_sub_e2"), ctx.clone());
149                println!("    {name} = comb.sub {e1_ssa}, {e2_ssa} : {type_name}");
150                name
151            },
152            Expr::BinOp(_loc, _typ, BinOp::And, e1, e2) => {
153                let name = format!("%{prefix}_and");
154                let e1_ssa = e1.emit_mlir(format!("{prefix}_and_e1"), ctx.clone());
155                let e2_ssa = e2.emit_mlir(format!("{prefix}_and_e2"), ctx.clone());
156                println!("    {name} = comb.and {e1_ssa}, {e2_ssa} : {type_name}");
157                name
158            },
159            Expr::BinOp(_loc, _typ, BinOp::Or, e1, e2) => {
160                let name = format!("%{prefix}_or");
161                let e1_ssa = e1.emit_mlir(format!("{prefix}_or_e1"), ctx.clone());
162                let e2_ssa = e2.emit_mlir(format!("{prefix}_or_e2"), ctx.clone());
163                println!("    {name} = comb.or {e1_ssa}, {e2_ssa} : {type_name}");
164                name
165            },
166            Expr::BinOp(_loc, _typ, BinOp::Xor, e1, e2) => {
167                let name = format!("%{prefix}_xor");
168                let e1_ssa = e1.emit_mlir(format!("{prefix}_or_e1"), ctx.clone());
169                let e2_ssa = e2.emit_mlir(format!("{prefix}_or_e2"), ctx.clone());
170                println!("    {name} = comb.xor {e1_ssa}, {e2_ssa} : {type_name}");
171                name
172            },
173            Expr::BinOp(_loc, _typ, BinOp::Eq, e1, e2) => {
174                let name = format!("%{prefix}_eq");
175                let e1_type_name = type_to_mlir(e1.type_of());
176                let e1_ssa = e1.emit_mlir(format!("{prefix}_eq_e1"), ctx.clone());
177                let e2_ssa = e2.emit_mlir(format!("{prefix}_eq_e2"), ctx.clone());
178                // %0 = comb.icmp bin eq %a, %b : i8
179                println!("    {name} = comb.icmp bin eq {e1_ssa}, {e2_ssa} : {e1_type_name}");
180                name
181            },
182            Expr::BinOp(_loc, _typ, BinOp::Lt, e1, e2) => {
183                let name = format!("%{prefix}_eq");
184                let e1_type_name = type_to_mlir(e1.type_of());
185                let e1_ssa = e1.emit_mlir(format!("{prefix}_lt_e1"), ctx.clone());
186                let e2_ssa = e2.emit_mlir(format!("{prefix}_lt_e2"), ctx.clone());
187                // %0 = comb.icmp bin ult %a, %b : i8
188                println!("    {name} = comb.icmp bin ult {e1_ssa}, {e2_ssa} : {e1_type_name}");
189                name
190            },
191            Expr::If(_loc, _typ, cond, e1, e2) => {
192                let name = format!("%{prefix}_if");
193                let cond_ssa = cond.emit_mlir(format!("{prefix}_if_cond"), ctx.clone());
194                let e1_ssa   =   e1.emit_mlir(format!("{prefix}_if_e1"),   ctx.clone());
195                let e2_ssa   =   e2.emit_mlir(format!("{prefix}_if_e2"),   ctx.clone());
196                // %0 = comb.mux bin %in, %a, %b : i8
197                println!("    {name} = comb.mux bin {cond_ssa}, {e1_ssa}, {e2_ssa} : {type_name}");
198                name
199            },
200            Expr::Mux(_loc, _typ, cond, e1, e2) => {
201                let name = format!("%{prefix}_mux");
202                let cond_ssa = cond.emit_mlir(format!("{prefix}_mux_cond"), ctx.clone());
203                let e1_ssa   =   e1.emit_mlir(format!("{prefix}_mux_e1"),   ctx.clone());
204                let e2_ssa   =   e2.emit_mlir(format!("{prefix}_mux_e2"),   ctx.clone());
205                // %0 = comb.mux bin %in, %a, %b : i8
206                println!("    {name} = comb.mux bin {cond_ssa}, {e1_ssa}, {e2_ssa} : {type_name}");
207                name
208            },
209            Expr::Cat(_loc, _typ, es) => {
210            let name = format!("%{prefix}_cat");
211                let mut es_ssas = vec![];
212                let mut es_typenames = vec![];
213                for (i, e) in es.iter().enumerate() {
214                    let ssa = e.emit_mlir(format!("{prefix}_cat_e{i}"),ctx.clone());
215                    let width = e.type_of().bitwidth();
216                    let type_name = format!("i{width}");
217                    es_ssas.push(ssa);
218                    es_typenames.push(type_name);
219                }
220
221                println!("    {name} = comb.concat {} : {}", es_ssas.join(", "), es_typenames.join(", "));
222                name
223            },
224            Expr::Sext(_loc, _typ, e1) => {
225                let name = format!("%{prefix}_sext");
226                match (typ, e1.type_of()) {
227                    (Type::Word(outer_width), Type::Word(inner_width)) => {
228                        assert!(outer_width >= inner_width);
229                        let extension_width = outer_width - inner_width;
230                        let e1_ssa = e1.emit_mlir(format!("{prefix}_sext_e1"),   ctx.clone());
231                        // %c0_i7 = hw.constant 0 : i7
232                        // %0 = comb.concat %c0_i7, %a : i7, i1
233                        println!("    %{prefix}_sext_zero = hw.constant 0 : i{extension_width}");
234                        println!("    {name} = comb.concat %{prefix}_sext_zero, {e1_ssa} : i{extension_width}, i{inner_width}");
235                        name
236                    },
237                    _ => panic!(),
238                }
239            },
240            /*
241            Expr::Let(loc, _typ, name, e, b) => {
242            let name = format!("%{prefix}_let");
243                let e_ssa = e.emit_mlir(format!("{prefix}_let_{name}"), ctx.clone());
244                let b_ssa = b.emit_mlir(format!("{prefix}_let_{name}"), new_ctx);
245
246                println!("comb.add %{b_ssa} : {type_name}");
247                name
248            },
249            */
250            Expr::Idx(_loc, _typ, e1, i) => {
251                let name = format!("%{prefix}_idx");
252                let e1_type_name = type_to_mlir(e1.type_of());
253                let e1_ssa = e1.emit_mlir(format!("{prefix}_idx_e1"), ctx.clone());
254                // %0 = comb.extract %b from 0 : (i8) -> i1
255                println!("    {name} = comb.extract {e1_ssa} from {i} : ({e1_type_name}) -> i1");
256                name
257            },
258            Expr::IdxRange(_loc, _typ, e1, j, i) => {
259                let name = format!("%{prefix}_idxrange");
260                let e1_type_name = type_to_mlir(e1.type_of());
261                let width = j - i;
262                let e1_ssa = e1.emit_mlir(format!("{prefix}_idxrange_e1"), ctx.clone());
263                // %0 = comb.extract %b from 0 : (i8) -> i3
264                println!("    {name} = comb.extract {e1_ssa} from {i} : ({e1_type_name}) -> i{width}");
265                name
266            }
267            _ => panic!("Can't lower expression {self:?}"),
268        }
269    }
270}
271
272fn type_to_mlir(typ: Type) -> String {
273    match typ {
274        Type::Word(n) => format!("i{n}"),
275        Type::Struct(typedef) => {
276            let typedef = typedef;
277            let n = typedef.bitwidth();
278            format!("i{n}")
279        },
280        Type::Enum(typedef) => {
281            let n = typedef.bitwidth();
282            format!("i{n}")
283        }
284        _ => panic!("Can't lower type to MLIR directly"),
285    }
286}