1use std::{env, fmt, path::PathBuf};
8
9use marlin_verilog_macro_builder::{
10    MacroArgs, build_verilated_struct, parse_verilog_ports,
11};
12use proc_macro::TokenStream;
13use quote::{format_ident, quote};
14use syn::{parse_macro_input, spanned::Spanned};
15
16#[proc_macro_attribute]
17pub fn verilog(args: TokenStream, item: TokenStream) -> TokenStream {
18    let args = syn::parse_macro_input!(args as MacroArgs);
19
20    let manifest_directory = PathBuf::from(env::var("CARGO_MANIFEST_DIR").expect("Please compile using `cargo` or set the `CARGO_MANIFEST_DIR` environment variable"));
21    let source_path = manifest_directory.join(args.source_path.value());
22
23    let ports = match parse_verilog_ports(
24        &args.name,
25        &args.source_path,
26        &source_path,
27    ) {
28        Ok(ports) => ports,
29        Err(error) => {
30            return error.into();
31        }
32    };
33
34    build_verilated_struct(
35        "verilog",
36        args.name,
37        syn::LitStr::new(
38            source_path.to_string_lossy().as_ref(),
39            args.source_path.span(),
40        ),
41        ports,
42        args.clock_port,
43        args.reset_port,
44        item.into(),
45    )
46    .into()
47}
48
49enum DPIPrimitiveType {
50    Bool,
51    U8,
52    U16,
53    U32,
54    U64,
55    I8,
56    I16,
57    I32,
58    I64,
59}
60
61impl fmt::Display for DPIPrimitiveType {
62    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63        match self {
64            DPIPrimitiveType::Bool => "bool",
65            DPIPrimitiveType::U8 => "u8",
66            DPIPrimitiveType::U16 => "u16",
67            DPIPrimitiveType::U32 => "u32",
68            DPIPrimitiveType::U64 => "u64",
69            DPIPrimitiveType::I8 => "i8",
70            DPIPrimitiveType::I16 => "i16",
71            DPIPrimitiveType::I32 => "i32",
72            DPIPrimitiveType::I64 => "i64",
73        }
74        .fmt(f)
75    }
76}
77
78impl DPIPrimitiveType {
79    fn as_c(&self) -> &'static str {
80        match self {
81            DPIPrimitiveType::Bool => "svBit",
82            DPIPrimitiveType::U8 => "uint8_t",
83            DPIPrimitiveType::U16 => "uint16_t",
84            DPIPrimitiveType::U32 => "uint32_t",
85            DPIPrimitiveType::U64 => "uint64_t",
86            DPIPrimitiveType::I8 => "int8_t",
87            DPIPrimitiveType::I16 => "int16_t",
88            DPIPrimitiveType::I32 => "int32_t",
89            DPIPrimitiveType::I64 => "int64_t",
90        }
91    }
92}
93
94fn parse_dpi_primitive_type(
95    ty: &syn::TypePath,
96) -> Result<DPIPrimitiveType, syn::Error> {
97    if let Some(qself) = &ty.qself {
98        return Err(syn::Error::new_spanned(
99            qself.lt_token,
100            "Primitive integer type should not be qualified in DPI function",
101        ));
102    }
103
104    match ty
105        .path
106        .require_ident()
107        .or(Err(syn::Error::new_spanned(
108            ty,
109            "Primitive integer type should not have multiple path segments",
110        )))?
111        .to_string()
112        .as_str()
113    {
114        "bool" => Ok(DPIPrimitiveType::Bool),
115        "u8" => Ok(DPIPrimitiveType::U8),
116        "u16" => Ok(DPIPrimitiveType::U16),
117        "u32" => Ok(DPIPrimitiveType::U32),
118        "u64" => Ok(DPIPrimitiveType::U64),
119        "i8" => Ok(DPIPrimitiveType::I8),
120        "i16" => Ok(DPIPrimitiveType::I16),
121        "i32" => Ok(DPIPrimitiveType::I32),
122        "i64" => Ok(DPIPrimitiveType::I64),
123        _ => Err(syn::Error::new_spanned(
124            ty,
125            "Unknown primitive integer type",
126        )),
127    }
128}
129
130enum DPIType {
131    Input(DPIPrimitiveType),
132    Inout(DPIPrimitiveType),
134}
135
136fn parse_dpi_type(ty: &syn::Type) -> Result<DPIType, syn::Error> {
137    match ty {
138        syn::Type::Path(type_path) => {
139            Ok(DPIType::Input(parse_dpi_primitive_type(type_path)?))
140        }
141        syn::Type::Reference(syn::TypeReference {
142            and_token,
143            lifetime,
144            mutability,
145            elem,
146        }) => {
147            if mutability.is_none() {
148                return Err(syn::Error::new_spanned(
149                    and_token,
150                    "DPI output or inout type must be represented with a mutable reference",
151                ));
152            }
153            if let Some(lifetime) = lifetime {
154                return Err(syn::Error::new_spanned(
155                    lifetime,
156                    "DPI output or inout type cannot use lifetimes",
157                ));
158            }
159
160            let syn::Type::Path(type_path) = elem.as_ref() else {
161                return Err(syn::Error::new_spanned(
162                    elem,
163                    "DPI output or inout type must be a mutable reference to a primitive integer type",
164                ));
165            };
166            Ok(DPIType::Inout(parse_dpi_primitive_type(type_path)?))
167        }
168        other => Err(syn::Error::new_spanned(
169            other,
170            "This type is not supported in DPI. Please use primitive integers or mutable references to them",
171        )),
172    }
173}
174
175#[proc_macro_attribute]
214pub fn dpi(_args: TokenStream, item: TokenStream) -> TokenStream {
215    let item_fn = parse_macro_input!(item as syn::ItemFn);
216
217    if !matches!(item_fn.vis, syn::Visibility::Public(_)) {
218        return syn::Error::new_spanned(
219            item_fn.vis,
220            "Marking the function `pub` is required to expose this Rust function to C",
221        )
222        .into_compile_error()
223        .into();
224    }
225
226    let Some(abi) = &item_fn.sig.abi else {
227        return syn::Error::new_spanned(
228            item_fn,
229            "`extern \"C\"` is required to expose this Rust function to C",
230        )
231        .into_compile_error()
232        .into();
233    };
234
235    if !abi
236        .name
237        .as_ref()
238        .map(|name| name.value().as_str() == "C")
239        .unwrap_or(true)
240    {
241        return syn::Error::new_spanned(
242            item_fn,
243            "You must specify the C ABI for the `extern` marking",
244        )
245        .into_compile_error()
246        .into();
247    }
248
249    if item_fn.sig.generics.lt_token.is_some() {
250        return syn::Error::new_spanned(
251            item_fn.sig.generics,
252            "Generics are not supported for DPI functions",
253        )
254        .into_compile_error()
255        .into();
256    }
257
258    if let Some(asyncness) = &item_fn.sig.asyncness {
259        return syn::Error::new_spanned(
260            asyncness,
261            "DPI functions must be synchronous",
262        )
263        .into_compile_error()
264        .into();
265    }
266
267    if let syn::ReturnType::Type(_, return_type) = &item_fn.sig.output {
268        return syn::Error::new_spanned(
269            return_type,
270            "DPI functions cannot have a return value",
271        )
272        .into_compile_error()
273        .into();
274    }
275
276    let ports =
277        match item_fn
278            .sig
279            .inputs
280            .iter()
281            .try_fold(vec![], |mut ports, input| {
282                let syn::FnArg::Typed(parameter) = input else {
283                    return Err(syn::Error::new_spanned(
284                        input,
285                        "Invalid parameter on DPI function",
286                    ));
287                };
288
289                let syn::Pat::Ident(name) = &*parameter.pat else {
290                    return Err(syn::Error::new_spanned(
291                        parameter,
292                        "Function argument must be an identifier",
293                    ));
294                };
295
296                let attrs = parameter.attrs.clone();
297                ports.push((name, attrs, parse_dpi_type(¶meter.ty)?));
298                Ok(ports)
299            }) {
300            Ok(ports) => ports,
301            Err(error) => {
302                return error.into_compile_error().into();
303            }
304        };
305
306    let attributes = item_fn.attrs;
307    let function_name = item_fn.sig.ident;
308    let body = item_fn.block;
309
310    let struct_name = format_ident!("__DPI_{}", function_name);
311
312    let mut parameter_types = vec![];
313    let mut parameters = vec![];
314
315    for (name, attributes, dpi_type) in &ports {
316        let parameter_type = match dpi_type {
317            DPIType::Input(inner) => {
318                let type_ident = format_ident!("{}", inner.to_string());
319                quote! { #type_ident }
320            }
321            DPIType::Inout(inner) => {
322                let type_ident = format_ident!("{}", inner.to_string());
323                quote! { *mut #type_ident }
324            }
325        };
326        parameter_types.push(parameter_type.clone());
327        parameters.push(quote! {
328            #(#attributes)* #name: #parameter_type
329        });
330    }
331
332    let preamble =
333        ports
334            .iter()
335            .filter_map(|(name, _, dpi_type)| match dpi_type {
336                DPIType::Inout(_) => Some(quote! {
337                    let #name = unsafe { &mut *#name };
338                }),
339                _ => None,
340            });
341
342    let function_name_literal = syn::LitStr::new(
343        function_name.to_string().as_str(),
344        function_name.span(),
345    );
346
347    let c_signature = ports
348        .iter()
349        .map(|(name, _, dpi_type)| {
350            let c_type = match dpi_type {
351                DPIType::Input(inner) => inner.as_c().to_string(),
352                DPIType::Inout(inner) => format!("{}*", inner.as_c()),
353            };
354            let name_literal =
355                syn::LitStr::new(name.ident.to_string().as_str(), name.span());
356            let type_literal = syn::LitStr::new(&c_type, name.span());
357            quote! {
358                (#name_literal, #type_literal)
359            }
360        })
361        .collect::<Vec<_>>();
362
363    quote! {
364        #[allow(non_camel_case_types)]
365        struct #struct_name;
366
367        impl #struct_name {
368            #(#attributes)*
369            pub extern "C" fn call(#(#parameters),*) {
370                #(#preamble)*
371                #body
372            }
373        }
374
375        impl verilog::__reexports::verilator::dpi::DpiFunction for #struct_name {
376            fn name(&self) -> &'static str {
377                #function_name_literal
378            }
379
380            fn signature(&self) -> &'static [(&'static str, &'static str)] {
381                &[#(#c_signature),*]
382            }
383
384            fn pointer(&self) -> *const std::ffi::c_void {
385                #struct_name::call as extern "C" fn(#(#parameter_types),*) as *const std::ffi::c_void
386            }
387        }
388
389        #[allow(non_upper_case_globals)]
390        pub static #function_name: &'static dyn verilog::__reexports::verilator::dpi::DpiFunction = &#struct_name;
391    }
392    .into()
393}