1use std::{env, fmt, path::PathBuf};
8
9use marlin_verilog_macro_builder::{
10 MacroArgs, build_verilated_struct, parse_verilog_ports,
11};
12use proc_macro::TokenStream;
13use quote::{format_ident, quote};
14use syn::{parse_macro_input, spanned::Spanned};
15
16#[proc_macro_attribute]
17pub fn verilog(args: TokenStream, item: TokenStream) -> TokenStream {
18 let args = syn::parse_macro_input!(args as MacroArgs);
19
20 let manifest_directory = PathBuf::from(env::var("CARGO_MANIFEST_DIR").expect("Please compile using `cargo` or set the `CARGO_MANIFEST_DIR` environment variable"));
21 let source_path = manifest_directory.join(args.source_path.value());
22
23 let ports = match parse_verilog_ports(
24 &args.name,
25 &args.source_path,
26 &source_path,
27 ) {
28 Ok(ports) => ports,
29 Err(error) => {
30 return error.into();
31 }
32 };
33
34 build_verilated_struct(
35 "verilog",
36 args.name,
37 syn::LitStr::new(
38 source_path.to_string_lossy().as_ref(),
39 args.source_path.span(),
40 ),
41 ports,
42 args.clock_port,
43 args.reset_port,
44 item.into(),
45 )
46 .into()
47}
48
49enum DPIPrimitiveType {
50 U8,
51 U16,
52 U32,
53 U64,
54 I8,
55 I16,
56 I32,
57 I64,
58}
59
60impl fmt::Display for DPIPrimitiveType {
61 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62 match self {
63 DPIPrimitiveType::U8 => "u8",
64 DPIPrimitiveType::U16 => "u16",
65 DPIPrimitiveType::U32 => "u32",
66 DPIPrimitiveType::U64 => "u64",
67 DPIPrimitiveType::I8 => "i8",
68 DPIPrimitiveType::I16 => "i16",
69 DPIPrimitiveType::I32 => "i32",
70 DPIPrimitiveType::I64 => "i64",
71 }
72 .fmt(f)
73 }
74}
75
76impl DPIPrimitiveType {
77 fn as_c(&self) -> &'static str {
78 match self {
79 DPIPrimitiveType::U8 => "int8_t",
81 DPIPrimitiveType::U16 => "int16_t",
82 DPIPrimitiveType::U32 => "int32_t",
83 DPIPrimitiveType::U64 => "int64_t",
84 DPIPrimitiveType::I8 => "int8_t",
85 DPIPrimitiveType::I16 => "int16_t",
86 DPIPrimitiveType::I32 => "int32_t",
87 DPIPrimitiveType::I64 => "int64_t",
88 }
89 }
90}
91
92fn parse_dpi_primitive_type(
93 ty: &syn::TypePath,
94) -> Result<DPIPrimitiveType, syn::Error> {
95 if let Some(qself) = &ty.qself {
96 return Err(syn::Error::new_spanned(
97 qself.lt_token,
98 "Primitive integer type should not be qualified in DPI function",
99 ));
100 }
101
102 match ty
103 .path
104 .require_ident()
105 .or(Err(syn::Error::new_spanned(
106 ty,
107 "Primitive integer type should not have multiple path segments",
108 )))?
109 .to_string()
110 .as_str()
111 {
112 "u8" => Ok(DPIPrimitiveType::U8),
113 "u16" => Ok(DPIPrimitiveType::U16),
114 "u32" => Ok(DPIPrimitiveType::U32),
115 "u64" => Ok(DPIPrimitiveType::U64),
116 "i8" => Ok(DPIPrimitiveType::I8),
117 "i16" => Ok(DPIPrimitiveType::I16),
118 "i32" => Ok(DPIPrimitiveType::I32),
119 "i64" => Ok(DPIPrimitiveType::I64),
120 _ => Err(syn::Error::new_spanned(
121 ty,
122 "Unknown primitive integer type",
123 )),
124 }
125}
126
127enum DPIType {
128 Input(DPIPrimitiveType),
129 Inout(DPIPrimitiveType),
131}
132
133fn parse_dpi_type(ty: &syn::Type) -> Result<DPIType, syn::Error> {
134 match ty {
135 syn::Type::Path(type_path) => {
136 Ok(DPIType::Input(parse_dpi_primitive_type(type_path)?))
137 }
138 syn::Type::Reference(syn::TypeReference {
139 and_token,
140 lifetime,
141 mutability,
142 elem,
143 }) => {
144 if mutability.is_none() {
145 return Err(syn::Error::new_spanned(
146 and_token,
147 "DPI output or inout type must be represented with a mutable reference",
148 ));
149 }
150 if let Some(lifetime) = lifetime {
151 return Err(syn::Error::new_spanned(
152 lifetime,
153 "DPI output or inout type cannot use lifetimes",
154 ));
155 }
156
157 let syn::Type::Path(type_path) = elem.as_ref() else {
158 return Err(syn::Error::new_spanned(
159 elem,
160 "DPI output or inout type must be a mutable reference to a primitive integer type",
161 ));
162 };
163 Ok(DPIType::Inout(parse_dpi_primitive_type(type_path)?))
164 }
165 other => Err(syn::Error::new_spanned(
166 other,
167 "This type is not supported in DPI. Please use primitive integers or mutable references to them",
168 )),
169 }
170}
171
172#[proc_macro_attribute]
173pub fn dpi(_args: TokenStream, item: TokenStream) -> TokenStream {
174 let item_fn = parse_macro_input!(item as syn::ItemFn);
175
176 if !matches!(item_fn.vis, syn::Visibility::Public(_)) {
177 return syn::Error::new_spanned(
178 item_fn.vis,
179 "Marking the function `pub` is required to expose this Rust function to C",
180 )
181 .into_compile_error()
182 .into();
183 }
184
185 let Some(abi) = &item_fn.sig.abi else {
186 return syn::Error::new_spanned(
187 item_fn,
188 "`extern \"C\"` is required to expose this Rust function to C",
189 )
190 .into_compile_error()
191 .into();
192 };
193
194 if !abi
195 .name
196 .as_ref()
197 .map(|name| name.value().as_str() == "C")
198 .unwrap_or(true)
199 {
200 return syn::Error::new_spanned(
201 item_fn,
202 "You must specify the C ABI for the `extern` marking",
203 )
204 .into_compile_error()
205 .into();
206 }
207
208 if item_fn.sig.generics.lt_token.is_some() {
209 return syn::Error::new_spanned(
210 item_fn.sig.generics,
211 "Generics are not supported for DPI functions",
212 )
213 .into_compile_error()
214 .into();
215 }
216
217 if let Some(asyncness) = &item_fn.sig.asyncness {
218 return syn::Error::new_spanned(
219 asyncness,
220 "DPI functions must be synchronous",
221 )
222 .into_compile_error()
223 .into();
224 }
225
226 if let syn::ReturnType::Type(_, return_type) = &item_fn.sig.output {
227 return syn::Error::new_spanned(
228 return_type,
229 "DPI functions cannot have a return value",
230 )
231 .into_compile_error()
232 .into();
233 }
234
235 let ports =
236 match item_fn
237 .sig
238 .inputs
239 .iter()
240 .try_fold(vec![], |mut ports, input| {
241 let syn::FnArg::Typed(parameter) = input else {
242 return Err(syn::Error::new_spanned(
243 input,
244 "Invalid parameter on DPI function",
245 ));
246 };
247
248 let syn::Pat::Ident(name) = &*parameter.pat else {
249 return Err(syn::Error::new_spanned(
250 parameter,
251 "Function argument must be an identifier",
252 ));
253 };
254
255 let attrs = parameter.attrs.clone();
256 ports.push((name, attrs, parse_dpi_type(¶meter.ty)?));
257 Ok(ports)
258 }) {
259 Ok(ports) => ports,
260 Err(error) => {
261 return error.into_compile_error().into();
262 }
263 };
264
265 let attributes = item_fn.attrs;
266 let function_name = item_fn.sig.ident;
267 let body = item_fn.block;
268
269 let struct_name = format_ident!("__DPI_{}", function_name);
270
271 let mut parameter_types = vec![];
272 let mut parameters = vec![];
273
274 for (name, attributes, dpi_type) in &ports {
275 let parameter_type = match dpi_type {
276 DPIType::Input(inner) => {
277 let type_ident = format_ident!("{}", inner.to_string());
278 quote! { #type_ident }
279 }
280 DPIType::Inout(inner) => {
281 let type_ident = format_ident!("{}", inner.to_string());
282 quote! { *mut #type_ident }
283 }
284 };
285 parameter_types.push(parameter_type.clone());
286 parameters.push(quote! {
287 #(#attributes)* #name: #parameter_type
288 });
289 }
290
291 let preamble =
292 ports
293 .iter()
294 .filter_map(|(name, _, dpi_type)| match dpi_type {
295 DPIType::Inout(_) => Some(quote! {
296 let #name = unsafe { &mut *#name };
297 }),
298 _ => None,
299 });
300
301 let function_name_literal = syn::LitStr::new(
302 function_name.to_string().as_str(),
303 function_name.span(),
304 );
305
306 let c_signature = ports
307 .iter()
308 .map(|(name, _, dpi_type)| {
309 let c_type = match dpi_type {
310 DPIType::Input(inner) => inner.as_c().to_string(),
311 DPIType::Inout(inner) => format!("{}*", inner.as_c()),
312 };
313 let name_literal =
314 syn::LitStr::new(name.ident.to_string().as_str(), name.span());
315 let type_literal = syn::LitStr::new(&c_type, name.span());
316 quote! {
317 (#name_literal, #type_literal)
318 }
319 })
320 .collect::<Vec<_>>();
321
322 quote! {
323 #[allow(non_camel_case_types)]
324 struct #struct_name;
325
326 impl #struct_name {
327 #(#attributes)*
328 pub extern "C" fn call(#(#parameters),*) {
329 #(#preamble)*
330 #body
331 }
332 }
333
334 impl verilog::__reexports::verilator::dpi::DpiFunction for #struct_name {
335 fn name(&self) -> &'static str {
336 #function_name_literal
337 }
338
339 fn signature(&self) -> &'static [(&'static str, &'static str)] {
340 &[#(#c_signature),*]
341 }
342
343 fn pointer(&self) -> *const verilog::__reexports::libc::c_void {
344 #struct_name::call as extern "C" fn(#(#parameter_types),*) as *const verilog::__reexports::libc::c_void
345 }
346 }
347
348 #[allow(non_upper_case_globals)]
349 pub static #function_name: &'static dyn verilog::__reexports::verilator::dpi::DpiFunction = &#struct_name;
350 }
351 .into()
352}