Skip to main content

marlin_spade_macro/
lib.rs

1// Copyright (C) 2024 Ethan Uppal.
2//
3// This Source Code Form is subject to the terms of the Mozilla Public License,
4// v. 2.0. If a copy of the MPL was not distributed with this file, You can
5// obtain one at https://mozilla.org/MPL/2.0/.
6
7use std::{env, fs};
8
9use camino::Utf8PathBuf;
10use marlin_verilator::PortDirection;
11use marlin_verilog_macro_builder::{MacroArgs, build_verilated_struct};
12use proc_macro::TokenStream;
13use spade_parser::logos::Logos;
14
15fn search_for_swim_toml(mut start: Utf8PathBuf) -> Option<Utf8PathBuf> {
16    while start.parent().is_some() {
17        if start.join("swim.toml").is_file() {
18            return Some(start.join("swim.toml"));
19        }
20        start.pop();
21    }
22    None
23}
24
25#[proc_macro_attribute]
26pub fn spade(args: TokenStream, item: TokenStream) -> TokenStream {
27    let args = syn::parse_macro_input!(args as MacroArgs);
28
29    let manifest_directory = Utf8PathBuf::from(
30        env::var("CARGO_MANIFEST_DIR").expect("Please use CARGO"),
31    );
32    let Some(swim_toml) = search_for_swim_toml(manifest_directory) else {
33        return syn::Error::new_spanned(
34            args.source_path,
35            "Could not find swim.toml",
36        )
37        .into_compile_error()
38        .into();
39    };
40
41    let verilog_source_path = {
42        let mut source_path = swim_toml.clone();
43        source_path.pop();
44        source_path.push("build/spade.sv");
45        syn::LitStr::new(source_path.as_str(), args.source_path.span())
46    };
47
48    let spade_source_path = {
49        let mut spade_source_path = swim_toml.clone();
50        spade_source_path.pop();
51        spade_source_path.join(args.source_path.value())
52    };
53    let source_code = match fs::read_to_string(&spade_source_path) {
54        Ok(contents) => contents,
55        Err(error) => {
56            return syn::Error::new_spanned(
57                &args.source_path,
58                format!(
59                    "Failed to read source code file at {spade_source_path}: {error}"
60                ),
61            )
62            .into_compile_error()
63            .into();
64        }
65    };
66
67    let lexer = <spade_parser::lexer::TokenKind as Logos>::lexer(&source_code);
68    let mut parser = spade_parser::Parser::new(lexer, 0);
69    let top_level = match parser.top_level_module_body() {
70        Ok(body) => body,
71        Err(_error) => {
72            return syn::Error::new_spanned(
73                args.source_path,
74                "Failed to parse Spade code: run the Spade compiler for more details",
75            )
76            .into_compile_error()
77            .into();
78        }
79    };
80
81    let Some(unit_head) =
82        top_level.members.iter().find_map(|item| match item {
83            spade_ast::Item::Unit(unit)
84                if unit.head.name.as_str() == args.name.value().as_str() =>
85            {
86                Some(unit.head.clone())
87            }
88            _ => None,
89        })
90    else {
91        let names = top_level
92            .members
93            .iter()
94            .filter_map(|item| match item {
95                spade_ast::Item::Unit(unit) => Some(format!(
96                    "{} {}",
97                    unit.head.unit_kind,
98                    unit.head.name.as_str()
99                )),
100                _ => None,
101            })
102            .collect::<Vec<_>>();
103        return syn::Error::new_spanned(
104            &args.name,
105            format!(
106                "Could not find top-level unit named `{}` in {}. Remember to use `#[no_mangle(all)]`. Unit names found are: {} (there are {} module item(s) in total)",
107                args.name.value(),
108                args.source_path.value(),
109                if names.is_empty() {"<none found>".into()} else {names.join(", ")},
110                top_level.members.len()
111            ),
112        )
113        .into_compile_error()
114        .into();
115    };
116
117    let Some(unit_mangle_attribute) = unit_head
118        .attributes
119        .0
120        .iter()
121        .find(|attribute| attribute.name() == "no_mangle")
122    else {
123        return syn::Error::new_spanned(
124            &args.name,
125            format!(
126                "Annotate `{}` with `#[no_mangle(all)]`",
127                args.name.value()
128            ),
129        )
130        .into_compile_error()
131        .into();
132    };
133    let is_no_mangle_all = matches!(
134        unit_mangle_attribute.inner,
135        spade_ast::Attribute::NoMangle { all: true }
136    );
137
138    if unit_head.output_type.is_some() {
139        return syn::Error::new_spanned(
140            &args.name,
141            format!(
142                "Unsupported output type on `{}` (verilator makes this annoying): use `inv &` instead",
143                args.name.value()
144            ),
145        )
146        .into_compile_error()
147        .into();
148    }
149
150    let mut ports = vec![];
151    for (attributes, port_name, port_type) in &unit_head.inputs.inner.args {
152        if !attributes
153            .0
154            .iter()
155            .any(|attribute| attribute.name() == "no_mangle")
156            && !is_no_mangle_all
157        {
158            return syn::Error::new_spanned(
159                &args.name,
160                format!(
161                    "Annotate the unit `{}` with `#[no_mangle(all)]` or just the port `{}` with `#[no_mangle]`",
162                    args.name.value(),
163                    port_name.inner,
164                ),
165            )
166            .into_compile_error()
167            .into();
168        }
169
170        let port_direction = match &port_type.inner {
171            spade_ast::TypeSpec::Inverted(_) => PortDirection::Output,
172            _ => PortDirection::Input,
173        };
174
175        let port_msb = spade_simple_type_width(&port_type.inner) - 1;
176
177        ports.push((
178            port_name.inner.as_str().to_string(),
179            port_msb,
180            0,
181            port_direction,
182        ));
183    }
184
185    build_verilated_struct(
186        "spade",
187        args.name,
188        verilog_source_path,
189        ports,
190        item.into(),
191    )
192    .into()
193}
194
195// TODO: make this decent with error handling. this is some of the worst code
196// I've written. This implementation is based off of https://gitlab.com/spade-lang/spade/-/blob/79cfd7ed12ee8a7328aa6e6650e394ed55ed2b2c/spade-mir/src/types.rs
197/// Determines the bit-width of a "simple" type present in a Spade top exposed
198/// to Verilog, e.g., integers and inverted integers, clocks, etc.
199fn spade_simple_type_width(type_spec: &spade_ast::TypeSpec) -> usize {
200    fn get_type_spec(
201        type_expression: &spade_ast::TypeExpression,
202    ) -> &spade_ast::TypeSpec {
203        match type_expression {
204            spade_ast::TypeExpression::TypeSpec(type_spec) => type_spec,
205            _ => panic!("Expected a type spec"),
206        }
207    }
208
209    fn get_constant(type_expression: &spade_ast::TypeExpression) -> usize {
210        // TODO: handle bigints correctly
211        match type_expression {
212            spade_ast::TypeExpression::Integer(big_int) => {
213                big_int.to_u64_digits().1[0] as usize
214            }
215            _ => panic!("Expected an integer"),
216        }
217    }
218
219    match type_spec {
220        spade_ast::TypeSpec::Tuple(inner) => inner
221            .iter()
222            .map(|type_expression| {
223                spade_simple_type_width(get_type_spec(type_expression))
224            })
225            .sum(),
226        spade_ast::TypeSpec::Named(name, args) => {
227            if name.inner.0.len() != 1 {
228                panic!("I'm so done writing error messages");
229            }
230            match name.inner.0[0].unwrap_named().as_str() {
231                "int" | "uint" => {
232                    if args.is_none() {
233                        panic!(
234                            "Found an integer without a size in the top module head"
235                        );
236                    }
237                    if args.as_ref().unwrap().len() != 1 {
238                        panic!(
239                            "Found an integer with more than one argument in the top module"
240                        );
241                    }
242                    get_constant(&args.as_ref().unwrap().inner[0])
243                }
244                "clock" | "bool" => 1,
245                other => panic!("Unsupported type in the top module: {other}"),
246            }
247        }
248        spade_ast::TypeSpec::Array { inner, size } => {
249            spade_simple_type_width(get_type_spec(inner)) * get_constant(size)
250        }
251        spade_ast::TypeSpec::Inverted(inner) => {
252            spade_simple_type_width(get_type_spec(inner))
253        }
254        spade_ast::TypeSpec::Wire(inner) => {
255            spade_simple_type_width(get_type_spec(inner))
256        }
257        spade_ast::TypeSpec::Wildcard => {
258            panic!("Invalid type for Verilog-exposed Spade top")
259        }
260    }
261}