fp_bindgen/types/
type_ident.rs

1use super::is_runtime_bound;
2use crate::primitives::Primitive;
3use std::num::NonZeroUsize;
4use std::{convert::TryFrom, fmt::Display, str::FromStr};
5use syn::{PathArguments, TypeParamBound, TypePath, TypeTuple};
6
7#[derive(Clone, Default, Debug, Eq, Hash, PartialEq)]
8#[non_exhaustive]
9pub struct TypeIdent {
10    pub name: String,
11    pub generic_args: Vec<(TypeIdent, Vec<String>)>,
12    /// If this TypeIdent represents an array this field will store the length
13    pub array: Option<NonZeroUsize>,
14}
15
16impl TypeIdent {
17    pub fn new(name: impl Into<String>, generic_args: Vec<(TypeIdent, Vec<String>)>) -> Self {
18        Self {
19            name: name.into(),
20            generic_args,
21            array: None,
22        }
23    }
24
25    pub fn is_array(&self) -> bool {
26        self.array.is_some()
27    }
28
29    pub fn is_primitive(&self) -> bool {
30        self.as_primitive().is_some()
31    }
32
33    pub fn as_primitive(&self) -> Option<Primitive> {
34        if self.array.is_none() {
35            Primitive::from_str(&self.name).ok()
36        } else {
37            None
38        }
39    }
40
41    pub fn format(&self, include_bounds: bool) -> String {
42        let ty = if self.generic_args.is_empty() {
43            self.name.clone()
44        } else {
45            format_args!(
46                "{}<{}>",
47                self.name,
48                self.generic_args
49                    .iter()
50                    .map(|(arg, bounds)| {
51                        if bounds.is_empty() || !include_bounds {
52                            format!("{arg}")
53                        } else {
54                            format!(
55                                "{}: {}",
56                                arg,
57                                bounds
58                                    .iter()
59                                    .filter(|b| is_runtime_bound(b))
60                                    .cloned()
61                                    .collect::<Vec<_>>()
62                                    .join(" + ")
63                            )
64                        }
65                    })
66                    .collect::<Vec<_>>()
67                    .join(", ")
68            )
69            .to_string()
70        };
71
72        match self.array {
73            Some(len) => format!("[{ty}; {len}]"),
74            None => ty,
75        }
76    }
77}
78
79impl Display for TypeIdent {
80    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81        f.write_str(&self.format(true))
82    }
83}
84
85impl From<&str> for TypeIdent {
86    fn from(name: &str) -> Self {
87        Self::from_str(name)
88            .unwrap_or_else(|_| panic!("Could not convert '{}' into a TypeIdent", name))
89    }
90}
91
92impl FromStr for TypeIdent {
93    type Err = String;
94
95    fn from_str(string: &str) -> Result<Self, Self::Err> {
96        let (string, array) = if string.starts_with('[') {
97            // Remove brackets and split on ;
98            let split = string
99                .strip_prefix('[')
100                .and_then(|s| s.strip_suffix(']'))
101                .ok_or(format!("Invalid array syntax in: {string}"))?
102                .split(';')
103                .collect::<Vec<_>>();
104
105            let element = split[0].trim();
106            let len = usize::from_str(split[1].trim())
107                .map_err(|_| format!("Invalid array length in: {string}"))?;
108
109            let primitive = Primitive::from_str(element)?;
110            if primitive.js_array_name().is_none() {
111                return Err(format!(
112                    "Only arrays of primitives supported by Javascript are allowed, found: {string}"
113                ));
114            }
115
116            (element, NonZeroUsize::new(len))
117        } else {
118            (string, None)
119        };
120
121        if let Some(start_index) = string.find('<') {
122            let end_index = string.rfind('>').unwrap_or(string.len());
123            Ok(Self {
124                name: string[0..start_index]
125                    .trim_end_matches(|c: char| c.is_whitespace() || c == ':')
126                    .to_owned(),
127                generic_args: string[start_index + 1..end_index]
128                    .split(',')
129                    .map(|arg| {
130                        let (arg, bounds) = arg.split_once(':').unwrap_or((arg, ""));
131                        let ident = Self::from_str(arg.trim());
132                        let bounds = bounds
133                            .split('+')
134                            .map(|b| b.trim().to_string())
135                            .filter(|b| !b.is_empty())
136                            .collect();
137                        ident.map(|ident| (ident, bounds))
138                    })
139                    .collect::<Result<Vec<(Self, Vec<String>)>, Self::Err>>()?,
140                array,
141            })
142        } else {
143            Ok(Self {
144                name: string.into(),
145                generic_args: vec![],
146                array,
147            })
148        }
149    }
150}
151
152impl From<String> for TypeIdent {
153    fn from(name: String) -> Self {
154        Self {
155            name,
156            generic_args: Vec::new(),
157            ..Default::default()
158        }
159    }
160}
161
162impl Ord for TypeIdent {
163    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
164        // We only compare the name and array so that any type is only included once in
165        // a map, regardless of how many concrete instances are used with
166        // different generic arguments.
167        (&self.name, self.array).cmp(&(&other.name, other.array))
168    }
169}
170
171impl PartialOrd for TypeIdent {
172    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
173        // We only compare the name and array so that any type is only included once in
174        // a map, regardless of how many concrete instances are used with
175        // different generic arguments.
176        (&self.name, self.array).partial_cmp(&(&other.name, other.array))
177    }
178}
179
180impl TryFrom<&syn::Type> for TypeIdent {
181    type Error = String;
182
183    fn try_from(ty: &syn::Type) -> Result<Self, Self::Error> {
184        match ty {
185            syn::Type::Array(syn::TypeArray {
186                elem,
187                len: syn::Expr::Lit(syn::ExprLit { lit, .. }),
188                ..
189            }) => {
190                let array_len = match lit {
191                    syn::Lit::Int(int) => int.base10_digits().parse::<usize>(),
192                    _ => panic!(),
193                }
194                .unwrap();
195                let elem_ident = TypeIdent::try_from(elem.as_ref())?;
196
197                Ok(Self {
198                    name: elem_ident.name,
199                    generic_args: vec![],
200                    array: NonZeroUsize::new(array_len),
201                })
202            }
203            syn::Type::Path(TypePath { path, qself }) if qself.is_none() => {
204                let mut generic_args = vec![];
205                if let Some(segment) = path.segments.last() {
206                    if let PathArguments::AngleBracketed(args) = &segment.arguments {
207                        for arg in &args.args {
208                            let generic_arg_ident;
209                            let mut generic_arg_bounds = vec![];
210                            match arg {
211                                syn::GenericArgument::Type(ty) => {
212                                    generic_arg_ident = Some(TypeIdent::try_from(ty)?);
213                                }
214                                syn::GenericArgument::Constraint(cons) => {
215                                    generic_arg_ident =
216                                        Some(TypeIdent::new(cons.ident.to_string(), vec![]));
217
218                                    let bounds = cons
219                                        .bounds
220                                        .iter()
221                                        .map(|bound| match bound {
222                                            TypeParamBound::Trait(tr) => {
223                                                Ok(path_to_string(&tr.path))
224                                            }
225                                            TypeParamBound::Lifetime(_) => Err(format!(
226                                                "Lifecycle bounds are not supported: {bound:?}"
227                                            )),
228                                        })
229                                        .collect::<Vec<_>>();
230                                    for bound in bounds {
231                                        generic_arg_bounds.push(bound?);
232                                    }
233                                }
234                                arg => {
235                                    return Err(format!("Unsupported generic argument: {arg:?}"));
236                                }
237                            }
238                            if let Some(ident) = generic_arg_ident {
239                                generic_args.push((ident, generic_arg_bounds));
240                            }
241                        }
242                    }
243                }
244
245                Ok(Self {
246                    name: path_to_string(path),
247                    generic_args,
248                    ..Default::default()
249                })
250            }
251            syn::Type::Tuple(TypeTuple {
252                elems,
253                paren_token: _,
254            }) if elems.is_empty() => Ok(TypeIdent::from("()")),
255            ty => Err(format!("Unsupported type: {ty:?}")),
256        }
257    }
258}
259
260fn path_to_string(path: &syn::Path) -> String {
261    path.segments
262        .iter()
263        .map(|segment| segment.ident.to_string())
264        .collect::<Vec<_>>()
265        .join("::")
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271
272    #[test]
273    fn type_ident_from_syn_type() {
274        let ty = syn::parse_str::<syn::Type>("u32").unwrap();
275        let t = TypeIdent::try_from(&ty).unwrap();
276        assert_eq!(t.name, "u32");
277        assert!(t.generic_args.is_empty());
278
279        let ty = syn::parse_str::<syn::Type>("Vec<u32>").unwrap();
280        let t = TypeIdent::try_from(&ty).unwrap();
281        assert_eq!(t.name, "Vec");
282        assert_eq!(
283            t.generic_args,
284            vec![(TypeIdent::new("u32", vec![]), vec![])]
285        );
286
287        let ty = syn::parse_str::<syn::Type>("BTreeMap<K, V>").unwrap();
288        let t = TypeIdent::try_from(&ty).unwrap();
289        assert_eq!(t.name, "BTreeMap");
290        assert_eq!(
291            t.generic_args,
292            vec![
293                (TypeIdent::new("K", vec![]), vec![]),
294                (TypeIdent::new("V", vec![]), vec![])
295            ]
296        );
297
298        let ty = syn::parse_str::<syn::Type>("Vec<T: Debug + Display>").unwrap();
299        let t = TypeIdent::try_from(&ty).unwrap();
300        assert_eq!(t.name, "Vec");
301        assert_eq!(
302            t.generic_args,
303            vec![(
304                TypeIdent::new("T", vec![]),
305                vec!["Debug".into(), "Display".into()]
306            )]
307        );
308    }
309
310    #[test]
311    fn type_ident_from_str() {
312        let t = TypeIdent::from_str("u32").unwrap();
313        assert_eq!(t.name, "u32");
314        assert!(t.generic_args.is_empty());
315
316        let t = TypeIdent::from_str("Vec<u32>").unwrap();
317        assert_eq!(t.name, "Vec");
318        assert_eq!(
319            t.generic_args,
320            vec![(TypeIdent::new("u32", vec![]), vec![])]
321        );
322
323        let t = TypeIdent::from_str("BTreeMap<K, V>").unwrap();
324        assert_eq!(t.name, "BTreeMap");
325        assert_eq!(
326            t.generic_args,
327            vec![
328                (TypeIdent::new("K", vec![]), vec![]),
329                (TypeIdent::new("V", vec![]), vec![])
330            ]
331        );
332
333        let t = TypeIdent::from_str("Vec<T: Debug + Display>").unwrap();
334        assert_eq!(t.name, "Vec");
335        assert_eq!(
336            t.generic_args,
337            vec![(
338                TypeIdent::new("T", vec![]),
339                vec!["Debug".into(), "Display".into()]
340            )]
341        );
342    }
343
344    #[test]
345    fn type_ident_from_str_array() {
346        let t = TypeIdent::from_str("[u32; 8]").unwrap();
347        assert_eq!(t.name, "u32");
348        assert!(t.generic_args.is_empty());
349        assert_eq!(t.array, NonZeroUsize::new(8));
350
351        // Cannot create non-primitive arrays, and other error scenarios
352        assert!(TypeIdent::from_str("[Vec<f32>; 8]").is_err());
353        assert!(TypeIdent::from_str("[u32;]").is_err());
354        assert!(TypeIdent::from_str("[u32; foo]").is_err());
355        assert!(TypeIdent::from_str("[u32; -1]").is_err());
356
357        // Unsupported primitive array types
358        assert!(TypeIdent::from_str("[u64; 8]").is_err());
359    }
360}