kit_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use quote::quote;
4use syn::{parse::Parse, parse::ParseStream, parse_macro_input, DeriveInput, Expr, LitStr, Token};
5use std::path::PathBuf;
6
7#[proc_macro_derive(InertiaProps)]
8pub fn derive_inertia_props(input: TokenStream) -> TokenStream {
9    let input = parse_macro_input!(input as DeriveInput);
10    let name = &input.ident;
11    let generics = &input.generics;
12    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
13
14    // Extract field information for generating Serialize impl
15    let fields = match &input.data {
16        syn::Data::Struct(data) => match &data.fields {
17            syn::Fields::Named(fields) => &fields.named,
18            _ => {
19                return syn::Error::new_spanned(
20                    &input,
21                    "InertiaProps only supports structs with named fields",
22                )
23                .to_compile_error()
24                .into();
25            }
26        },
27        _ => {
28            return syn::Error::new_spanned(&input, "InertiaProps can only be derived for structs")
29                .to_compile_error()
30                .into();
31        }
32    };
33
34    let field_count = fields.len();
35    let field_names: Vec<_> = fields.iter().map(|f| &f.ident).collect();
36    let field_name_strings: Vec<_> = fields
37        .iter()
38        .map(|f| f.ident.as_ref().unwrap().to_string())
39        .collect();
40
41    let expanded = quote! {
42        impl #impl_generics ::kit::serde::Serialize for #name #ty_generics #where_clause {
43            fn serialize<S>(&self, serializer: S) -> ::core::result::Result<S::Ok, S::Error>
44            where
45                S: ::kit::serde::Serializer,
46            {
47                use ::kit::serde::ser::SerializeStruct;
48                let mut state = serializer.serialize_struct(stringify!(#name), #field_count)?;
49                #(
50                    state.serialize_field(#field_name_strings, &self.#field_names)?;
51                )*
52                state.end()
53            }
54        }
55    };
56
57    expanded.into()
58}
59
60/// Props can be either a typed struct expression or JSON-like syntax
61enum PropsKind {
62    /// Typed struct: `HomeProps { title: "Welcome".into(), user }`
63    Typed(Expr),
64    /// JSON-like syntax: `{ "title": "Welcome" }`
65    Json(proc_macro2::TokenStream),
66}
67
68/// Custom parser for inertia_response! arguments
69struct InertiaResponseInput {
70    component: LitStr,
71    _comma: Token![,],
72    props: PropsKind,
73    config: Option<ConfigArg>,
74}
75
76struct ConfigArg {
77    _comma: Token![,],
78    expr: Expr,
79}
80
81impl Parse for InertiaResponseInput {
82    fn parse(input: ParseStream) -> syn::Result<Self> {
83        let component: LitStr = input.parse()?;
84        let comma: Token![,] = input.parse()?;
85
86        // Determine if this is a typed struct or JSON syntax
87        // Typed struct: identifier followed by { }
88        // JSON syntax: directly { }
89        let props = if input.peek(syn::Ident) {
90            // This is a typed struct expression: `HomeProps { ... }`
91            let expr: Expr = input.parse()?;
92            PropsKind::Typed(expr)
93        } else {
94            // This is JSON-like syntax: `{ "key": value }`
95            let props_content;
96            syn::braced!(props_content in input);
97            let props_tokens: proc_macro2::TokenStream = props_content.parse()?;
98            PropsKind::Json(props_tokens)
99        };
100
101        // Check for optional config argument
102        let config = if input.peek(Token![,]) {
103            let config_comma: Token![,] = input.parse()?;
104            let config_expr: Expr = input.parse()?;
105            Some(ConfigArg {
106                _comma: config_comma,
107                expr: config_expr,
108            })
109        } else {
110            None
111        };
112
113        Ok(InertiaResponseInput {
114            component,
115            _comma: comma,
116            props,
117            config,
118        })
119    }
120}
121
122/// Create an Inertia response with compile-time component validation
123///
124/// # Examples
125///
126/// ## With typed struct (recommended for type safety):
127/// ```rust,ignore
128/// #[derive(InertiaProps)]
129/// struct HomeProps {
130///     title: String,
131///     user: User,
132/// }
133///
134/// inertia_response!("Home", HomeProps { title: "Welcome".into(), user })
135/// ```
136///
137/// ## With JSON-like syntax (for quick prototyping):
138/// ```rust,ignore
139/// inertia_response!("Dashboard", { "user": { "name": "John" } })
140/// ```
141///
142/// This macro validates that the component file exists at compile time.
143/// If `frontend/src/pages/Dashboard.tsx` doesn't exist, you'll get a compile error.
144#[proc_macro]
145pub fn inertia_response(input: TokenStream) -> TokenStream {
146    let input = parse_macro_input!(input as InertiaResponseInput);
147
148    let component_name = input.component.value();
149    let component_lit = &input.component;
150
151    // Validate the component exists at compile time
152    if let Err(err) = validate_component_exists(&component_name, component_lit.span()) {
153        return err.to_compile_error().into();
154    }
155
156    // Generate props conversion based on props kind
157    let props_expr = match &input.props {
158        PropsKind::Typed(expr) => {
159            // Typed struct: serialize using serde_json::to_value
160            quote! {
161                ::kit::serde_json::to_value(&#expr)
162                    .expect("Failed to serialize InertiaProps")
163            }
164        }
165        PropsKind::Json(tokens) => {
166            // JSON-like syntax: use serde_json::json! macro
167            quote! {
168                ::kit::serde_json::json!({#tokens})
169            }
170        }
171    };
172
173    // Generate the appropriate expansion based on whether config is provided
174    let expanded = if let Some(config) = input.config {
175        let config_expr = config.expr;
176        quote! {{
177            let props = #props_expr;
178            let url = ::kit::InertiaContext::current_path();
179            let response = ::kit::InertiaResponse::new(#component_lit, props, url)
180                .with_config(#config_expr);
181
182            if ::kit::InertiaContext::is_inertia_request() {
183                Ok(response.to_json_response())
184            } else {
185                Ok(response.to_html_response())
186            }
187        }}
188    } else {
189        quote! {{
190            let props = #props_expr;
191            let url = ::kit::InertiaContext::current_path();
192            let response = ::kit::InertiaResponse::new(#component_lit, props, url);
193
194            if ::kit::InertiaContext::is_inertia_request() {
195                Ok(response.to_json_response())
196            } else {
197                Ok(response.to_html_response())
198            }
199        }}
200    };
201
202    expanded.into()
203}
204
205fn validate_component_exists(component_name: &str, span: Span) -> Result<(), syn::Error> {
206    // Get the manifest directory (where Cargo.toml is)
207    let manifest_dir = match std::env::var("CARGO_MANIFEST_DIR") {
208        Ok(dir) => dir,
209        Err(_) => {
210            // In environments where CARGO_MANIFEST_DIR isn't set (e.g., some IDEs),
211            // skip validation gracefully
212            return Ok(());
213        }
214    };
215
216    let project_root = PathBuf::from(&manifest_dir);
217
218    // Build the expected component path
219    // Support nested paths: "Users/Profile" -> frontend/src/pages/Users/Profile.tsx
220    let component_path = project_root
221        .join("frontend")
222        .join("src")
223        .join("pages")
224        .join(format!("{}.tsx", component_name));
225
226    if !component_path.exists() {
227        // Build helpful error message with available components
228        let available = list_available_components(&project_root);
229
230        let mut error_msg = format!(
231            "Inertia component '{}' not found.\nExpected file: frontend/src/pages/{}.tsx",
232            component_name, component_name
233        );
234
235        if !available.is_empty() {
236            error_msg.push_str("\n\nAvailable components:");
237            for comp in &available {
238                error_msg.push_str(&format!("\n  - {}", comp));
239            }
240
241            // Suggest similar components (fuzzy matching)
242            if let Some(suggestion) = find_similar_component(component_name, &available) {
243                error_msg.push_str(&format!("\n\nDid you mean '{}'?", suggestion));
244            }
245        } else {
246            error_msg.push_str("\n\nNo components found in frontend/src/pages/");
247            error_msg.push_str("\nMake sure your frontend directory structure is set up correctly.");
248        }
249
250        return Err(syn::Error::new(span, error_msg));
251    }
252
253    Ok(())
254}
255
256fn list_available_components(project_root: &PathBuf) -> Vec<String> {
257    let pages_dir = project_root
258        .join("frontend")
259        .join("src")
260        .join("pages");
261
262    let mut components = Vec::new();
263    collect_components_recursive(&pages_dir, &pages_dir, &mut components);
264    components.sort();
265    components
266}
267
268fn collect_components_recursive(base_dir: &PathBuf, current_dir: &PathBuf, components: &mut Vec<String>) {
269    if let Ok(entries) = std::fs::read_dir(current_dir) {
270        for entry in entries.filter_map(|e| e.ok()) {
271            let path = entry.path();
272
273            if path.is_dir() {
274                // Recurse into subdirectories
275                collect_components_recursive(base_dir, &path, components);
276            } else if path.extension().map(|e| e == "tsx").unwrap_or(false) {
277                // Get relative path from pages directory
278                if let Ok(relative) = path.strip_prefix(base_dir) {
279                    if let Some(stem) = relative.with_extension("").to_str() {
280                        // Convert path separators to forward slashes for consistency
281                        let component_name = stem.replace(std::path::MAIN_SEPARATOR, "/");
282                        components.push(component_name);
283                    }
284                }
285            }
286        }
287    }
288}
289
290fn find_similar_component(target: &str, available: &[String]) -> Option<String> {
291    let target_lower = target.to_lowercase();
292
293    // Check for case-insensitive exact match first
294    for comp in available {
295        if comp.to_lowercase() == target_lower {
296            return Some(comp.clone());
297        }
298    }
299
300    // Find closest match using Levenshtein distance
301    let mut best_match: Option<(String, usize)> = None;
302
303    for comp in available {
304        let distance = levenshtein_distance(&target_lower, &comp.to_lowercase());
305        // Allow up to 2 character differences for short names, more for longer names
306        let threshold = std::cmp::max(2, target.len() / 3);
307        if distance <= threshold {
308            if best_match.is_none() || distance < best_match.as_ref().unwrap().1 {
309                best_match = Some((comp.clone(), distance));
310            }
311        }
312    }
313
314    best_match.map(|(name, _)| name)
315}
316
317fn levenshtein_distance(a: &str, b: &str) -> usize {
318    let a_chars: Vec<char> = a.chars().collect();
319    let b_chars: Vec<char> = b.chars().collect();
320    let len_a = a_chars.len();
321    let len_b = b_chars.len();
322
323    if len_a == 0 { return len_b; }
324    if len_b == 0 { return len_a; }
325
326    let mut matrix: Vec<Vec<usize>> = vec![vec![0; len_b + 1]; len_a + 1];
327
328    for i in 0..=len_a { matrix[i][0] = i; }
329    for j in 0..=len_b { matrix[0][j] = j; }
330
331    for i in 1..=len_a {
332        for j in 1..=len_b {
333            let cost = if a_chars[i - 1] == b_chars[j - 1] { 0 } else { 1 };
334            matrix[i][j] = std::cmp::min(
335                std::cmp::min(matrix[i - 1][j] + 1, matrix[i][j - 1] + 1),
336                matrix[i - 1][j - 1] + cost
337            );
338        }
339    }
340
341    matrix[len_a][len_b]
342}