1use proc_macro::TokenStream;
46use proc_macro2::Span;
47use quote::{quote, ToTokens};
48use syn::{parse_macro_input, BinOp, Expr, FnArg, Ident, ItemFn, Lit, PatType, Stmt, Type, UnOp};
49
50#[proc_macro_attribute]
52pub fn kernel(_attr: TokenStream, item: TokenStream) -> TokenStream {
53 let input = parse_macro_input!(item as ItemFn);
54 expand_kernel(&input).map_or_else(|e| e.to_compile_error().into(), Into::into)
55}
56
57#[derive(Clone)]
59enum ScalarTy {
60 Bool,
61 Bits(u32),
62 Signed(u32),
63}
64
65impl ScalarTy {
66 fn wire_ty_tokens(&self) -> proc_macro2::TokenStream {
67 match self {
68 Self::Bool => quote! { ::hdl_cat_ir::WireTy::Bit },
69 Self::Bits(n) => quote! { ::hdl_cat_ir::WireTy::Bits(#n) },
70 Self::Signed(n) => quote! { ::hdl_cat_ir::WireTy::Signed(#n) },
71 }
72 }
73
74 fn obj_ty_tokens(&self) -> proc_macro2::TokenStream {
75 match self {
76 Self::Bool => quote! { ::hdl_cat_circuit::Obj<bool> },
77 Self::Bits(n) => {
78 let n_literal = *n as usize;
79 quote! { ::hdl_cat_circuit::Obj<::hdl_cat_bits::Bits<#n_literal>> }
80 }
81 Self::Signed(n) => {
82 let n_literal = *n as usize;
83 quote! { ::hdl_cat_circuit::Obj<::hdl_cat_bits::SignedBits<#n_literal>> }
84 }
85 }
86 }
87}
88
89fn parse_scalar_ty(ty: &Type) -> Result<ScalarTy, syn::Error> {
90 let Type::Path(p) = ty else {
91 return Err(syn::Error::new_spanned(ty, "unsupported type"));
92 };
93 let segment = p
94 .path
95 .segments
96 .last()
97 .ok_or_else(|| syn::Error::new_spanned(p, "empty path"))?;
98 let name = segment.ident.to_string();
99 match name.as_str() {
100 "bool" => Ok(ScalarTy::Bool),
101 "Bits" | "SignedBits" => {
102 let syn::PathArguments::AngleBracketed(args) = &segment.arguments else {
103 return Err(syn::Error::new_spanned(
104 segment,
105 "Bits/SignedBits requires a const generic width",
106 ));
107 };
108 let arg = args.args.first().ok_or_else(|| {
109 syn::Error::new_spanned(args, "expected single const generic arg")
110 })?;
111 let width = const_width_from_generic_arg(arg)?;
112 if name == "Bits" {
113 Ok(ScalarTy::Bits(width))
114 } else {
115 Ok(ScalarTy::Signed(width))
116 }
117 }
118 other => Err(syn::Error::new_spanned(
119 segment,
120 format!("unsupported type `{other}`"),
121 )),
122 }
123}
124
125fn const_width_from_generic_arg(arg: &syn::GenericArgument) -> Result<u32, syn::Error> {
126 let expr = match arg {
127 syn::GenericArgument::Const(e) => Ok(e),
128 syn::GenericArgument::Type(Type::Path(p)) => Err(syn::Error::new_spanned(
129 p,
130 "expected a literal width, not a type path",
131 )),
132 other => Err(syn::Error::new_spanned(other, "expected a const literal width")),
133 }?;
134 let Expr::Lit(lit) = expr else {
135 return Err(syn::Error::new_spanned(expr, "expected a const literal"));
136 };
137 let Lit::Int(n) = &lit.lit else {
138 return Err(syn::Error::new_spanned(&lit.lit, "expected an integer literal"));
139 };
140 n.base10_parse::<u32>()
141}
142
143#[derive(Clone)]
147struct BodyCtx {
148 stmts: Vec<proc_macro2::TokenStream>,
149 env: Vec<(String, Ident, ScalarTy)>,
150 fresh_counter: usize,
151}
152
153impl BodyCtx {
154 fn new() -> Self {
155 Self {
156 stmts: Vec::new(),
157 env: Vec::new(),
158 fresh_counter: 0,
159 }
160 }
161
162 fn fresh_wire_ident(self) -> (Self, Ident) {
163 let id = Ident::new(
164 &format!("__k_tmp_{}", self.fresh_counter),
165 Span::call_site(),
166 );
167 (
168 Self {
169 fresh_counter: self.fresh_counter + 1,
170 ..self
171 },
172 id,
173 )
174 }
175
176 fn bind(self, source_name: String, wire_ident: Ident, ty: ScalarTy) -> Self {
177 let new_env = self
178 .env
179 .into_iter()
180 .chain(core::iter::once((source_name, wire_ident, ty)))
181 .collect();
182 Self {
183 env: new_env,
184 ..self
185 }
186 }
187
188 fn lookup(&self, name: &str) -> Option<(Ident, ScalarTy)> {
189 self.env
190 .iter()
191 .rev()
192 .find(|(n, _, _)| n == name)
193 .map(|(_, id, ty)| (id.clone(), ty.clone()))
194 }
195
196 fn push_stmt(self, ts: proc_macro2::TokenStream) -> Self {
197 let new_stmts = self
198 .stmts
199 .into_iter()
200 .chain(core::iter::once(ts))
201 .collect();
202 Self {
203 stmts: new_stmts,
204 ..self
205 }
206 }
207}
208
209fn expand_kernel(func: &ItemFn) -> Result<proc_macro2::TokenStream, syn::Error> {
210 let name = &func.sig.ident;
211 let vis = &func.vis;
212
213 let args: Vec<(String, ScalarTy, Ident)> = func
215 .sig
216 .inputs
217 .iter()
218 .map(parse_kernel_arg)
219 .collect::<Result<Vec<_>, _>>()?;
220
221 (!args.is_empty())
222 .then_some(())
223 .ok_or_else(|| syn::Error::new_spanned(&func.sig, "kernel needs at least one parameter"))?;
224
225 let out_ty = match &func.sig.output {
227 syn::ReturnType::Default => {
228 return Err(syn::Error::new_spanned(
229 &func.sig,
230 "kernel must return a scalar",
231 ));
232 }
233 syn::ReturnType::Type(_, t) => parse_scalar_ty(t)?,
234 };
235
236 let input_ty_tokens = build_input_type_tokens(&args);
238 let output_ty_tokens = out_ty.obj_ty_tokens();
239
240 let ctx = compile_body(&args, &func.block, &out_ty)?;
242
243 let arg_wire_decls: Vec<proc_macro2::TokenStream> = args
245 .iter()
246 .map(|(_, sty, ident)| {
247 let ty_tok = sty.wire_ty_tokens();
248 quote! {
249 let (bld, #ident) = bld.with_wire(#ty_tok);
250 }
251 })
252 .collect();
253
254 let arg_wire_idents: Vec<&Ident> = args.iter().map(|(_, _, id)| id).collect();
255
256 let final_output = ctx
259 .final_output
260 .ok_or_else(|| syn::Error::new_spanned(&func.block, "kernel body produced no value"))?;
261 let body_stmts = ctx.ctx.stmts;
262
263 Ok(quote! {
264 #vis fn #name() -> ::core::result::Result<
265 ::hdl_cat_circuit::CircuitArrow<#input_ty_tokens, #output_ty_tokens>,
266 ::hdl_cat_error::Error,
267 > {
268 let bld = ::hdl_cat_ir::HdlGraphBuilder::new();
269 #(#arg_wire_decls)*
270 #(#body_stmts)*
271 ::core::result::Result::Ok(
272 ::hdl_cat_circuit::CircuitArrow::from_raw_parts(
273 bld.build(),
274 vec![#(#arg_wire_idents),*],
275 vec![#final_output],
276 )
277 )
278 }
279 })
280}
281
282fn parse_kernel_arg(arg: &FnArg) -> Result<(String, ScalarTy, Ident), syn::Error> {
283 let FnArg::Typed(PatType { pat, ty, .. }) = arg else {
284 return Err(syn::Error::new_spanned(
285 arg,
286 "self parameters not supported",
287 ));
288 };
289 let syn::Pat::Ident(pat_ident) = pat.as_ref() else {
290 return Err(syn::Error::new_spanned(pat, "expected a simple identifier"));
291 };
292 let source_name = pat_ident.ident.to_string();
293 let wire_ident = Ident::new(
294 &format!("__k_arg_{source_name}"),
295 pat_ident.ident.span(),
296 );
297 let sty = parse_scalar_ty(ty)?;
298 Ok((source_name, sty, wire_ident))
299}
300
301fn build_input_type_tokens(
302 args: &[(String, ScalarTy, Ident)],
303) -> proc_macro2::TokenStream {
304 match args.len() {
305 0 => quote! { ::hdl_cat_circuit::CircuitUnit },
306 1 => args[0].1.obj_ty_tokens(),
307 _ => {
308 let (first_rest, last) = args.split_at(args.len() - 1);
309 let head = build_input_type_tokens_owned(first_rest);
310 let tail = last[0].1.obj_ty_tokens();
311 quote! { ::hdl_cat_circuit::CircuitTensor<#head, #tail> }
312 }
313 }
314}
315
316fn build_input_type_tokens_owned(
317 args: &[(String, ScalarTy, Ident)],
318) -> proc_macro2::TokenStream {
319 match args.len() {
320 0 => quote! { ::hdl_cat_circuit::CircuitUnit },
321 1 => args[0].1.obj_ty_tokens(),
322 _ => {
323 let (first_rest, last) = args.split_at(args.len() - 1);
324 let head = build_input_type_tokens_owned(first_rest);
325 let tail = last[0].1.obj_ty_tokens();
326 quote! { ::hdl_cat_circuit::CircuitTensor<#head, #tail> }
327 }
328 }
329}
330
331struct CompiledBody {
334 ctx: BodyCtx,
335 final_output: Option<Ident>,
336}
337
338fn compile_body(
339 args: &[(String, ScalarTy, Ident)],
340 block: &syn::Block,
341 _out_ty: &ScalarTy,
342) -> Result<CompiledBody, syn::Error> {
343 let initial_ctx = args.iter().fold(BodyCtx::new(), |ctx, (name, sty, wire_ident)| {
344 ctx.bind(name.clone(), wire_ident.clone(), sty.clone())
345 });
346 let (ctx, final_output, _ty) = compile_block(initial_ctx, block)?;
347 Ok(CompiledBody {
348 ctx,
349 final_output: Some(final_output),
350 })
351}
352
353fn compile_block(
354 ctx: BodyCtx,
355 block: &syn::Block,
356) -> Result<(BodyCtx, Ident, ScalarTy), syn::Error> {
357 let (head, tail) = block
358 .stmts
359 .split_last()
360 .ok_or_else(|| syn::Error::new_spanned(block, "empty kernel body"))?;
361
362 let ctx_after_lets = tail
364 .iter()
365 .try_fold(ctx, compile_let_stmt)?;
366
367 let tail_expr = match head {
369 Stmt::Expr(e, _) => Ok(e),
370 other => Err(syn::Error::new_spanned(
371 other,
372 "kernel body must end in an expression",
373 )),
374 }?;
375 compile_expr(ctx_after_lets, tail_expr)
376}
377
378fn compile_let_stmt(ctx: BodyCtx, stmt: &Stmt) -> Result<BodyCtx, syn::Error> {
379 let Stmt::Local(local) = stmt else {
380 return Err(syn::Error::new_spanned(
381 stmt,
382 "only `let` bindings allowed before the tail expression",
383 ));
384 };
385 let syn::Pat::Ident(pat_ident) = &local.pat else {
386 return Err(syn::Error::new_spanned(
387 &local.pat,
388 "expected a simple identifier",
389 ));
390 };
391 let name = pat_ident.ident.to_string();
392 let init = local
393 .init
394 .as_ref()
395 .ok_or_else(|| syn::Error::new_spanned(local, "`let` requires an initializer"))?;
396 let (ctx_after_rhs, wire, ty) = compile_expr(ctx, &init.expr)?;
397 Ok(ctx_after_rhs.bind(name, wire, ty))
398}
399
400fn compile_expr(
401 ctx: BodyCtx,
402 expr: &Expr,
403) -> Result<(BodyCtx, Ident, ScalarTy), syn::Error> {
404 match expr {
405 Expr::Path(p) => {
406 let ident = p
407 .path
408 .get_ident()
409 .ok_or_else(|| syn::Error::new_spanned(p, "expected bare identifier"))?;
410 let (id, ty) = ctx
411 .lookup(&ident.to_string())
412 .ok_or_else(|| syn::Error::new_spanned(ident, "unknown identifier"))?;
413 Ok((ctx, id, ty))
414 }
415 Expr::Binary(b) => compile_binary(ctx, b),
416 Expr::Unary(u) => compile_unary(ctx, u),
417 Expr::Paren(p) => compile_expr(ctx, &p.expr),
418 other => Err(syn::Error::new_spanned(
419 other,
420 "unsupported expression in kernel body",
421 )),
422 }
423}
424
425fn compile_binary(
426 ctx: BodyCtx,
427 b: &syn::ExprBinary,
428) -> Result<(BodyCtx, Ident, ScalarTy), syn::Error> {
429 let (ctx_l, lhs, lhs_ty) = compile_expr(ctx, &b.left)?;
430 let (ctx_lr, rhs, _rhs_ty) = compile_expr(ctx_l, &b.right)?;
431 let op_tok = bin_op_tokens(&b.op)?;
432 let (ctx_fresh, output) = ctx_lr.fresh_wire_ident();
433 let out_ty = lhs_ty;
436 let out_ty_tok = out_ty.wire_ty_tokens();
437 let stmt = quote! {
438 let (bld, #output) = bld.with_wire(#out_ty_tok);
439 let bld = bld.with_instruction(
440 ::hdl_cat_ir::Op::Bin(#op_tok),
441 vec![#lhs, #rhs],
442 #output,
443 )?;
444 };
445 let ctx_final = ctx_fresh.push_stmt(stmt);
446 Ok((ctx_final, output, out_ty))
447}
448
449fn compile_unary(
450 ctx: BodyCtx,
451 u: &syn::ExprUnary,
452) -> Result<(BodyCtx, Ident, ScalarTy), syn::Error> {
453 match u.op {
454 UnOp::Not(_) => {
455 let (ctx_inner, operand, operand_ty) = compile_expr(ctx, &u.expr)?;
456 let (ctx_fresh, output) = ctx_inner.fresh_wire_ident();
457 let ty_tok = operand_ty.wire_ty_tokens();
458 let stmt = quote! {
459 let (bld, #output) = bld.with_wire(#ty_tok);
460 let bld = bld.with_instruction(
461 ::hdl_cat_ir::Op::Not,
462 vec![#operand],
463 #output,
464 )?;
465 };
466 let ctx_final = ctx_fresh.push_stmt(stmt);
467 Ok((ctx_final, output, operand_ty))
468 }
469 other => Err(syn::Error::new_spanned(
470 other.into_token_stream(),
471 "only unary `!` is supported",
472 )),
473 }
474}
475
476fn bin_op_tokens(op: &BinOp) -> Result<proc_macro2::TokenStream, syn::Error> {
477 Ok(match op {
478 BinOp::Add(_) => quote! { ::hdl_cat_ir::BinOp::Add },
479 BinOp::Sub(_) => quote! { ::hdl_cat_ir::BinOp::Sub },
480 BinOp::Mul(_) => quote! { ::hdl_cat_ir::BinOp::Mul },
481 BinOp::BitAnd(_) => quote! { ::hdl_cat_ir::BinOp::And },
482 BinOp::BitOr(_) => quote! { ::hdl_cat_ir::BinOp::Or },
483 BinOp::BitXor(_) => quote! { ::hdl_cat_ir::BinOp::Xor },
484 BinOp::Eq(_) => quote! { ::hdl_cat_ir::BinOp::Eq },
485 BinOp::Lt(_) => quote! { ::hdl_cat_ir::BinOp::Lt },
486 other => {
487 return Err(syn::Error::new_spanned(
488 other.into_token_stream(),
489 "unsupported binary operator",
490 ));
491 }
492 })
493}