Skip to main content

packtab_macro/
lib.rs

1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::parse::{Parse, ParseStream};
6use syn::{parse_macro_input, Expr, Ident, LitInt, Token, Visibility, Type};
7
8/// Input syntax:
9/// ```text
10/// packtab_macro::pack_table! {
11///     pub fn lookup(u: usize) -> u8 {
12///         data: [1, 2, 3, 4, 5],
13///         default: 0,
14///         compression: 1.0,
15///     }
16/// }
17/// ```
18struct PackTableInput {
19    vis: Visibility,
20    fn_name: Ident,
21    _arg_name: Ident,
22    _ret_type: Type,
23    data: Vec<i64>,
24    default: i64,
25    compression: f64,
26    unsafe_access: bool,
27}
28
29impl Parse for PackTableInput {
30    fn parse(input: ParseStream) -> syn::Result<Self> {
31        let vis: Visibility = input.parse()?;
32        input.parse::<Token![fn]>()?;
33        let fn_name: Ident = input.parse()?;
34
35        let paren_content;
36        syn::parenthesized!(paren_content in input);
37        let arg_name: Ident = paren_content.parse()?;
38        paren_content.parse::<Token![:]>()?;
39        let _arg_type: Type = paren_content.parse()?;
40
41        input.parse::<Token![->]>()?;
42        let ret_type: Type = input.parse()?;
43
44        let brace_content;
45        syn::braced!(brace_content in input);
46
47        // data: [...]
48        let data_ident: Ident = brace_content.parse()?;
49        if data_ident != "data" {
50            return Err(syn::Error::new_spanned(data_ident, "expected 'data'"));
51        }
52        brace_content.parse::<Token![:]>()?;
53
54        let bracket_content;
55        syn::bracketed!(bracket_content in brace_content);
56        let mut data = Vec::new();
57        while !bracket_content.is_empty() {
58            if bracket_content.peek(Token![-]) {
59                bracket_content.parse::<Token![-]>()?;
60                let lit: LitInt = bracket_content.parse()?;
61                data.push(-(lit.base10_parse::<i64>()?));
62            } else {
63                let lit: LitInt = bracket_content.parse()?;
64                data.push(lit.base10_parse::<i64>()?);
65            }
66            if bracket_content.peek(Token![,]) {
67                bracket_content.parse::<Token![,]>()?;
68            }
69        }
70        brace_content.parse::<Token![,]>()?;
71
72        // default: N
73        let default_ident: Ident = brace_content.parse()?;
74        if default_ident != "default" {
75            return Err(syn::Error::new_spanned(default_ident, "expected 'default'"));
76        }
77        brace_content.parse::<Token![:]>()?;
78        let default = if brace_content.peek(Token![-]) {
79            brace_content.parse::<Token![-]>()?;
80            let lit: LitInt = brace_content.parse()?;
81            -(lit.base10_parse::<i64>()?)
82        } else {
83            let lit: LitInt = brace_content.parse()?;
84            lit.base10_parse::<i64>()?
85        };
86
87        // Optional trailing fields: compression, unsafe
88        let mut compression = 1.0f64;
89        let mut unsafe_access = false;
90        while brace_content.peek(Token![,]) {
91            brace_content.parse::<Token![,]>()?;
92            if brace_content.is_empty() {
93                break;
94            }
95            if brace_content.peek(Token![unsafe]) {
96                let kw: Token![unsafe] = brace_content.parse()?;
97                brace_content.parse::<Token![:]>()?;
98                let lit: syn::LitBool = brace_content.parse()
99                    .map_err(|_| syn::Error::new_spanned(kw, "expected bool after 'unsafe:'"))?;
100                unsafe_access = lit.value;
101            } else {
102                let ident: Ident = brace_content.parse()?;
103                match ident.to_string().as_str() {
104                    "compression" => {
105                        brace_content.parse::<Token![:]>()?;
106                        let expr: Expr = brace_content.parse()?;
107                        compression = match &expr {
108                            Expr::Lit(lit) => match &lit.lit {
109                                syn::Lit::Float(f) => f.base10_parse::<f64>()?,
110                                syn::Lit::Int(i) => i.base10_parse::<f64>()?,
111                                _ => return Err(syn::Error::new_spanned(lit, "expected number")),
112                            },
113                            _ => return Err(syn::Error::new_spanned(expr, "expected number literal")),
114                        };
115                    }
116                    _ => return Err(syn::Error::new_spanned(ident, "expected 'compression' or 'unsafe'")),
117                }
118            }
119        }
120
121        Ok(PackTableInput {
122            vis,
123            fn_name,
124            _arg_name: arg_name,
125            _ret_type: ret_type,
126            data,
127            default,
128            compression,
129            unsafe_access,
130        })
131    }
132}
133
134/// Pack a table of integers into compact multi-level lookup tables at compile time.
135///
136/// # Example
137///
138/// ```text
139/// packtab_macro::pack_table! {
140///     pub fn lookup(u: usize) -> u8 {
141///         data: [1, 2, 3, 4, 5, 6, 7, 8],
142///         default: 0,
143///     }
144/// }
145/// ```
146#[proc_macro]
147pub fn pack_table(input: TokenStream) -> TokenStream {
148    let input = parse_macro_input!(input as PackTableInput);
149
150    let (info, best_idx) = packtab::pack_table(&input.data, input.default, input.compression);
151    let code_str = packtab::generate(
152        &info,
153        best_idx,
154        &input.fn_name.to_string(),
155        packtab::codegen::Language::Rust { unsafe_access: input.unsafe_access },
156    );
157
158    // Adjust visibility: replace "pub(crate) fn name_get" with user's visibility + name.
159    let vis_str = match &input.vis {
160        Visibility::Public(_) => "pub",
161        Visibility::Inherited => "",
162        _ => "pub(crate)",
163    };
164
165    let fn_name_str = input.fn_name.to_string();
166    let adjusted = code_str.replace(
167        &format!("pub(crate) fn {}_get", fn_name_str),
168        &format!("{} fn {}", vis_str, fn_name_str),
169    );
170    // Replace internal references to name_get with just name
171    let adjusted = adjusted.replace(
172        &format!("{}_get", fn_name_str),
173        &fn_name_str,
174    );
175
176    let generated: proc_macro2::TokenStream = adjusted
177        .parse()
178        .unwrap_or_else(|e| panic!("Failed to parse generated code: {}\n\nCode:\n{}", e, adjusted));
179
180    let output = quote! {
181        #generated
182    };
183
184    output.into()
185}