Skip to main content

hdl_cat_macros/
lib.rs

1//! `#[kernel]` attribute macro for `hdl-cat`.
2//!
3//! Lifts a Rust function with a restricted body into an IR
4//! builder returning a combinational
5//! `hdl_cat_circuit::CircuitArrow`.
6//!
7//! # Supported subset
8//!
9//! - Parameter types: `bool`, `Bits<N>`, `SignedBits<N>` (with
10//!   a literal `N`).
11//! - Return type: a single scalar (same constraints as parameters).
12//! - Body: zero or more `let` bindings followed by a tail
13//!   expression.
14//! - Expressions: identifiers, binary operators (`+`, `-`, `*`,
15//!   `&`, `|`, `^`), unary `!`.
16//!
17//! Function bodies outside this subset raise a compile error.
18//!
19//! # Example
20//!
21//! The macro cannot be doctested from within its own crate
22//! because expanded code references `hdl_cat_ir`, `hdl_cat_bits`,
23//! `hdl_cat_circuit`, and `hdl_cat_error`, which aren't in this
24//! crate's dependency graph.  Real doctests live in the umbrella
25//! `hdl-cat` crate's `tests/kernel_macro.rs`.
26//!
27//! Conceptually:
28//!
29//! ```ignore
30//! use hdl_cat_macros::kernel;
31//! use hdl_cat_bits::Bits;
32//!
33//! #[kernel]
34//! fn xor_plus_a(a: Bits<8>, b: Bits<8>) -> Bits<8> {
35//!     let x = a ^ b;
36//!     x + a
37//! }
38//! ```
39//!
40//! After expansion, `xor_plus_a` becomes a nullary function
41//! returning
42//! `Result<CircuitArrow<CircuitTensor<Obj<Bits<8>>, Obj<Bits<8>>>, Obj<Bits<8>>>, Error>`
43//! that builds the IR for the given expression.
44
45use proc_macro::TokenStream;
46use proc_macro2::Span;
47use quote::{quote, ToTokens};
48use syn::{parse_macro_input, BinOp, Expr, FnArg, Ident, ItemFn, Lit, PatType, Stmt, Type, UnOp};
49
50/// Lift a Rust function into an `hdl-cat` IR builder.
51#[proc_macro_attribute]
52pub fn kernel(_attr: TokenStream, item: TokenStream) -> TokenStream {
53    let input = parse_macro_input!(item as ItemFn);
54    expand_kernel(&input).map_or_else(|e| e.to_compile_error().into(), Into::into)
55}
56
57/// An abstract hardware type extracted from the Rust type syntax.
58#[derive(Clone)]
59enum ScalarTy {
60    Bool,
61    Bits(u32),
62    Signed(u32),
63}
64
65impl ScalarTy {
66    fn wire_ty_tokens(&self) -> proc_macro2::TokenStream {
67        match self {
68            Self::Bool => quote! { ::hdl_cat_ir::WireTy::Bit },
69            Self::Bits(n) => quote! { ::hdl_cat_ir::WireTy::Bits(#n) },
70            Self::Signed(n) => quote! { ::hdl_cat_ir::WireTy::Signed(#n) },
71        }
72    }
73
74    fn obj_ty_tokens(&self) -> proc_macro2::TokenStream {
75        match self {
76            Self::Bool => quote! { ::hdl_cat_circuit::Obj<bool> },
77            Self::Bits(n) => {
78                let n_literal = *n as usize;
79                quote! { ::hdl_cat_circuit::Obj<::hdl_cat_bits::Bits<#n_literal>> }
80            }
81            Self::Signed(n) => {
82                let n_literal = *n as usize;
83                quote! { ::hdl_cat_circuit::Obj<::hdl_cat_bits::SignedBits<#n_literal>> }
84            }
85        }
86    }
87}
88
89fn parse_scalar_ty(ty: &Type) -> Result<ScalarTy, syn::Error> {
90    let Type::Path(p) = ty else {
91        return Err(syn::Error::new_spanned(ty, "unsupported type"));
92    };
93    let segment = p
94        .path
95        .segments
96        .last()
97        .ok_or_else(|| syn::Error::new_spanned(p, "empty path"))?;
98    let name = segment.ident.to_string();
99    match name.as_str() {
100        "bool" => Ok(ScalarTy::Bool),
101        "Bits" | "SignedBits" => {
102            let syn::PathArguments::AngleBracketed(args) = &segment.arguments else {
103                return Err(syn::Error::new_spanned(
104                    segment,
105                    "Bits/SignedBits requires a const generic width",
106                ));
107            };
108            let arg = args.args.first().ok_or_else(|| {
109                syn::Error::new_spanned(args, "expected single const generic arg")
110            })?;
111            let width = const_width_from_generic_arg(arg)?;
112            if name == "Bits" {
113                Ok(ScalarTy::Bits(width))
114            } else {
115                Ok(ScalarTy::Signed(width))
116            }
117        }
118        other => Err(syn::Error::new_spanned(
119            segment,
120            format!("unsupported type `{other}`"),
121        )),
122    }
123}
124
125fn const_width_from_generic_arg(arg: &syn::GenericArgument) -> Result<u32, syn::Error> {
126    let expr = match arg {
127        syn::GenericArgument::Const(e) => Ok(e),
128        syn::GenericArgument::Type(Type::Path(p)) => Err(syn::Error::new_spanned(
129            p,
130            "expected a literal width, not a type path",
131        )),
132        other => Err(syn::Error::new_spanned(other, "expected a const literal width")),
133    }?;
134    let Expr::Lit(lit) = expr else {
135        return Err(syn::Error::new_spanned(expr, "expected a const literal"));
136    };
137    let Lit::Int(n) = &lit.lit else {
138        return Err(syn::Error::new_spanned(&lit.lit, "expected an integer literal"));
139    };
140    n.base10_parse::<u32>()
141}
142
143/// Immutable state threaded through body compilation.
144///
145/// Every mutation returns a new `BodyCtx` — no `&mut`.
146#[derive(Clone)]
147struct BodyCtx {
148    stmts: Vec<proc_macro2::TokenStream>,
149    env: Vec<(String, Ident, ScalarTy)>,
150    fresh_counter: usize,
151}
152
153impl BodyCtx {
154    fn new() -> Self {
155        Self {
156            stmts: Vec::new(),
157            env: Vec::new(),
158            fresh_counter: 0,
159        }
160    }
161
162    fn fresh_wire_ident(self) -> (Self, Ident) {
163        let id = Ident::new(
164            &format!("__k_tmp_{}", self.fresh_counter),
165            Span::call_site(),
166        );
167        (
168            Self {
169                fresh_counter: self.fresh_counter + 1,
170                ..self
171            },
172            id,
173        )
174    }
175
176    fn bind(self, source_name: String, wire_ident: Ident, ty: ScalarTy) -> Self {
177        let new_env = self
178            .env
179            .into_iter()
180            .chain(core::iter::once((source_name, wire_ident, ty)))
181            .collect();
182        Self {
183            env: new_env,
184            ..self
185        }
186    }
187
188    fn lookup(&self, name: &str) -> Option<(Ident, ScalarTy)> {
189        self.env
190            .iter()
191            .rev()
192            .find(|(n, _, _)| n == name)
193            .map(|(_, id, ty)| (id.clone(), ty.clone()))
194    }
195
196    fn push_stmt(self, ts: proc_macro2::TokenStream) -> Self {
197        let new_stmts = self
198            .stmts
199            .into_iter()
200            .chain(core::iter::once(ts))
201            .collect();
202        Self {
203            stmts: new_stmts,
204            ..self
205        }
206    }
207}
208
209fn expand_kernel(func: &ItemFn) -> Result<proc_macro2::TokenStream, syn::Error> {
210    let name = &func.sig.ident;
211    let vis = &func.vis;
212
213    // Extract args.
214    let args: Vec<(String, ScalarTy, Ident)> = func
215        .sig
216        .inputs
217        .iter()
218        .map(parse_kernel_arg)
219        .collect::<Result<Vec<_>, _>>()?;
220
221    (!args.is_empty())
222        .then_some(())
223        .ok_or_else(|| syn::Error::new_spanned(&func.sig, "kernel needs at least one parameter"))?;
224
225    // Return type.
226    let out_ty = match &func.sig.output {
227        syn::ReturnType::Default => {
228            return Err(syn::Error::new_spanned(
229                &func.sig,
230                "kernel must return a scalar",
231            ));
232        }
233        syn::ReturnType::Type(_, t) => parse_scalar_ty(t)?,
234    };
235
236    // Build input-side Obj<...> and CircuitTensor<...> types (left-nested).
237    let input_ty_tokens = build_input_type_tokens(&args);
238    let output_ty_tokens = out_ty.obj_ty_tokens();
239
240    // Compile body.
241    let ctx = compile_body(&args, &func.block, &out_ty)?;
242
243    // Generate argument wire declarations.
244    let arg_wire_decls: Vec<proc_macro2::TokenStream> = args
245        .iter()
246        .map(|(_, sty, ident)| {
247            let ty_tok = sty.wire_ty_tokens();
248            quote! {
249                let (bld, #ident) = bld.with_wire(#ty_tok);
250            }
251        })
252        .collect();
253
254    let arg_wire_idents: Vec<&Ident> = args.iter().map(|(_, _, id)| id).collect();
255
256    // The final output wire identifier is the last fresh wire
257    // emitted by `compile_body`, stored in `ctx.final_output`.
258    let final_output = ctx
259        .final_output
260        .ok_or_else(|| syn::Error::new_spanned(&func.block, "kernel body produced no value"))?;
261    let body_stmts = ctx.ctx.stmts;
262
263    Ok(quote! {
264        #vis fn #name() -> ::core::result::Result<
265            ::hdl_cat_circuit::CircuitArrow<#input_ty_tokens, #output_ty_tokens>,
266            ::hdl_cat_error::Error,
267        > {
268            let bld = ::hdl_cat_ir::HdlGraphBuilder::new();
269            #(#arg_wire_decls)*
270            #(#body_stmts)*
271            ::core::result::Result::Ok(
272                ::hdl_cat_circuit::CircuitArrow::from_raw_parts(
273                    bld.build(),
274                    vec![#(#arg_wire_idents),*],
275                    vec![#final_output],
276                )
277            )
278        }
279    })
280}
281
282fn parse_kernel_arg(arg: &FnArg) -> Result<(String, ScalarTy, Ident), syn::Error> {
283    let FnArg::Typed(PatType { pat, ty, .. }) = arg else {
284        return Err(syn::Error::new_spanned(
285            arg,
286            "self parameters not supported",
287        ));
288    };
289    let syn::Pat::Ident(pat_ident) = pat.as_ref() else {
290        return Err(syn::Error::new_spanned(pat, "expected a simple identifier"));
291    };
292    let source_name = pat_ident.ident.to_string();
293    let wire_ident = Ident::new(
294        &format!("__k_arg_{source_name}"),
295        pat_ident.ident.span(),
296    );
297    let sty = parse_scalar_ty(ty)?;
298    Ok((source_name, sty, wire_ident))
299}
300
301fn build_input_type_tokens(
302    args: &[(String, ScalarTy, Ident)],
303) -> proc_macro2::TokenStream {
304    match args.len() {
305        0 => quote! { ::hdl_cat_circuit::CircuitUnit },
306        1 => args[0].1.obj_ty_tokens(),
307        _ => {
308            let (first_rest, last) = args.split_at(args.len() - 1);
309            let head = build_input_type_tokens_owned(first_rest);
310            let tail = last[0].1.obj_ty_tokens();
311            quote! { ::hdl_cat_circuit::CircuitTensor<#head, #tail> }
312        }
313    }
314}
315
316fn build_input_type_tokens_owned(
317    args: &[(String, ScalarTy, Ident)],
318) -> proc_macro2::TokenStream {
319    match args.len() {
320        0 => quote! { ::hdl_cat_circuit::CircuitUnit },
321        1 => args[0].1.obj_ty_tokens(),
322        _ => {
323            let (first_rest, last) = args.split_at(args.len() - 1);
324            let head = build_input_type_tokens_owned(first_rest);
325            let tail = last[0].1.obj_ty_tokens();
326            quote! { ::hdl_cat_circuit::CircuitTensor<#head, #tail> }
327        }
328    }
329}
330
331/// Bundle of context and final-output wire identifier after
332/// compiling a kernel body.
333struct CompiledBody {
334    ctx: BodyCtx,
335    final_output: Option<Ident>,
336}
337
338fn compile_body(
339    args: &[(String, ScalarTy, Ident)],
340    block: &syn::Block,
341    _out_ty: &ScalarTy,
342) -> Result<CompiledBody, syn::Error> {
343    let initial_ctx = args.iter().fold(BodyCtx::new(), |ctx, (name, sty, wire_ident)| {
344        ctx.bind(name.clone(), wire_ident.clone(), sty.clone())
345    });
346    let (ctx, final_output, _ty) = compile_block(initial_ctx, block)?;
347    Ok(CompiledBody {
348        ctx,
349        final_output: Some(final_output),
350    })
351}
352
353fn compile_block(
354    ctx: BodyCtx,
355    block: &syn::Block,
356) -> Result<(BodyCtx, Ident, ScalarTy), syn::Error> {
357    let (head, tail) = block
358        .stmts
359        .split_last()
360        .ok_or_else(|| syn::Error::new_spanned(block, "empty kernel body"))?;
361
362    // Fold over leading `let` statements, threading the ctx.
363    let ctx_after_lets = tail
364        .iter()
365        .try_fold(ctx, compile_let_stmt)?;
366
367    // Tail expression produces the block's value.
368    let tail_expr = match head {
369        Stmt::Expr(e, _) => Ok(e),
370        other => Err(syn::Error::new_spanned(
371            other,
372            "kernel body must end in an expression",
373        )),
374    }?;
375    compile_expr(ctx_after_lets, tail_expr)
376}
377
378fn compile_let_stmt(ctx: BodyCtx, stmt: &Stmt) -> Result<BodyCtx, syn::Error> {
379    let Stmt::Local(local) = stmt else {
380        return Err(syn::Error::new_spanned(
381            stmt,
382            "only `let` bindings allowed before the tail expression",
383        ));
384    };
385    let syn::Pat::Ident(pat_ident) = &local.pat else {
386        return Err(syn::Error::new_spanned(
387            &local.pat,
388            "expected a simple identifier",
389        ));
390    };
391    let name = pat_ident.ident.to_string();
392    let init = local
393        .init
394        .as_ref()
395        .ok_or_else(|| syn::Error::new_spanned(local, "`let` requires an initializer"))?;
396    let (ctx_after_rhs, wire, ty) = compile_expr(ctx, &init.expr)?;
397    Ok(ctx_after_rhs.bind(name, wire, ty))
398}
399
400fn compile_expr(
401    ctx: BodyCtx,
402    expr: &Expr,
403) -> Result<(BodyCtx, Ident, ScalarTy), syn::Error> {
404    match expr {
405        Expr::Path(p) => {
406            let ident = p
407                .path
408                .get_ident()
409                .ok_or_else(|| syn::Error::new_spanned(p, "expected bare identifier"))?;
410            let (id, ty) = ctx
411                .lookup(&ident.to_string())
412                .ok_or_else(|| syn::Error::new_spanned(ident, "unknown identifier"))?;
413            Ok((ctx, id, ty))
414        }
415        Expr::Binary(b) => compile_binary(ctx, b),
416        Expr::Unary(u) => compile_unary(ctx, u),
417        Expr::Paren(p) => compile_expr(ctx, &p.expr),
418        other => Err(syn::Error::new_spanned(
419            other,
420            "unsupported expression in kernel body",
421        )),
422    }
423}
424
425fn compile_binary(
426    ctx: BodyCtx,
427    b: &syn::ExprBinary,
428) -> Result<(BodyCtx, Ident, ScalarTy), syn::Error> {
429    let (ctx_l, lhs, lhs_ty) = compile_expr(ctx, &b.left)?;
430    let (ctx_lr, rhs, _rhs_ty) = compile_expr(ctx_l, &b.right)?;
431    let op_tok = bin_op_tokens(&b.op)?;
432    let (ctx_fresh, output) = ctx_lr.fresh_wire_ident();
433    // Output type = lhs type (binary ops preserve width in our IR
434    // except for comparisons, which are out-of-scope for v1).
435    let out_ty = lhs_ty;
436    let out_ty_tok = out_ty.wire_ty_tokens();
437    let stmt = quote! {
438        let (bld, #output) = bld.with_wire(#out_ty_tok);
439        let bld = bld.with_instruction(
440            ::hdl_cat_ir::Op::Bin(#op_tok),
441            vec![#lhs, #rhs],
442            #output,
443        )?;
444    };
445    let ctx_final = ctx_fresh.push_stmt(stmt);
446    Ok((ctx_final, output, out_ty))
447}
448
449fn compile_unary(
450    ctx: BodyCtx,
451    u: &syn::ExprUnary,
452) -> Result<(BodyCtx, Ident, ScalarTy), syn::Error> {
453    match u.op {
454        UnOp::Not(_) => {
455            let (ctx_inner, operand, operand_ty) = compile_expr(ctx, &u.expr)?;
456            let (ctx_fresh, output) = ctx_inner.fresh_wire_ident();
457            let ty_tok = operand_ty.wire_ty_tokens();
458            let stmt = quote! {
459                let (bld, #output) = bld.with_wire(#ty_tok);
460                let bld = bld.with_instruction(
461                    ::hdl_cat_ir::Op::Not,
462                    vec![#operand],
463                    #output,
464                )?;
465            };
466            let ctx_final = ctx_fresh.push_stmt(stmt);
467            Ok((ctx_final, output, operand_ty))
468        }
469        other => Err(syn::Error::new_spanned(
470            other.into_token_stream(),
471            "only unary `!` is supported",
472        )),
473    }
474}
475
476fn bin_op_tokens(op: &BinOp) -> Result<proc_macro2::TokenStream, syn::Error> {
477    Ok(match op {
478        BinOp::Add(_) => quote! { ::hdl_cat_ir::BinOp::Add },
479        BinOp::Sub(_) => quote! { ::hdl_cat_ir::BinOp::Sub },
480        BinOp::Mul(_) => quote! { ::hdl_cat_ir::BinOp::Mul },
481        BinOp::BitAnd(_) => quote! { ::hdl_cat_ir::BinOp::And },
482        BinOp::BitOr(_) => quote! { ::hdl_cat_ir::BinOp::Or },
483        BinOp::BitXor(_) => quote! { ::hdl_cat_ir::BinOp::Xor },
484        BinOp::Eq(_) => quote! { ::hdl_cat_ir::BinOp::Eq },
485        BinOp::Lt(_) => quote! { ::hdl_cat_ir::BinOp::Lt },
486        other => {
487            return Err(syn::Error::new_spanned(
488                other.into_token_stream(),
489                "unsupported binary operator",
490            ));
491        }
492    })
493}