dioxus_use_js_macro/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use core::panic;
4use indexmap::IndexMap;
5use proc_macro::TokenStream;
6use proc_macro2::{Literal, TokenStream as TokenStream2};
7use quote::{format_ident, quote};
8use std::collections::HashMap;
9use std::str::FromStr;
10use std::{fs, path::Path};
11use swc_common::comments::{CommentKind, Comments};
12use swc_common::{SourceMap, Span, comments::SingleThreadedComments};
13use swc_common::{SourceMapper, Spanned};
14use swc_ecma_ast::{
15    Decl, ExportDecl, ExportSpecifier, FnDecl, NamedExport, Pat, TsType, TsTypeAnn, VarDeclarator,
16};
17use swc_ecma_parser::EsSyntax;
18use swc_ecma_parser::{Parser, StringInput, Syntax, lexer::Lexer};
19use swc_ecma_visit::{Visit, VisitWith};
20use syn::TypeParam;
21use syn::{
22    Ident, LitStr, Result, Token,
23    parse::{Parse, ParseStream},
24    parse_macro_input,
25};
26
27/// `JsValue<T>`
28const JSVALUE_START: &str = "JsValue";
29const JSVALUE: &str = "dioxus_use_js::JsValue";
30const DEFAULT_GENRIC_INPUT: &str = "impl dioxus_use_js::SerdeSerialize";
31const DEFAULT_GENERIC_OUTPUT: &str = "DeserializeOwned";
32const DEFAULT_OUTPUT_GENERIC_DECLARTION: &str =
33    "DeserializeOwned: dioxus_use_js::SerdeDeDeserializeOwned";
34const SERDE_VALUE: &str = "dioxus_use_js::SerdeJsonValue";
35const JSON: &str = "Json";
36/// `RustCallback<T,TT>`
37const RUST_CALLBACK_JS_START: &str = "RustCallback";
38const UNIT: &str = "()";
39const DROP_TYPE: &str = "Drop";
40const DROP_NAME: &str = "drop";
41
42#[derive(Debug, Clone)]
43enum ImportSpec {
44    /// *
45    All,
46    /// {greeting, other_func}
47    Named(Vec<Ident>),
48    /// greeting
49    Single(Ident),
50}
51
52struct UseJsInput {
53    js_bundle_path: LitStr,
54    ts_source_path: Option<LitStr>,
55    import_spec: ImportSpec,
56}
57
58impl Parse for UseJsInput {
59    fn parse(input: ParseStream) -> Result<Self> {
60        let first_str: LitStr = input.parse()?;
61
62        // Check if => follows (i.e., we have "src.ts" => "bundle.js")
63        let (ts_source_path, js_bundle_path) = if input.peek(Token![,]) {
64            input.parse::<Token![,]>()?;
65            let second_str: LitStr = input.parse()?;
66            (Some(first_str), second_str)
67        } else {
68            (None, first_str)
69        };
70
71        // Check for optional :: following bundle path
72        let import_spec = if input.peek(Token![::]) {
73            input.parse::<Token![::]>()?;
74
75            if input.peek(Token![*]) {
76                input.parse::<Token![*]>()?;
77                ImportSpec::All
78            } else if input.peek(Ident) {
79                let ident: Ident = input.parse()?;
80                ImportSpec::Single(ident)
81            } else if input.peek(syn::token::Brace) {
82                let content;
83                syn::braced!(content in input);
84                let idents: syn::punctuated::Punctuated<Ident, Token![,]> =
85                    content.parse_terminated(Ident::parse, Token![,])?;
86                ImportSpec::Named(idents.into_iter().collect())
87            } else {
88                return Err(input.error("Expected `*`, an identifier, or a brace group after `::`"));
89            }
90        } else {
91            return Err(input
92                .error("Expected `::` followed by an import spec (even for wildcard with `*`)"));
93        };
94
95        Ok(UseJsInput {
96            js_bundle_path,
97            ts_source_path,
98            import_spec,
99        })
100    }
101}
102
103#[derive(Debug, Clone)]
104struct ParamInfo {
105    name: String,
106    #[allow(unused)]
107    js_type: Option<String>,
108    rust_type: RustType,
109}
110
111impl ParamInfo {
112    fn is_drop(&self) -> bool {
113        match self.js_type.as_ref() {
114            Some(js_type) => js_type == DROP_TYPE,
115            None => self.name == DROP_NAME,
116        }
117    }
118}
119
120#[derive(Debug, Clone)]
121struct FunctionInfo {
122    name: String,
123    /// If specified in the `use_js!` declaration. Used to link the generated code to this span
124    name_ident: Option<Ident>,
125    /// js param types
126    params: Vec<ParamInfo>,
127    // js return type
128    #[allow(unused)]
129    js_return_type: Option<String>,
130    rust_return_type: RustType,
131    is_exported: bool,
132    is_async: bool,
133    /// The stripped lines
134    doc_comment: Vec<String>,
135}
136
137struct FunctionVisitor {
138    functions: Vec<FunctionInfo>,
139    comments: SingleThreadedComments,
140    source_map: SourceMap,
141}
142
143impl FunctionVisitor {
144    fn new(comments: SingleThreadedComments, source_map: SourceMap) -> Self {
145        Self {
146            functions: Vec::new(),
147            comments,
148            source_map,
149        }
150    }
151
152    fn extract_doc_comment(&self, span: &Span) -> Vec<String> {
153        // Get leading comments for the span
154        let leading_comment = self.comments.get_leading(span.lo());
155
156        if let Some(comments) = leading_comment {
157            let mut doc_lines = Vec::new();
158
159            for comment in comments.iter() {
160                let comment_text = &comment.text;
161                match comment.kind {
162                    // Handle `///`. `//` is already stripped
163                    CommentKind::Line => {
164                        if let Some(content) = comment_text.strip_prefix("/") {
165                            let cleaned = content.trim_start();
166                            doc_lines.push(cleaned.to_string());
167                        }
168                    }
169                    // Handle `/*` `*/`. `/*` `*/` is already stripped
170                    CommentKind::Block => {
171                        for line in comment_text.lines() {
172                            if let Some(cleaned) = line.trim_start().strip_prefix("*") {
173                                doc_lines.push(cleaned.to_string());
174                            }
175                        }
176                    }
177                };
178            }
179
180            doc_lines
181        } else {
182            Vec::new()
183        }
184    }
185}
186
187#[derive(Debug, Clone)]
188enum RustType {
189    Regular(String),
190    Callback(RustCallback),
191    JsValue(JsValue),
192}
193
194impl ToString for RustType {
195    fn to_string(&self) -> String {
196        match self {
197            RustType::Regular(ty) => ty.clone(),
198            RustType::Callback(callback) => callback.to_string(),
199            RustType::JsValue(js_value) => js_value.to_string(),
200        }
201    }
202}
203
204impl RustType {
205    fn to_tokens(&self) -> TokenStream2 {
206        self.to_string()
207            .parse::<TokenStream2>()
208            .expect("Calculated Rust type should always be valid")
209    }
210}
211
212#[derive(Debug, Clone)]
213struct RustCallback {
214    input: Option<String>,
215    output: Option<String>,
216}
217
218impl ToString for RustCallback {
219    fn to_string(&self) -> String {
220        let input = self.input.as_deref();
221        let output = self.output.as_deref().unwrap_or(UNIT);
222        format!(
223            "dioxus::core::Callback<{}, impl Future<Output = Result<{}, dioxus_use_js::SerdeJsonValue>> + 'static>",
224            input.unwrap_or("()"),
225            output
226        )
227    }
228}
229
230#[derive(Debug, Clone)]
231struct JsValue {
232    is_option: bool,
233    is_input: bool,
234}
235
236impl ToString for JsValue {
237    fn to_string(&self) -> String {
238        if self.is_option {
239            format!(
240                "Option<{}>",
241                if self.is_input {
242                    format!("&{}", JSVALUE)
243                } else {
244                    JSVALUE.to_owned()
245                }
246            )
247        } else {
248            if self.is_input {
249                format!("&{}", JSVALUE)
250            } else {
251                JSVALUE.to_owned()
252            }
253        }
254    }
255}
256
257fn strip_parenthesis(mut ts_type: &str) -> &str {
258    while ts_type.starts_with("(") && ts_type.ends_with(")") {
259        ts_type = &ts_type[1..ts_type.len() - 1].trim();
260    }
261    return ts_type;
262}
263
264/// Splits into correct comma delimited arguments
265fn split_into_args(ts_type: &str) -> Vec<&str> {
266    let mut depth_angle: u16 = 0;
267    let mut depth_square: u16 = 0;
268    let mut depth_paren: u16 = 0;
269    let mut splits = Vec::new();
270    let mut last: usize = 0;
271    for (i, c) in ts_type.char_indices() {
272        match c {
273            '<' => depth_angle += 1,
274            '>' => depth_angle = depth_angle.saturating_sub(1),
275            '[' => depth_square += 1,
276            ']' => depth_square = depth_square.saturating_sub(1),
277            '(' => depth_paren += 1,
278            ')' => depth_paren = depth_paren.saturating_sub(1),
279            ',' if depth_angle == 0 && depth_square == 0 && depth_paren == 0 => {
280                splits.push(ts_type[last..i].trim());
281                last = i + 1;
282            }
283            _ => {}
284        }
285    }
286    let len = ts_type.len();
287    if last != len {
288        let maybe_arg = ts_type[last..len].trim();
289        if !maybe_arg.is_empty() {
290            splits.push(maybe_arg);
291        }
292    }
293    splits
294}
295
296fn ts_type_to_rust_type(ts_type: Option<&str>, is_input: bool) -> RustType {
297    let Some(mut ts_type) = ts_type else {
298        return RustType::Regular(
299            (if is_input {
300                DEFAULT_GENRIC_INPUT
301            } else {
302                DEFAULT_GENERIC_OUTPUT
303            })
304            .to_owned(),
305        );
306    };
307    ts_type = strip_parenthesis(&mut ts_type);
308    if ts_type.starts_with("Promise<") && ts_type.ends_with(">") {
309        assert!(!is_input, "Promise cannot be used as input type");
310        ts_type = &ts_type[8..ts_type.len() - 1];
311    }
312    ts_type = strip_parenthesis(&mut ts_type);
313    if ts_type.contains(JSVALUE_START) {
314        let parts = split_top_level_union(ts_type);
315        let len = parts.len();
316        if len == 1 && parts[0].starts_with(JSVALUE_START) {
317            return RustType::JsValue(JsValue {
318                is_option: false,
319                is_input,
320            });
321        }
322
323        if len == 2 && parts.contains(&"null") {
324            return RustType::JsValue(JsValue {
325                is_option: true,
326                is_input,
327            });
328        } else {
329            panic!("Invalid use of `{}` for `{}`", JSVALUE_START, ts_type);
330        }
331    }
332    if ts_type.contains(RUST_CALLBACK_JS_START) {
333        if !ts_type.starts_with(RUST_CALLBACK_JS_START) {
334            panic!("Nested RustCallback is not valid: {}", ts_type);
335        }
336        assert!(is_input, "Cannot return a RustCallback: {}", ts_type);
337        let ts_type = &ts_type[RUST_CALLBACK_JS_START.len()..];
338        if !(ts_type.starts_with("<") && ts_type.ends_with(">")) {
339            panic!("Invalid RustCallback type: {}", ts_type);
340        }
341        let inner = &ts_type[1..ts_type.len() - 1];
342        let parts = split_into_args(inner);
343        let len = parts.len();
344        if len != 2 {
345            panic!(
346                "A RustCallback type expects two parameters, got: {:?}",
347                parts
348            );
349        }
350        let ts_input = parts[0];
351        let rs_input = if ts_input == "void" {
352            None
353        } else {
354            let rs_input = ts_type_to_rust_type_helper(ts_input, false);
355            if rs_input.is_none() || rs_input.as_ref().is_some_and(|e| e == UNIT) {
356                panic!("Type `{ts_input}` is not a valid input for `{RUST_CALLBACK_JS_START}`");
357            }
358            rs_input
359        };
360        let ts_output = parts[1];
361        let rs_output = if ts_output == "void" {
362            None
363        } else {
364            let rs_output = ts_type_to_rust_type_helper(ts_output, false);
365            if rs_output.is_none() || rs_output.as_ref().is_some_and(|e| e == UNIT) {
366                panic!("Type `{ts_output}` is not a valid output for `{RUST_CALLBACK_JS_START}`");
367            }
368            rs_output
369        };
370        return RustType::Callback(RustCallback {
371            input: rs_input,
372            output: rs_output,
373        });
374    }
375    RustType::Regular(match ts_type_to_rust_type_helper(ts_type, is_input) {
376        Some(value) => {
377            if value.contains(UNIT) && (is_input || &value != UNIT) {
378                // Would cause serialization errors since `serde_json::Value::Null` or any, cannot be deserialized into `()`.
379                // We handle `()` special case to account for this if this is the root type in the output. But not input or output nested.
380                panic!("`{}` is not valid in this position", ts_type);
381            }
382            value
383        }
384        None => (if is_input {
385            DEFAULT_GENRIC_INPUT
386        } else {
387            DEFAULT_GENERIC_OUTPUT
388        })
389        .to_owned(),
390    })
391}
392
393/// Returns None if could not determine type
394fn ts_type_to_rust_type_helper(mut ts_type: &str, can_be_ref: bool) -> Option<String> {
395    ts_type = ts_type.trim();
396    ts_type = strip_parenthesis(&mut ts_type);
397
398    let parts = split_top_level_union(ts_type);
399    if parts.len() > 1 {
400        // Handle single null union: T | null or null | T
401        if parts.len() == 2 && parts.contains(&"null") {
402            let inner = parts.iter().find(|p| **p != "null")?;
403            let inner_rust = ts_type_to_rust_type_helper(inner, can_be_ref)?;
404            return Some(format!("Option<{}>", inner_rust));
405        }
406        // Unsupported union type
407        return None;
408    }
409
410    ts_type = parts[0];
411
412    if ts_type.ends_with("[]") {
413        let inner = ts_type.strip_suffix("[]").unwrap();
414        let inner_rust = ts_type_to_rust_type_helper(inner, false)?;
415        return Some(if can_be_ref {
416            format!("&[{}]", inner_rust)
417        } else {
418            format!("Vec<{}>", inner_rust)
419        });
420    }
421
422    if ts_type.starts_with("Array<") && ts_type.ends_with(">") {
423        let inner = &ts_type[6..ts_type.len() - 1];
424        let inner_rust = ts_type_to_rust_type_helper(inner, false)?;
425        return Some(if can_be_ref {
426            format!("&[{}]", inner_rust)
427        } else {
428            format!("Vec<{}>", inner_rust)
429        });
430    }
431
432    if ts_type.starts_with("Set<") && ts_type.ends_with(">") {
433        let inner = &ts_type[4..ts_type.len() - 1];
434        let inner_rust = ts_type_to_rust_type_helper(inner, false)?;
435        if can_be_ref {
436            return Some(format!("&std::collections::HashSet<{}>", inner_rust));
437        } else {
438            return Some(format!("std::collections::HashSet<{}>", inner_rust));
439        }
440    }
441
442    if ts_type.starts_with("Map<") && ts_type.ends_with(">") {
443        let inner = &ts_type[4..ts_type.len() - 1];
444        let mut depth = 0;
445        let mut split_index = None;
446        for (i, c) in inner.char_indices() {
447            match c {
448                '<' => depth += 1,
449                '>' => depth -= 1,
450                ',' if depth == 0 => {
451                    split_index = Some(i);
452                    break;
453                }
454                _ => {}
455            }
456        }
457
458        if let Some(i) = split_index {
459            let (key, value) = inner.split_at(i);
460            let value = &value[1..]; // skip comma
461            let key_rust = ts_type_to_rust_type_helper(key.trim(), false)?;
462            let value_rust = ts_type_to_rust_type_helper(value.trim(), false)?;
463            if can_be_ref {
464                return Some(format!(
465                    "&std::collections::HashMap<{}, {}>",
466                    key_rust, value_rust
467                ));
468            } else {
469                return Some(format!(
470                    "std::collections::HashMap<{}, {}>",
471                    key_rust, value_rust
472                ));
473            }
474        } else {
475            return None;
476        }
477    }
478
479    // Base types
480    let rust_type = match ts_type {
481        "string" => {
482            if can_be_ref {
483                Some("&str".to_owned())
484            } else {
485                Some("String".to_owned())
486            }
487        }
488        "number" => Some("f64".to_owned()),
489        "boolean" => Some("bool".to_owned()),
490        "void" | "undefined" | "never" | "null" => Some(UNIT.to_owned()),
491        JSON => {
492            if can_be_ref {
493                Some(format!("&{SERDE_VALUE}"))
494            } else {
495                Some(SERDE_VALUE.to_owned())
496            }
497        }
498        "Promise" => {
499            panic!("`{}` - nested promises are not valid", ts_type)
500        }
501        // "any" | "unknown" | "object" | .. etc.
502        _ => None,
503    };
504
505    rust_type
506}
507
508/// Splits e.g. `number | null | string` ignoring nesting like `(number | null)[]`
509fn split_top_level_union(s: &str) -> Vec<&str> {
510    let mut parts = vec![];
511    let mut last = 0;
512    let mut depth_angle = 0;
513    let mut depth_paren = 0;
514
515    for (i, c) in s.char_indices() {
516        match c {
517            '<' => depth_angle += 1,
518            '>' => {
519                if depth_angle > 0 {
520                    depth_angle -= 1
521                }
522            }
523            '(' => depth_paren += 1,
524            ')' => {
525                if depth_paren > 0 {
526                    depth_paren -= 1
527                }
528            }
529            '|' if depth_angle == 0 && depth_paren == 0 => {
530                parts.push(s[last..i].trim());
531                last = i + 1;
532            }
533            _ => {}
534        }
535    }
536
537    if last < s.len() {
538        parts.push(s[last..].trim());
539    }
540
541    parts
542}
543
544fn type_to_string(ty: &Box<TsType>, source_map: &SourceMap) -> String {
545    let span = ty.span();
546    source_map
547        .span_to_snippet(span)
548        .expect("Could not get snippet from span for type")
549}
550
551fn function_pat_to_param_info<'a, I>(pats: I, source_map: &SourceMap) -> Vec<ParamInfo>
552where
553    I: Iterator<Item = &'a Pat>,
554{
555    pats.enumerate()
556        .map(|(i, pat)| to_param_info_helper(i, pat, source_map))
557        .collect()
558}
559
560fn to_param_info_helper(i: usize, pat: &Pat, source_map: &SourceMap) -> ParamInfo {
561    let name = if let Some(ident) = pat.as_ident() {
562        ident.id.sym.to_string()
563    } else {
564        format!("arg{}", i)
565    };
566
567    let js_type = pat
568        .as_ident()
569        .and_then(|ident| ident.type_ann.as_ref())
570        .map(|type_ann| {
571            let ty = &type_ann.type_ann;
572            type_to_string(ty, source_map)
573        });
574    let rust_type = ts_type_to_rust_type(js_type.as_deref(), true);
575
576    ParamInfo {
577        name,
578        js_type,
579        rust_type,
580    }
581}
582
583fn function_info_helper<'a, I>(
584    visitor: &FunctionVisitor,
585    name: String,
586    span: &Span,
587    params: I,
588    return_type: Option<&Box<TsTypeAnn>>,
589    is_async: bool,
590    is_exported: bool,
591) -> FunctionInfo
592where
593    I: Iterator<Item = &'a Pat>,
594{
595    let doc_comment = visitor.extract_doc_comment(span);
596
597    let params = function_pat_to_param_info(params, &visitor.source_map);
598
599    let js_return_type = return_type.as_ref().map(|type_ann| {
600        let ty = &type_ann.type_ann;
601        type_to_string(ty, &visitor.source_map)
602    });
603    if !is_async
604        && let Some(ref js_return_type) = js_return_type
605        && js_return_type.starts_with("Promise")
606    {
607        panic!(
608            "Promise return type is only supported for async functions, use `async fn` instead. For `{js_return_type}`"
609        );
610    }
611
612    let rust_return_type = ts_type_to_rust_type(js_return_type.as_deref(), false);
613
614    FunctionInfo {
615        name,
616        name_ident: None,
617        params,
618        js_return_type,
619        rust_return_type,
620        is_exported,
621        is_async,
622        doc_comment,
623    }
624}
625
626impl Visit for FunctionVisitor {
627    /// Visit function declarations: function foo() {}
628    fn visit_fn_decl(&mut self, node: &FnDecl) {
629        let name = node.ident.sym.to_string();
630        self.functions.push(function_info_helper(
631            self,
632            name,
633            &node.span(),
634            node.function.params.iter().map(|e| &e.pat),
635            node.function.return_type.as_ref(),
636            node.function.is_async,
637            false,
638        ));
639        node.visit_children_with(self);
640    }
641
642    /// Visit function expressions: const foo = function() {}
643    fn visit_var_declarator(&mut self, node: &VarDeclarator) {
644        if let swc_ecma_ast::Pat::Ident(ident) = &node.name {
645            if let Some(init) = &node.init {
646                let span = node.span();
647                let name = ident.id.sym.to_string();
648                match &**init {
649                    swc_ecma_ast::Expr::Fn(fn_expr) => {
650                        self.functions.push(function_info_helper(
651                            &self,
652                            name,
653                            &span,
654                            fn_expr.function.params.iter().map(|e| &e.pat),
655                            fn_expr.function.return_type.as_ref(),
656                            fn_expr.function.is_async,
657                            false,
658                        ));
659                    }
660                    swc_ecma_ast::Expr::Arrow(arrow_fn) => {
661                        self.functions.push(function_info_helper(
662                            &self,
663                            name,
664                            &span,
665                            arrow_fn.params.iter(),
666                            arrow_fn.return_type.as_ref(),
667                            arrow_fn.is_async,
668                            false,
669                        ));
670                    }
671                    _ => {}
672                }
673            }
674        }
675        node.visit_children_with(self);
676    }
677
678    /// Visit export declarations: export function foo() {}
679    fn visit_export_decl(&mut self, node: &ExportDecl) {
680        if let Decl::Fn(fn_decl) = &node.decl {
681            let span = node.span();
682            let name = fn_decl.ident.sym.to_string();
683            self.functions.push(function_info_helper(
684                &self,
685                name,
686                &span,
687                fn_decl.function.params.iter().map(|e| &e.pat),
688                fn_decl.function.return_type.as_ref(),
689                fn_decl.function.is_async,
690                true,
691            ));
692        }
693        node.visit_children_with(self);
694    }
695
696    /// Visit named exports: export { foo }
697    fn visit_named_export(&mut self, node: &NamedExport) {
698        for spec in &node.specifiers {
699            if let ExportSpecifier::Named(named) = spec {
700                let original_name = named.orig.atom().to_string();
701                let out_name = named
702                    .exported
703                    .as_ref()
704                    .map(|e| e.atom().to_string())
705                    .unwrap_or_else(|| original_name.clone());
706
707                if let Some(func) = self.functions.iter_mut().find(|f| f.name == original_name) {
708                    let mut func = func.clone();
709                    func.name = out_name;
710                    func.is_exported = true;
711                    self.functions.push(func);
712                }
713            }
714        }
715        node.visit_children_with(self);
716    }
717}
718
719fn parse_script_file(file_path: &Path, is_js: bool) -> Result<Vec<FunctionInfo>> {
720    let js_content = fs::read_to_string(file_path).map_err(|e| {
721        syn::Error::new(
722            proc_macro2::Span::call_site(),
723            format!("Could not read file '{}': {}", file_path.display(), e),
724        )
725    })?;
726
727    let source_map = SourceMap::default();
728    let fm = source_map.new_source_file(
729        swc_common::FileName::Custom(file_path.display().to_string()).into(),
730        js_content.clone(),
731    );
732    let comments = SingleThreadedComments::default();
733
734    // Enable TypeScript parsing to handle type annotations
735    let syntax = if is_js {
736        Syntax::Es(EsSyntax {
737            jsx: false,
738            fn_bind: false,
739            decorators: false,
740            decorators_before_export: false,
741            export_default_from: false,
742            import_attributes: false,
743            allow_super_outside_method: false,
744            allow_return_outside_function: false,
745            auto_accessors: false,
746            explicit_resource_management: false,
747        })
748    } else {
749        Syntax::Typescript(swc_ecma_parser::TsSyntax {
750            tsx: false,
751            decorators: false,
752            dts: false,
753            no_early_errors: false,
754            disallow_ambiguous_jsx_like: true,
755        })
756    };
757
758    let lexer = Lexer::new(
759        syntax,
760        Default::default(),
761        StringInput::from(&*fm),
762        Some(&comments),
763    );
764
765    let mut parser = Parser::new_from(lexer);
766
767    let module = parser.parse_module().map_err(|e| {
768        syn::Error::new(
769            proc_macro2::Span::call_site(),
770            format!(
771                "Failed to parse script file '{}': {:?}",
772                file_path.display(),
773                e
774            ),
775        )
776    })?;
777
778    let mut visitor = FunctionVisitor::new(comments, source_map);
779    module.visit_with(&mut visitor);
780
781    // Functions are added twice for some reason.
782    visitor
783        .functions
784        .dedup_by(|e1, e2| e1.name.as_str() == e2.name.as_str());
785    Ok(visitor.functions)
786}
787
788fn take_function_by_name(
789    name: &str,
790    functions: &mut Vec<FunctionInfo>,
791    file: &Path,
792) -> Result<FunctionInfo> {
793    let function_info = if let Some(pos) = functions.iter().position(|f| f.name == name) {
794        functions.remove(pos)
795    } else {
796        return Err(syn::Error::new(
797            proc_macro2::Span::call_site(),
798            format!("Function '{}' not found in file '{}'", name, file.display()),
799        ));
800    };
801    if !function_info.is_exported {
802        return Err(syn::Error::new(
803            proc_macro2::Span::call_site(),
804            format!(
805                "Function '{}' not exported in file '{}'",
806                name,
807                file.display()
808            ),
809        ));
810    }
811    Ok(function_info)
812}
813
814fn get_functions_to_generate(
815    mut functions: Vec<FunctionInfo>,
816    import_spec: &ImportSpec,
817    file: &Path,
818) -> Result<Vec<FunctionInfo>> {
819    match import_spec {
820        ImportSpec::All => Ok(functions.into_iter().filter(|e| e.is_exported).collect()),
821        ImportSpec::Single(name) => {
822            let mut func = take_function_by_name(name.to_string().as_str(), &mut functions, file)?;
823            func.name_ident.replace(name.clone());
824            Ok(vec![func])
825        }
826        ImportSpec::Named(names) => {
827            let mut result = Vec::new();
828            for name in names {
829                let mut func =
830                    take_function_by_name(name.to_string().as_str(), &mut functions, file)?;
831                func.name_ident.replace(name.clone());
832                result.push(func);
833            }
834            Ok(result)
835        }
836    }
837}
838
839fn generate_function_wrapper(func: &FunctionInfo, asset_path: &LitStr) -> TokenStream2 {
840    // If we have callbacks, we cant do a simpl return, we have to do message passing
841    let mut callback_name_to_index: HashMap<String, u64> = HashMap::new();
842    let mut callback_name_to_info: IndexMap<String, &RustCallback> = IndexMap::new();
843    let mut index: u64 = 0;
844    for param in &func.params {
845        if let RustType::Callback(callback) = &param.rust_type {
846            callback_name_to_index.insert(param.name.to_owned(), index);
847            index += 1;
848            callback_name_to_info.insert(param.name.to_owned(), callback);
849        }
850    }
851    let js_func_name = &func.name;
852    let js_func_name_ident = quote! { FUNC_NAME };
853
854    let mut has_callbacks = false;
855    let send_calls: Vec<TokenStream2> = func
856        .params
857        .iter()
858        .flat_map(|param| {
859            if param.is_drop() {
860                return None;
861            }
862            let param_name = format_ident!("{}", param.name);
863            match &param.rust_type {
864                RustType::Regular(_) => Some(quote! {
865                    eval.send(#param_name).map_err(|e| dioxus_use_js::JsError::Eval { func: #js_func_name_ident, error: e })?;
866                }),
867                RustType::JsValue(js_value) => {
868                    if js_value.is_option {
869                        Some(quote! {
870                            #[allow(deprecated)]
871                            eval.send(#param_name.map(|e| e.internal_get())).map_err(|e| dioxus_use_js::JsError::Eval { func: #js_func_name_ident, error: e })?;
872                        })
873                    } else {
874                        Some(quote! {
875                            #[allow(deprecated)]
876                            eval.send(#param_name.internal_get()).map_err(|e| dioxus_use_js::JsError::Eval { func: #js_func_name_ident, error: e })?;
877                        })
878                    }
879                },
880                RustType::Callback(_) => {
881                    has_callbacks = true;
882                    None
883                },
884            }
885        })
886        .collect();
887
888    let params_list = func
889        .params
890        .iter()
891        .map(|p| p.name.as_str())
892        .collect::<Vec<&str>>()
893        .join(", ");
894    let prepare_callbacks = if has_callbacks {
895        "let _i_=\"**INVOCATION_ID**\";let _l_={};window[_i_]=_l_;let _g_ = 0;let _a_=true;const _c_=(c, v)=>{if(!_a_){return Promise.reject(new Error(\"Channel already destroyed\"));}_g_+=1;if(_g_>Number.MAX_SAFE_INTEGER){_g_= 0;}let o, e;let p=new Promise((rs, rj)=>{o=rs;e=rj});_l_[_g_]=[o, e];dioxus.send([c,_g_,v]);return p;};"
896    } else {
897        ""
898    };
899    let param_declarations = func
900        .params
901        .iter()
902        .map(|param| {
903            if param.is_drop() {
904                return format!("let {}=_dp_;", param.name);
905            }
906            match &param.rust_type {
907            RustType::Regular(_) => {
908                format!("let {}=await dioxus.recv();", param.name)
909            }
910            RustType::JsValue(js_value) => {
911                let param_name = &param.name;
912                if js_value.is_option {
913                format!(
914                    "let {param_name}Temp_=await dioxus.recv();let {param_name}=null;if({param_name}Temp_!==null){{{param_name}=window[{param_name}Temp_]}};",
915                )
916            }
917            else {
918                format!(
919                    "let {param_name}Temp_=await dioxus.recv();let {param_name}=window[{param_name}Temp_];",
920                )
921            }
922            },
923            RustType::Callback(rust_callback) => {
924                let name = &param.name;
925                let index = callback_name_to_index.get(name).unwrap();
926                let RustCallback { input, output } = rust_callback;
927                match (input, output) {
928                    (None, None) => {
929                        // no return, but still need to await ack
930                        format!(
931                            "const {}=async()=>{{await _c_({},null);}};",
932                            name, index
933                        )
934                    },
935                    (None, Some(_)) => {
936                        format!(
937                            "const {}=async()=>{{return await _c_({},null);}};",
938                            name, index
939
940                        )
941                    },
942                    (Some(_), None) => {
943                        // no return, but still need to await ack
944                        format!(
945                            "const {}=async(v)=>{{await _c_({},v);}};",
946                            name, index
947                        )
948                    },
949                    (Some(_), Some(_)) => {
950                        format!(
951                            "const {}=async(v)=>{{return await _c_({},v);}};",
952                            name, index
953                        )
954                    },
955                }
956            },
957        }})
958        .collect::<Vec<_>>()
959        .join("");
960    let mut maybe_await = String::new();
961    if func.is_async {
962        maybe_await.push_str("await");
963    }
964    let call_function = match &func.rust_return_type {
965        RustType::Regular(_) => {
966            format!("return [true, {maybe_await} {js_func_name}({params_list})];")
967        }
968        RustType::Callback(_) => {
969            unreachable!("This cannot be an output type, the macro should have panicked earlier.")
970        }
971        RustType::JsValue(js_value) => {
972            let check = if js_value.is_option {
973                // null or undefined is valid, since this is e.g. `Option<JsValue>`
974                "if (_v_===null||_v_===undefined){return [true,null];}".to_owned()
975            } else {
976                format!(
977                    "if (_v_===null||_v_===undefined){{console.error(\"The result of `{js_func_name}` was null or undefined, but a value is needed for JsValue\");return [true,null];}}"
978                )
979            };
980            format!(
981                "const _v_={maybe_await} {js_func_name}({params_list});{check}let _j_=\"__js-value-\"+crypto.randomUUID();window[_j_]=_v_;return [true,_j_];"
982            )
983        }
984    };
985    let drop_declare = "let _d_;let _dp_=new Promise((r)=>_d_=r);";
986    let drop_handle = if has_callbacks {
987        "(async()=>{await dioxus.recv();dioxus.close();_d_();_a_=false;let w=window[_i_];delete window[_i_];for(const[o, e] of Object.values(w)){e(new Error(\"Channel destroyed\"));}})();"
988    } else {
989        "(async()=>{await dioxus.recv();dioxus.close();_d_();})();"
990    };
991    let asset_path_string = asset_path.value();
992    // Note: eval will fail if returning undefined. undefined happens if there is no return type
993    let js = format!(
994        "const{{{js_func_name}}}=await import(\"{asset_path_string}\");{prepare_callbacks}{drop_declare}{param_declarations}{drop_handle}try{{{call_function}}}catch(e){{console.warn(\"Executing `{js_func_name}` threw:\", e);return [false,null];}}"
995    );
996    fn to_raw_string_literal(s: &str) -> Literal {
997        let mut hashes = String::from("#");
998        while s.contains(&format!("\"{}", hashes)) {
999            hashes.push('#');
1000        }
1001
1002        let raw = format!("r{h}\"{s}\"{h}", h = hashes);
1003        Literal::from_str(&raw).unwrap()
1004    }
1005    let comment = to_raw_string_literal(&js);
1006    // Easier debugging to see what the generated js is. Will be compiled away.
1007    let js_in_comment = quote! {
1008        #[doc = #comment]
1009        fn ___above_is_the_generated_js___() {}
1010    };
1011    let js_format = js
1012        .replace("{", "{{")
1013        .replace("}", "}}")
1014        .replace(&asset_path_string, "{}");
1015    let js_format = if has_callbacks {
1016        js_format.replace("**INVOCATION_ID**", "{}")
1017    } else {
1018        js_format
1019    };
1020
1021    // Generate parameter types with extracted type information
1022    let param_types: Vec<_> = func
1023        .params
1024        .iter()
1025        .filter_map(|param| {
1026            if param.is_drop() {
1027                return None;
1028            }
1029            let param_name = format_ident!("{}", param.name);
1030            let type_tokens = param.rust_type.to_tokens();
1031            if let RustType::Callback(_) = param.rust_type {
1032                Some(quote! { mut #param_name: #type_tokens })
1033            } else {
1034                Some(quote! { #param_name: #type_tokens })
1035            }
1036        })
1037        .collect();
1038
1039    let parsed_type = func.rust_return_type.to_tokens();
1040    let (return_type_tokens, generic_tokens) = if func.rust_return_type.to_string()
1041        == DEFAULT_GENERIC_OUTPUT
1042    {
1043        let span = func
1044            .name_ident
1045            .as_ref()
1046            .map(|e| e.span())
1047            .unwrap_or_else(|| proc_macro2::Span::call_site());
1048        let generic = Ident::new(DEFAULT_GENERIC_OUTPUT, span);
1049        let generic_decl: TypeParam = syn::parse_str(DEFAULT_OUTPUT_GENERIC_DECLARTION).unwrap();
1050        (
1051            quote! { Result<#generic, dioxus_use_js::JsError> },
1052            Some(quote! { <#generic_decl> }),
1053        )
1054    } else {
1055        (
1056            quote! { Result<#parsed_type, dioxus_use_js::JsError> },
1057            None,
1058        )
1059    };
1060
1061    // Generate documentation comment if available - preserve original JSDoc format
1062    let doc_comment = if func.doc_comment.is_empty() {
1063        quote! {}
1064    } else {
1065        let doc_lines: Vec<_> = func
1066            .doc_comment
1067            .iter()
1068            .map(|line| quote! { #[doc = #line] })
1069            .collect();
1070        quote! { #(#doc_lines)* }
1071    };
1072
1073    let func_name = func
1074        .name_ident
1075        .clone()
1076        // Can not exist if `::*`
1077        .unwrap_or_else(|| Ident::new(func.name.as_str(), proc_macro2::Span::call_site()));
1078
1079    // void like returns always send back "Null" as an ack
1080    let void_output_mapping = if func.rust_return_type.to_string() == UNIT {
1081        quote! {
1082            .and_then(|e| {
1083                if matches!(e, dioxus_use_js::SerdeJsonValue::Null) {
1084                    Ok(())
1085                } else {
1086                    Err(dioxus_use_js::JsError::Eval {
1087                        func: #js_func_name_ident,
1088                        error: dioxus::document::EvalError::Serialization(
1089                            <dioxus_use_js::SerdeJsonError as dioxus_use_js::SerdeDeError>::custom(dioxus_use_js::__BAD_VOID_RETURN.to_owned())
1090                        )
1091                    })
1092                }
1093            })
1094        }
1095    } else {
1096        quote! {}
1097    };
1098
1099    let callback_arms: Vec<TokenStream2> = callback_name_to_index
1100        .iter()
1101        .map(|(name, index)| {
1102            let callback_name = format_ident!("{}", name);
1103            let callback_info = callback_name_to_info.get(name).unwrap();
1104            let callback_call = match (&callback_info.input, &callback_info.output) {
1105                (None, None) => {
1106                    quote! {
1107                    dioxus::prelude::spawn({let responder = responder.clone(); async move {
1108                        let result = #callback_name(()).await;
1109
1110                        match result {
1111                            // send ack
1112                            Ok(_) => responder.respond(request_id, true, dioxus_use_js::SerdeJsonValue::Null),
1113                            Err(error) => responder.respond(request_id, false, error),
1114                        }
1115                    }});
1116                    }
1117                },
1118                (None, Some(_)) => {
1119                    quote! {
1120                    dioxus::prelude::spawn({let responder = responder.clone(); async move {
1121                        let result = #callback_name(()).await;
1122
1123                        match result {
1124                            Ok(value) => responder.respond(request_id, true, value),
1125                            Err(error) => responder.respond(request_id, false, error),
1126                        }
1127                    }});
1128                }
1129                },
1130                (Some(_), None) => {
1131                    quote! {
1132                    let value = values.next().unwrap();
1133                    let value = match dioxus_use_js::serde_json_from_value(value) {
1134                        Ok(value) => value,
1135                        Err(value) => {
1136                            responder.respond(request_id, false, dioxus_use_js::SerdeJsonValue::String(dioxus_use_js::__UNEXPECTED_CALLBACK_TYPE.to_owned()));
1137                            continue;
1138                        }
1139                    };
1140
1141                    dioxus::prelude::spawn({let responder = responder.clone(); async move {
1142                        let result = #callback_name(value).await;
1143
1144                        match result {
1145                            // send ack
1146                            Ok(_) => responder.respond(request_id, true, dioxus_use_js::SerdeJsonValue::Null),
1147                            Err(error) => responder.respond(request_id, false, error),
1148                        }
1149                    }});
1150                }
1151                },
1152                (Some(_), Some(_)) => {
1153                    quote! {
1154                    let value = values.next().unwrap();
1155                    let value = match dioxus_use_js::serde_json_from_value(value) {
1156                        Ok(value) => value,
1157                        Err(value) => {
1158                            responder.respond(request_id, false, dioxus_use_js::SerdeJsonValue::String(dioxus_use_js::__UNEXPECTED_CALLBACK_TYPE.to_owned()));
1159                            continue;
1160                        }
1161                    };
1162
1163                    dioxus::prelude::spawn({let responder = responder.clone(); async move {
1164                        let result = #callback_name(value).await;
1165
1166                        match result {
1167                            Ok(value) => responder.respond(request_id, true, value),
1168                            Err(error) => responder.respond(request_id, false, error),
1169                        }
1170                    }});
1171                }
1172                }
1173            };
1174            quote! {
1175                #index => {
1176                    #callback_call
1177                }
1178            }
1179        })
1180        .collect();
1181
1182    let callback_spawn = if !callback_arms.is_empty() {
1183        quote! {
1184            dioxus::prelude::spawn({
1185                let mut eval = dioxus_use_js::EvalDrop::new(eval);
1186                    async move {
1187                        let responder = dioxus_use_js::CallbackResponder::new(&invocation_id);
1188                        loop {
1189                            let result = eval.recv::<dioxus_use_js::SerdeJsonValue>().await;
1190                            let value = match result {
1191                                Ok(v) => v,
1192                                Err(e) => {
1193                                    // Though we still may be able to accept more callback requests,
1194                                    // We shutdown otherwise the invocation of this callback will be awaiting forever
1195                                    // since we can't cancel it since we do not know the id. (Dropping the eval triggers shutdown)
1196                                    dioxus::prelude::error!(
1197                                        "Callback receiver errored. Shutting down all callbacks for invocation id `{}`: {:?}",
1198                                        &invocation_id,
1199                                        e
1200                                    );
1201                                    return;
1202                                }
1203                            };
1204                            let dioxus_use_js::SerdeJsonValue::Array(values) = value else {
1205                                unreachable!("{}", dioxus_use_js::__CALLBACK_SEND_VALIDATION_MSG);
1206                            };
1207                            let len = values.len();
1208                            if len != 3 {
1209                                unreachable!("{}", dioxus_use_js::__CALLBACK_SEND_VALIDATION_MSG);
1210                            }
1211                            let mut values = values.into_iter();
1212                            let action = values.next().unwrap().as_u64().expect(dioxus_use_js::__INDEX_VALIDATION_MSG);
1213                            let request_id = values.next().unwrap().as_u64().expect(dioxus_use_js::__INDEX_VALIDATION_MSG);
1214                            match action {
1215                                #(#callback_arms,)*
1216                                _ => unreachable!("{}", dioxus_use_js::__BAD_CALL_MSG),
1217                            }
1218                        }
1219                    }
1220            });
1221        }
1222    } else {
1223        quote! {}
1224    };
1225
1226    let end_statement = quote! {
1227        let value = eval.await.map_err(|e| {
1228            dioxus_use_js::JsError::Eval {
1229                func: #js_func_name_ident,
1230                error: e,
1231            }
1232        })?;
1233        let dioxus_use_js::SerdeJsonValue::Array(values) = value else {
1234            unreachable!("{}", dioxus_use_js::__RESULT_SEND_VALIDATION_MSG);
1235        };
1236        if values.len() != 2 {
1237            unreachable!("{}", dioxus_use_js::__RESULT_SEND_VALIDATION_MSG);
1238        }
1239        let mut values = values.into_iter();
1240        let success = values.next().unwrap().as_bool().expect(dioxus_use_js::__INDEX_VALIDATION_MSG);
1241        if success {
1242            let value = values.next().unwrap();
1243            return dioxus_use_js::serde_json_from_value(value).map_err(|e| {
1244                dioxus_use_js::JsError::Eval {
1245                    func: #js_func_name_ident,
1246                    error: dioxus::document::EvalError::Serialization(e),
1247                }
1248            })
1249            #void_output_mapping;
1250        } else {
1251             return Err(dioxus_use_js::JsError::Threw { func: #js_func_name_ident });
1252        }
1253    };
1254    let macro_invocation_id = uuid::Uuid::now_v7().to_string();
1255    let js_string = if has_callbacks {
1256        quote! {
1257            static INVOCATION_NUM: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
1258            // Each invocation id garuntees a unique namespace for the callback invocation where new requested/responded and everything clean/promises rejected on drop.
1259            let invocation_id = format!("{}{}", #macro_invocation_id, INVOCATION_NUM.fetch_add(1, std::sync::atomic::Ordering::Relaxed));
1260            let js = format!(#js_format, MODULE, &invocation_id);
1261        }
1262    } else {
1263        quote! {let js = format!(#js_format, MODULE);}
1264    };
1265
1266    quote! {
1267        #doc_comment
1268        #[allow(non_snake_case)]
1269        pub async fn #func_name #generic_tokens(#(#param_types),*) -> #return_type_tokens {
1270            const MODULE: Asset = asset!(#asset_path);
1271            const #js_func_name_ident: &str = #js_func_name;
1272            #js_in_comment
1273            #js_string
1274            let mut eval = dioxus::document::eval(js.as_str());
1275            #(#send_calls)*
1276            #callback_spawn
1277            #end_statement
1278        }
1279    }
1280}
1281
1282/// A macro to create rust bindings to javascript and typescript functions. See [README](https://github.com/mcmah309/dioxus-use-js) and [example](https://github.com/mcmah309/dioxus-use-js/blob/master/example/src/main.rs) for more.
1283#[proc_macro]
1284pub fn use_js(input: TokenStream) -> TokenStream {
1285    let input = parse_macro_input!(input as UseJsInput);
1286
1287    let manifest_dir = match std::env::var("CARGO_MANIFEST_DIR") {
1288        Ok(dir) => dir,
1289        Err(_) => {
1290            return TokenStream::from(
1291                syn::Error::new(
1292                    proc_macro2::Span::call_site(),
1293                    "CARGO_MANIFEST_DIR environment variable not found",
1294                )
1295                .to_compile_error(),
1296            );
1297        }
1298    };
1299
1300    let UseJsInput {
1301        js_bundle_path,
1302        ts_source_path,
1303        import_spec,
1304    } = input;
1305
1306    let js_file_path = std::path::Path::new(&manifest_dir).join(js_bundle_path.value());
1307
1308    let js_all_functions = match parse_script_file(&js_file_path, true) {
1309        Ok(funcs) => funcs,
1310        Err(e) => return TokenStream::from(e.to_compile_error()),
1311    };
1312
1313    let js_functions_to_generate =
1314        match get_functions_to_generate(js_all_functions, &import_spec, &js_file_path) {
1315            Ok(funcs) => funcs,
1316            Err(e) => return TokenStream::from(e.to_compile_error()),
1317        };
1318
1319    let functions_to_generate = if let Some(ts_file_path) = ts_source_path {
1320        let ts_file_path = std::path::Path::new(&manifest_dir).join(ts_file_path.value());
1321        let ts_all_functions = match parse_script_file(&ts_file_path, false) {
1322            Ok(funcs) => funcs,
1323            Err(e) => return TokenStream::from(e.to_compile_error()),
1324        };
1325
1326        let ts_functions_to_generate =
1327            match get_functions_to_generate(ts_all_functions, &import_spec, &ts_file_path) {
1328                Ok(funcs) => funcs,
1329                Err(e) => {
1330                    return TokenStream::from(e.to_compile_error());
1331                }
1332            };
1333
1334        for ts_func in ts_functions_to_generate.iter() {
1335            if let Some(js_func) = js_functions_to_generate
1336                .iter()
1337                .find(|f| f.name == ts_func.name)
1338            {
1339                if ts_func.params.len() != js_func.params.len() {
1340                    return TokenStream::from(syn::Error::new(
1341                        proc_macro2::Span::call_site(),
1342                        format!(
1343                            "Function '{}' has different parameter count in JS and TS files. Bundle may be out of date",
1344                            ts_func.name
1345                        ),
1346                    )
1347                    .to_compile_error());
1348                }
1349            } else {
1350                return TokenStream::from(syn::Error::new(
1351                    proc_macro2::Span::call_site(),
1352                    format!(
1353                        "Function '{}' is defined in TS file but not in JS file. Bundle may be out of date",
1354                        ts_func.name
1355                    ),
1356                )
1357                .to_compile_error());
1358            }
1359        }
1360        ts_functions_to_generate
1361    } else {
1362        js_functions_to_generate
1363    };
1364    for function in functions_to_generate.iter() {
1365        for param in function.params.iter() {
1366            if param.name.starts_with("_") && param.name.ends_with("_") {
1367                panic!(
1368                    "Parameter name '{}' in function '{}' is invalid. Parameters starting and ending with underscores are reserved.",
1369                    param.name, function.name
1370                );
1371            }
1372            if param.name == "dioxus" {
1373                panic!(
1374                    "Parameter name 'dioxus' in function '{}' is invalid. This parameter name is reserved.",
1375                    function.name
1376                );
1377            }
1378            if param.name == function.name {
1379                panic!(
1380                    "Parameter name '{}' in function '{}' is invalid. Parameters cannot have the same name as the function.",
1381                    param.name, function.name
1382                );
1383            }
1384        }
1385    }
1386
1387    let function_wrappers: Vec<TokenStream2> = functions_to_generate
1388        .iter()
1389        .map(|func| generate_function_wrapper(func, &js_bundle_path))
1390        .collect();
1391
1392    let expanded = quote! {
1393        #(#function_wrappers)*
1394    };
1395
1396    TokenStream::from(expanded)
1397}
1398
1399//************************************************************************//
1400
1401#[cfg(test)]
1402mod tests {
1403    use super::*;
1404
1405    #[test]
1406    fn test_primitives() {
1407        assert_eq!(
1408            ts_type_to_rust_type(Some("string"), false).to_string(),
1409            "String"
1410        );
1411        assert_eq!(
1412            ts_type_to_rust_type(Some("string"), true).to_string(),
1413            "&str"
1414        );
1415        assert_eq!(
1416            ts_type_to_rust_type(Some("number"), false).to_string(),
1417            "f64"
1418        );
1419        assert_eq!(
1420            ts_type_to_rust_type(Some("number"), true).to_string(),
1421            "f64"
1422        );
1423        assert_eq!(
1424            ts_type_to_rust_type(Some("boolean"), false).to_string(),
1425            "bool"
1426        );
1427        assert_eq!(
1428            ts_type_to_rust_type(Some("boolean"), true).to_string(),
1429            "bool"
1430        );
1431    }
1432
1433    #[test]
1434    fn test_nullable_primitives() {
1435        assert_eq!(
1436            ts_type_to_rust_type(Some("string | null"), true).to_string(),
1437            "Option<&str>"
1438        );
1439        assert_eq!(
1440            ts_type_to_rust_type(Some("string | null"), false).to_string(),
1441            "Option<String>"
1442        );
1443        assert_eq!(
1444            ts_type_to_rust_type(Some("number | null"), true).to_string(),
1445            "Option<f64>"
1446        );
1447        assert_eq!(
1448            ts_type_to_rust_type(Some("number | null"), false).to_string(),
1449            "Option<f64>"
1450        );
1451        assert_eq!(
1452            ts_type_to_rust_type(Some("boolean | null"), true).to_string(),
1453            "Option<bool>"
1454        );
1455        assert_eq!(
1456            ts_type_to_rust_type(Some("boolean | null"), false).to_string(),
1457            "Option<bool>"
1458        );
1459    }
1460
1461    #[test]
1462    fn test_arrays() {
1463        assert_eq!(
1464            ts_type_to_rust_type(Some("string[]"), true).to_string(),
1465            "&[String]"
1466        );
1467        assert_eq!(
1468            ts_type_to_rust_type(Some("string[]"), false).to_string(),
1469            "Vec<String>"
1470        );
1471        assert_eq!(
1472            ts_type_to_rust_type(Some("Array<number>"), true).to_string(),
1473            "&[f64]"
1474        );
1475        assert_eq!(
1476            ts_type_to_rust_type(Some("Array<number>"), false).to_string(),
1477            "Vec<f64>"
1478        );
1479    }
1480
1481    #[test]
1482    fn test_nullable_array_elements() {
1483        assert_eq!(
1484            ts_type_to_rust_type(Some("(string | null)[]"), true).to_string(),
1485            "&[Option<String>]"
1486        );
1487        assert_eq!(
1488            ts_type_to_rust_type(Some("(string | null)[]"), false).to_string(),
1489            "Vec<Option<String>>"
1490        );
1491        assert_eq!(
1492            ts_type_to_rust_type(Some("Array<number | null>"), true).to_string(),
1493            "&[Option<f64>]"
1494        );
1495        assert_eq!(
1496            ts_type_to_rust_type(Some("Array<number | null>"), false).to_string(),
1497            "Vec<Option<f64>>"
1498        );
1499    }
1500
1501    #[test]
1502    fn test_nullable_array_itself() {
1503        assert_eq!(
1504            ts_type_to_rust_type(Some("string[] | null"), true).to_string(),
1505            "Option<&[String]>"
1506        );
1507        assert_eq!(
1508            ts_type_to_rust_type(Some("string[] | null"), false).to_string(),
1509            "Option<Vec<String>>"
1510        );
1511        assert_eq!(
1512            ts_type_to_rust_type(Some("Array<number> | null"), true).to_string(),
1513            "Option<&[f64]>"
1514        );
1515        assert_eq!(
1516            ts_type_to_rust_type(Some("Array<number> | null"), false).to_string(),
1517            "Option<Vec<f64>>"
1518        );
1519    }
1520
1521    #[test]
1522    fn test_nullable_array_and_elements() {
1523        assert_eq!(
1524            ts_type_to_rust_type(Some("Array<string | null> | null"), true).to_string(),
1525            "Option<&[Option<String>]>"
1526        );
1527        assert_eq!(
1528            ts_type_to_rust_type(Some("Array<string | null> | null"), false).to_string(),
1529            "Option<Vec<Option<String>>>"
1530        );
1531    }
1532
1533    #[test]
1534    fn test_fallback_for_union() {
1535        assert_eq!(
1536            ts_type_to_rust_type(Some("string | number"), true).to_string(),
1537            "impl dioxus_use_js::SerdeSerialize"
1538        );
1539        assert_eq!(
1540            ts_type_to_rust_type(Some("string | number"), false).to_string(),
1541            "DeserializeOwned"
1542        );
1543        assert_eq!(
1544            ts_type_to_rust_type(Some("string | number | null"), true).to_string(),
1545            "impl dioxus_use_js::SerdeSerialize"
1546        );
1547        assert_eq!(
1548            ts_type_to_rust_type(Some("string | number | null"), false).to_string(),
1549            "DeserializeOwned"
1550        );
1551    }
1552
1553    #[test]
1554    fn test_unknown_types() {
1555        assert_eq!(
1556            ts_type_to_rust_type(Some("foo"), true).to_string(),
1557            "impl dioxus_use_js::SerdeSerialize"
1558        );
1559        assert_eq!(
1560            ts_type_to_rust_type(Some("foo"), false).to_string(),
1561            "DeserializeOwned"
1562        );
1563
1564        assert_eq!(
1565            ts_type_to_rust_type(Some("any"), true).to_string(),
1566            "impl dioxus_use_js::SerdeSerialize"
1567        );
1568        assert_eq!(
1569            ts_type_to_rust_type(Some("any"), false).to_string(),
1570            "DeserializeOwned"
1571        );
1572        assert_eq!(
1573            ts_type_to_rust_type(Some("object"), true).to_string(),
1574            "impl dioxus_use_js::SerdeSerialize"
1575        );
1576        assert_eq!(
1577            ts_type_to_rust_type(Some("object"), false).to_string(),
1578            "DeserializeOwned"
1579        );
1580        assert_eq!(
1581            ts_type_to_rust_type(Some("unknown"), true).to_string(),
1582            "impl dioxus_use_js::SerdeSerialize"
1583        );
1584        assert_eq!(
1585            ts_type_to_rust_type(Some("unknown"), false).to_string(),
1586            "DeserializeOwned"
1587        );
1588
1589        assert_eq!(ts_type_to_rust_type(Some("void"), false).to_string(), "()");
1590        assert_eq!(
1591            ts_type_to_rust_type(Some("undefined"), false).to_string(),
1592            "()"
1593        );
1594        assert_eq!(ts_type_to_rust_type(Some("null"), false).to_string(), "()");
1595    }
1596
1597    #[test]
1598    fn test_extra_whitespace() {
1599        assert_eq!(
1600            ts_type_to_rust_type(Some("  string | null  "), true).to_string(),
1601            "Option<&str>"
1602        );
1603        assert_eq!(
1604            ts_type_to_rust_type(Some("  string | null  "), false).to_string(),
1605            "Option<String>"
1606        );
1607        assert_eq!(
1608            ts_type_to_rust_type(Some(" Array< string > "), true).to_string(),
1609            "&[String]"
1610        );
1611        assert_eq!(
1612            ts_type_to_rust_type(Some(" Array< string > "), false).to_string(),
1613            "Vec<String>"
1614        );
1615    }
1616
1617    #[test]
1618    fn test_map_types() {
1619        assert_eq!(
1620            ts_type_to_rust_type(Some("Map<string, number>"), true).to_string(),
1621            "&std::collections::HashMap<String, f64>"
1622        );
1623        assert_eq!(
1624            ts_type_to_rust_type(Some("Map<string, number>"), false).to_string(),
1625            "std::collections::HashMap<String, f64>"
1626        );
1627        assert_eq!(
1628            ts_type_to_rust_type(Some("Map<string, boolean>"), true).to_string(),
1629            "&std::collections::HashMap<String, bool>"
1630        );
1631        assert_eq!(
1632            ts_type_to_rust_type(Some("Map<string, boolean>"), false).to_string(),
1633            "std::collections::HashMap<String, bool>"
1634        );
1635        assert_eq!(
1636            ts_type_to_rust_type(Some("Map<number, string>"), true).to_string(),
1637            "&std::collections::HashMap<f64, String>"
1638        );
1639        assert_eq!(
1640            ts_type_to_rust_type(Some("Map<number, string>"), false).to_string(),
1641            "std::collections::HashMap<f64, String>"
1642        );
1643    }
1644
1645    #[test]
1646    fn test_set_types() {
1647        assert_eq!(
1648            ts_type_to_rust_type(Some("Set<string>"), true).to_string(),
1649            "&std::collections::HashSet<String>"
1650        );
1651        assert_eq!(
1652            ts_type_to_rust_type(Some("Set<string>"), false).to_string(),
1653            "std::collections::HashSet<String>"
1654        );
1655        assert_eq!(
1656            ts_type_to_rust_type(Some("Set<number>"), true).to_string(),
1657            "&std::collections::HashSet<f64>"
1658        );
1659        assert_eq!(
1660            ts_type_to_rust_type(Some("Set<number>"), false).to_string(),
1661            "std::collections::HashSet<f64>"
1662        );
1663        assert_eq!(
1664            ts_type_to_rust_type(Some("Set<boolean>"), true).to_string(),
1665            "&std::collections::HashSet<bool>"
1666        );
1667        assert_eq!(
1668            ts_type_to_rust_type(Some("Set<boolean>"), false).to_string(),
1669            "std::collections::HashSet<bool>"
1670        );
1671    }
1672
1673    #[test]
1674    fn test_rust_callback() {
1675        assert_eq!(
1676            ts_type_to_rust_type(Some("RustCallback<number,string>"), true).to_string(),
1677            "dioxus::core::Callback<f64, impl Future<Output = Result<String, dioxus_use_js::SerdeJsonValue>> + 'static>"
1678        );
1679        assert_eq!(
1680            ts_type_to_rust_type(Some("RustCallback<void,string>"), true).to_string(),
1681            "dioxus::core::Callback<(), impl Future<Output = Result<String, dioxus_use_js::SerdeJsonValue>> + 'static>"
1682        );
1683        assert_eq!(
1684            ts_type_to_rust_type(Some("RustCallback<void,void>"), true).to_string(),
1685            "dioxus::core::Callback<(), impl Future<Output = Result<(), dioxus_use_js::SerdeJsonValue>> + 'static>"
1686        );
1687        assert_eq!(
1688            ts_type_to_rust_type(Some("RustCallback<number,void>"), true).to_string(),
1689            "dioxus::core::Callback<f64, impl Future<Output = Result<(), dioxus_use_js::SerdeJsonValue>> + 'static>"
1690        );
1691    }
1692
1693    #[test]
1694    fn test_promise_types() {
1695        assert_eq!(
1696            ts_type_to_rust_type(Some("Promise<string>"), false).to_string(),
1697            "String"
1698        );
1699        assert_eq!(
1700            ts_type_to_rust_type(Some("Promise<number>"), false).to_string(),
1701            "f64"
1702        );
1703        assert_eq!(
1704            ts_type_to_rust_type(Some("Promise<boolean>"), false).to_string(),
1705            "bool"
1706        );
1707    }
1708
1709    #[test]
1710    fn test_json_types() {
1711        assert_eq!(
1712            ts_type_to_rust_type(Some("Json"), true).to_string(),
1713            "&dioxus_use_js::SerdeJsonValue"
1714        );
1715        assert_eq!(
1716            ts_type_to_rust_type(Some("Json"), false).to_string(),
1717            "dioxus_use_js::SerdeJsonValue"
1718        );
1719    }
1720
1721    #[test]
1722    fn test_js_value() {
1723        assert_eq!(
1724            ts_type_to_rust_type(Some("JsValue"), true).to_string(),
1725            "&dioxus_use_js::JsValue"
1726        );
1727        assert_eq!(
1728            ts_type_to_rust_type(Some("JsValue"), false).to_string(),
1729            "dioxus_use_js::JsValue"
1730        );
1731        assert_eq!(
1732            ts_type_to_rust_type(Some("JsValue<CustomType>"), true).to_string(),
1733            "&dioxus_use_js::JsValue"
1734        );
1735        assert_eq!(
1736            ts_type_to_rust_type(Some("JsValue<CustomType>"), false).to_string(),
1737            "dioxus_use_js::JsValue"
1738        );
1739
1740        assert_eq!(
1741            ts_type_to_rust_type(Some("Promise<JsValue>"), false).to_string(),
1742            "dioxus_use_js::JsValue"
1743        );
1744
1745        assert_eq!(
1746            ts_type_to_rust_type(Some("Promise<JsValue | null>"), false).to_string(),
1747            "Option<dioxus_use_js::JsValue>"
1748        );
1749        assert_eq!(
1750            ts_type_to_rust_type(Some("JsValue | null"), true).to_string(),
1751            "Option<&dioxus_use_js::JsValue>"
1752        );
1753        assert_eq!(
1754            ts_type_to_rust_type(Some("JsValue | null"), false).to_string(),
1755            "Option<dioxus_use_js::JsValue>"
1756        );
1757    }
1758}