Skip to main content

ts_function/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{
4    Error, FnArg, GenericArgument, Ident, Item, ItemImpl, ItemType, PathArguments, ReturnType,
5    Type, parse_macro_input,
6};
7
8#[macro_use]
9mod ts_type;
10mod ts_macro;
11
12use crate::ts_type::ToTsType;
13
14/// Generates TypeScript interface bindings from a Rust struct.
15///
16/// This attribute works identically to the upstream `ts-macro` attribute, allowing
17/// the struct to define a TypeScript interface with property bindings seamlessly
18/// mapped to Javascript functions.
19///
20/// It generates:
21/// 1. A TypeScript interface string exposed as a custom wasm section
22/// 2. Extensible bindings and trait implementations
23///
24/// The default behavior for field names is to convert to `camelCase` for Javascript conventions.
25/// However, you can opt-out by adding `rename_all = "none"`:
26///
27/// ```rust,ignore
28/// #[ts(rename_all = "none")]
29/// struct MyStruct {
30///     my_field_name: String, // Will remain "my_field_name" in TypeScript
31/// }
32/// ```
33#[proc_macro_attribute]
34pub fn ts(attr: TokenStream, input: TokenStream) -> TokenStream {
35    ts_macro::ts(attr, input)
36}
37
38struct ParsedSignature<'a> {
39    struct_ident: &'a Ident,
40    args: Vec<(Ident, &'a Type)>,
41    output: &'a ReturnType,
42}
43
44/// Generates TypeScript type aliases and `wasm-bindgen` ABI trait implementations
45/// for Rust callback wrapper structs.
46///
47/// `ts-function` acts as a bridge for function/callback types in pure Rust when
48/// interoperating with TypeScript using `ts-macro`. It can be applied to either
49/// type aliases (`pub type MyCb = fn(args: ...)`) or `impl` blocks (the "escape hatch").
50///
51/// # Examples
52///
53/// **Basic Usage**
54///
55/// ```rust,ignore
56/// use ts_function::{ts, ts_function};
57///
58/// #[ts_function]
59/// pub type OnReadyCb = fn(msg: String);
60///
61/// #[ts]
62/// struct AppCallbacks {
63///     on_ready: OnReadyCb,
64/// }
65/// ```
66///
67/// **Escape Hatch Usage**
68///
69/// For completely custom serialization or embedding specific side-effects and error
70/// handling directly into the callback execution:
71///
72/// ```rust,ignore
73/// use wasm_bindgen::prelude::*;
74/// use ts_function::ts_function;
75///
76/// pub struct CustomLoggingCallback(pub js_sys::Function);
77///
78/// #[ts_function]
79/// impl CustomLoggingCallback {
80///     pub fn call(&self, val: f64) {
81///         // Call the JS function and handle errors internally
82///         let _ = self.0.call1(
83///             &wasm_bindgen::JsValue::NULL,
84///             &wasm_bindgen::JsValue::from_f64(val),
85///         );
86///     }
87/// }
88/// ```
89#[proc_macro_attribute]
90pub fn ts_function(_attr: TokenStream, item: TokenStream) -> TokenStream {
91    let item = parse_macro_input!(item as Item);
92
93    let result = match &item {
94        Item::Type(item_type) => parse_item_type(item_type),
95        Item::Impl(item_impl) => parse_item_impl(item_impl),
96        _ => {
97            return Error::new_spanned(
98                item,
99                "#[ts_function] can only be applied to a type alias or an impl block",
100            )
101            .to_compile_error()
102            .into();
103        }
104    };
105
106    match result {
107        Ok(tokens) => tokens.into(),
108        Err(err) => err.to_compile_error().into(),
109    }
110}
111
112fn generate_return_conversion(ty: &Type) -> syn::Result<proc_macro2::TokenStream> {
113    match ty {
114        Type::Path(type_path) => {
115            let segment = type_path.path.segments.last().unwrap();
116            let ident = &segment.ident;
117            let ident_str = ident.to_string();
118
119            if let Some(inner_ty) = get_slice_element_type(ty)
120                && let Some(arr_type) = get_typed_array_ident(inner_ty)
121            {
122                return Ok(quote! {
123                    let arr: ::js_sys::#arr_type = ::wasm_bindgen::JsCast::unchecked_into(res);
124                    Ok(::std::convert::Into::<#ty>::into(arr.to_vec()))
125                });
126            }
127
128            match ident_str.as_str() {
129                "f32" | "f64" | "i8" | "i16" | "i32" | "u8" | "u16" | "u32" => Ok(quote! {
130                    res.as_f64().map(|v| v as #ty).ok_or_else(|| ::wasm_bindgen::JsValue::from_str("Expected a number"))
131                }),
132                "i64" | "u64" => Ok(quote! {
133                    ::std::convert::TryInto::<#ty>::try_into(res).map_err(|_| ::wasm_bindgen::JsValue::from_str("Expected a BigInt"))
134                }),
135                "bool" => Ok(quote! {
136                    res.as_bool().ok_or_else(|| ::wasm_bindgen::JsValue::from_str("Expected a boolean"))
137                }),
138                "String" => Ok(quote! {
139                    res.as_string().ok_or_else(|| ::wasm_bindgen::JsValue::from_str("Expected a string"))
140                }),
141                "JsValue" => Ok(quote! {
142                    Ok(res)
143                }),
144                "Option" => {
145                    let PathArguments::AngleBracketed(args) = &segment.arguments else {
146                        return Err(Error::new_spanned(
147                            ty,
148                            "Expected generic argument for Option",
149                        ));
150                    };
151                    let syn::GenericArgument::Type(inner_ty) = args.args.first().unwrap() else {
152                        return Err(Error::new_spanned(ty, "Expected type argument for Option"));
153                    };
154                    let inner_conversion = generate_return_conversion(inner_ty)?;
155                    Ok(quote! {
156                        if res.is_null() || res.is_undefined() {
157                            Ok(None)
158                        } else {
159                            let res = { #inner_conversion };
160                            res.map(Some)
161                        }
162                    })
163                }
164                _ => Ok(quote! {
165                    Ok(::wasm_bindgen::JsCast::unchecked_into::<#ty>(res))
166                }),
167            }
168        }
169        _ => Err(Error::new_spanned(
170            ty,
171            "Unsupported return type in type alias pattern. Use the `impl` escape hatch instead.",
172        )),
173    }
174}
175
176fn parse_item_type(item_type: &ItemType) -> syn::Result<proc_macro2::TokenStream> {
177    let Type::BareFn(bare_fn) = &*item_type.ty else {
178        return Err(Error::new_spanned(
179            &item_type.ty,
180            "Expected a function pointer type (e.g., `fn(x: f64)`)",
181        ));
182    };
183
184    let struct_ident = &item_type.ident;
185    let mut args = Vec::new();
186
187    for (i, arg) in bare_fn.inputs.iter().enumerate() {
188        let ident = match &arg.name {
189            Some((ident, _)) => ident.clone(),
190            None => format_ident!("arg{}", i),
191        };
192        args.push((ident, &arg.ty));
193    }
194
195    let parsed = ParsedSignature {
196        struct_ident,
197        args: args.clone(),
198        output: &bare_fn.output,
199    };
200
201    let abi_traits = generate_abi_traits(&parsed)?;
202
203    let mut fn_args = Vec::new();
204    let mut arg_conversions = Vec::new();
205    let mut call_args = Vec::new();
206    for (ident, ty) in &args {
207        fn_args.push(quote! { #ident: #ty });
208        let conversion = generate_conversion(ident, ty)?;
209        arg_conversions.push(conversion);
210        call_args.push(quote! { &#ident });
211    }
212
213    let args_len = call_args.len();
214    if args_len > 9 {
215        return Err(Error::new_spanned(
216            item_type,
217            "Functions with more than 9 arguments are not supported yet",
218        ));
219    }
220    let call_method_name = format_ident!("call{}", args_len);
221    let call_method = quote! { #call_method_name(&::wasm_bindgen::JsValue::NULL, #(#call_args),*) };
222
223    let output = parsed.output;
224    let (ret_type, ret_stmt) = match output {
225        ReturnType::Default => (quote! { () }, quote! { self.0.#call_method.map(|_| ()) }),
226        ReturnType::Type(_, ty) => {
227            let conversion = generate_return_conversion(ty)?;
228            (
229                quote! { #ty },
230                quote! {
231                    let res = self.0.#call_method?;
232                    #conversion
233                },
234            )
235        }
236    };
237
238    Ok(quote! {
239        pub struct #struct_ident(pub ::js_sys::Function);
240
241        impl #struct_ident {
242            pub fn call(&self, #(#fn_args),*) -> Result<#ret_type, ::wasm_bindgen::JsValue> {
243                #(#arg_conversions)*
244                #ret_stmt
245            }
246        }
247
248        #abi_traits
249    })
250}
251
252fn generate_conversion(ident: &Ident, ty: &Type) -> syn::Result<proc_macro2::TokenStream> {
253    if let Type::ImplTrait(type_impl) = ty {
254        for bound in &type_impl.bounds {
255            if let syn::TypeParamBound::Trait(trait_bound) = bound
256                && let Some(segment) = trait_bound.path.segments.last()
257                && let PathArguments::AngleBracketed(args) = &segment.arguments
258                && let Some(GenericArgument::Type(inner_ty)) = args.args.first()
259            {
260                match segment.ident.to_string().as_str() {
261                    "Into" => {
262                        let inner_conversion = generate_conversion(ident, inner_ty)?;
263                        return Ok(quote! {
264                            let #ident = ::std::convert::Into::<#inner_ty>::into(#ident);
265                            #inner_conversion
266                        });
267                    }
268                    "AsRef" => {
269                        if let Type::Slice(slice) = inner_ty {
270                            return Ok(generate_typed_array_conversion(ident, &slice.elem));
271                        }
272                    }
273                    _ => {}
274                }
275            }
276        }
277        return Err(Error::new_spanned(
278            ty,
279            "Unsupported `impl Trait`. Only `impl Into<T>` and `impl AsRef<[T]>` are supported.",
280        ));
281    }
282
283    if let Some(inner_ty) = get_slice_element_type(ty) {
284        Ok(generate_typed_array_conversion(ident, inner_ty))
285    } else {
286        Ok(quote! {
287            let #ident = ::std::convert::Into::<::wasm_bindgen::JsValue>::into(#ident);
288        })
289    }
290}
291
292fn generate_typed_array_conversion(ident: &Ident, inner_ty: &Type) -> proc_macro2::TokenStream {
293    if let Some(arr_type) = get_typed_array_ident(inner_ty) {
294        quote! {
295            let #ident = ::wasm_bindgen::JsValue::from(::js_sys::#arr_type::from(::std::convert::AsRef::<[#inner_ty]>::as_ref(&#ident)));
296        }
297    } else {
298        quote! {
299            let #ident = ::wasm_bindgen::JsValue::from(
300                ::std::convert::AsRef::<[#inner_ty]>::as_ref(&#ident)
301                    .iter()
302                    .map(::wasm_bindgen::JsValue::from)
303                    .collect::<::js_sys::Array>()
304            );
305        }
306    }
307}
308
309fn get_typed_array_ident(inner_ty: &Type) -> Option<proc_macro2::TokenStream> {
310    let inner_str = match inner_ty {
311        Type::Path(p) => p.path.segments.last().map(|s| s.ident.to_string()),
312        _ => None,
313    };
314
315    match inner_str.as_deref() {
316        Some("u8") => Some(quote! { Uint8Array }),
317        Some("i8") => Some(quote! { Int8Array }),
318        Some("u16") => Some(quote! { Uint16Array }),
319        Some("i16") => Some(quote! { Int16Array }),
320        Some("u32") => Some(quote! { Uint32Array }),
321        Some("i32") => Some(quote! { Int32Array }),
322        Some("f32") => Some(quote! { Float32Array }),
323        Some("f64") => Some(quote! { Float64Array }),
324        Some("u64") => Some(quote! { BigUint64Array }),
325        Some("i64") => Some(quote! { BigInt64Array }),
326        _ => None,
327    }
328}
329
330fn get_slice_element_type(ty: &Type) -> Option<&Type> {
331    match ty {
332        Type::Path(type_path) => {
333            let segment = type_path.path.segments.last()?;
334            // Types that implement AsRef<[T]> and we can easily extract T from AST
335            if matches!(
336                segment.ident.to_string().as_str(),
337                "Vec" | "Box" | "Arc" | "Rc"
338            ) && let PathArguments::AngleBracketed(args) = &segment.arguments
339                && let Some(syn::GenericArgument::Type(inner)) = args.args.first()
340            {
341                if let Type::Slice(slice) = inner {
342                    return Some(&*slice.elem);
343                }
344                return Some(inner);
345            }
346        }
347        Type::Reference(type_ref) => {
348            if let Type::Slice(type_slice) = &*type_ref.elem {
349                return Some(&*type_slice.elem);
350            }
351            return get_slice_element_type(&type_ref.elem);
352        }
353        _ => {}
354    }
355    None
356}
357
358fn parse_item_impl(item_impl: &ItemImpl) -> syn::Result<proc_macro2::TokenStream> {
359    if item_impl.trait_.is_some() {
360        return Err(Error::new_spanned(
361            item_impl,
362            "#[ts_function] cannot be applied to trait impls",
363        ));
364    }
365
366    let Type::Path(type_path) = &*item_impl.self_ty else {
367        return Err(Error::new_spanned(
368            &item_impl.self_ty,
369            "Expected a simple path for the struct",
370        ));
371    };
372
373    let struct_ident = type_path.path.get_ident().ok_or_else(|| {
374        Error::new_spanned(
375            &type_path.path,
376            "Expected a single identifier for the struct",
377        )
378    })?;
379
380    let method = item_impl
381        .items
382        .iter()
383        .find_map(|item| {
384            if let syn::ImplItem::Fn(method) = item
385                && method.sig.ident == "call"
386            {
387                return Some(method);
388            }
389            None
390        })
391        .ok_or_else(|| Error::new_spanned(item_impl, "Missing `call` method in impl block"))?;
392
393    let mut args = Vec::new();
394    let mut inputs_iter = method.sig.inputs.iter();
395
396    // Check first argument is `&self` or `&mut self`
397    match inputs_iter.next() {
398        Some(FnArg::Receiver(_)) => {}
399        _ => {
400            return Err(Error::new_spanned(
401                &method.sig,
402                "The `call` method must take `&self` or `&mut self` as its first parameter",
403            ));
404        }
405    }
406
407    for (i, arg) in inputs_iter.enumerate() {
408        let FnArg::Typed(pat_type) = arg else {
409            return Err(Error::new_spanned(arg, "Expected a typed argument"));
410        };
411
412        let ident = if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
413            pat_ident.ident.clone()
414        } else {
415            format_ident!("arg{}", i)
416        };
417
418        args.push((ident, &*pat_type.ty));
419    }
420
421    let parsed = ParsedSignature {
422        struct_ident,
423        args,
424        output: &method.sig.output,
425    };
426
427    let abi_traits = generate_abi_traits(&parsed)?;
428
429    Ok(quote! {
430        #item_impl
431        #abi_traits
432    })
433}
434
435fn generate_abi_traits(parsed: &ParsedSignature) -> syn::Result<proc_macro2::TokenStream> {
436    let struct_ident = parsed.struct_ident;
437    let mut ts_args = Vec::new();
438
439    for (ident, ty) in &parsed.args {
440        let ts_ty = ty
441            .to_ts_type()
442            .map_err(|e| Error::new_spanned(ty, e.message))?
443            .to_string();
444        ts_args.push(format!("{}: {}", ident, ts_ty));
445    }
446
447    let ts_output = match parsed.output {
448        ReturnType::Default => "void".to_string(),
449        ReturnType::Type(_, ty) => ty
450            .to_ts_type()
451            .map_err(|e| Error::new_spanned(ty, e.message))?
452            .to_string(),
453    };
454
455    let ts_string = format!(
456        "type {} = ({}) => {};",
457        struct_ident,
458        ts_args.join(", "),
459        ts_output
460    );
461
462    let generated = quote! {
463        #[::wasm_bindgen::prelude::wasm_bindgen(typescript_custom_section)]
464        const _: &'static str = #ts_string;
465
466        impl ::wasm_bindgen::describe::WasmDescribe for #struct_ident {
467            fn describe() {
468                <::js_sys::Function as ::wasm_bindgen::describe::WasmDescribe>::describe()
469            }
470        }
471
472        impl ::wasm_bindgen::convert::FromWasmAbi for #struct_ident {
473            type Abi = <::js_sys::Function as ::wasm_bindgen::convert::FromWasmAbi>::Abi;
474
475            unsafe fn from_abi(js: Self::Abi) -> Self {
476                Self(::js_sys::Function::from_abi(js))
477            }
478        }
479
480        impl ::wasm_bindgen::convert::OptionFromWasmAbi for #struct_ident {
481            fn is_none(abi: &Self::Abi) -> bool {
482                <::js_sys::Function as ::wasm_bindgen::convert::OptionFromWasmAbi>::is_none(abi)
483            }
484        }
485
486        impl From<::js_sys::Function> for #struct_ident {
487            fn from(f: ::js_sys::Function) -> Self {
488                Self(f)
489            }
490        }
491    };
492
493    Ok(generated)
494}
495
496#[cfg(test)]
497mod tests {
498    use super::*;
499    use syn::parse_quote;
500
501    #[test]
502    fn test_item_type() {
503        let item_type: ItemType = parse_quote! {
504            pub type OnClick = fn(x: f64, y: impl Into<f64>, arr: js_sys::Float64Array);
505        };
506        let result = parse_item_type(&item_type).unwrap();
507        let result_str = result.to_string();
508
509        assert!(
510            result_str
511                .contains("type OnClick = (x: number, y: number, arr: Float64Array) => void;")
512        );
513        assert!(result_str.contains("pub struct OnClick (pub :: js_sys :: Function) ;"));
514        assert!(result_str.contains(
515            "pub fn call (& self , x : f64 , y : impl Into < f64 > , arr : js_sys :: Float64Array)"
516        ));
517    }
518
519    #[test]
520    fn test_item_impl() {
521        let item_impl: ItemImpl = parse_quote! {
522            impl OnScroll {
523                pub fn call(&self, y: f64) {
524                    // body
525                }
526            }
527        };
528        let result = parse_item_impl(&item_impl).unwrap();
529        let result_str = result.to_string();
530
531        assert!(result_str.contains("type OnScroll = (y: number) => void;"));
532        assert!(
533            result_str.contains("impl :: wasm_bindgen :: describe :: WasmDescribe for OnScroll")
534        );
535    }
536
537    #[test]
538    fn test_recursive_generics() {
539        let item_type: ItemType = parse_quote! {
540            pub type ResultCb = fn(res: Result<String, i32>);
541        };
542        let result = parse_item_type(&item_type).unwrap();
543        let result_str = result.to_string();
544
545        assert!(result_str.contains("type ResultCb = (res: Result<string, number>) => void;"));
546
547        let item_type: ItemType = parse_quote! {
548            pub type NestedVecCb = fn(args: Vec<Vec<f64>>);
549        };
550        let result = parse_item_type(&item_type).unwrap();
551        let result_str = result.to_string();
552
553        assert!(result_str.contains("type NestedVecCb = (args: Float64Array[]) => void;"));
554    }
555}