1use std::borrow::Cow;
2
3use {
4    proc_macro2::Span,
5    syn::{spanned::Spanned, GenericArgument, TypePath},
6};
7
8#[derive(Debug, Clone)]
9pub enum RustType<'a> {
10    Optional {
11        syn: Cow<'a, TypePath>,
12        inner: Box<RustType<'a>>,
13        span: Span,
14    },
15    List {
16        syn: Cow<'a, syn::Type>,
17        inner: Box<RustType<'a>>,
18        span: Span,
19    },
20    Ref {
21        syn: Cow<'a, syn::Type>,
22        inner: Box<RustType<'a>>,
23        span: Span,
24    },
25    SimpleType {
26        syn: Cow<'a, syn::Type>,
27        span: Span,
28    },
29}
30
31impl RustType<'_> {
32    pub fn into_owned(self) -> RustType<'static> {
33        match self {
34            RustType::Optional { syn, inner, span } => RustType::Optional {
35                syn: Cow::Owned(syn.into_owned()),
36                inner: Box::new(inner.into_owned()),
37                span,
38            },
39            RustType::List { syn, inner, span } => RustType::List {
40                syn: Cow::Owned(syn.into_owned()),
41                inner: Box::new(inner.into_owned()),
42                span,
43            },
44            RustType::Ref { syn, inner, span } => RustType::Ref {
45                syn: Cow::Owned(syn.into_owned()),
46                inner: Box::new(inner.into_owned()),
47                span,
48            },
49            RustType::SimpleType { syn, span } => RustType::SimpleType {
50                syn: Cow::Owned(syn.into_owned()),
51                span,
52            },
53        }
54    }
55
56    pub fn span(&self) -> Span {
57        match self {
58            RustType::Optional { span, .. } => *span,
59            RustType::List { span, .. } => *span,
60            RustType::Ref { span, .. } => *span,
61            RustType::SimpleType { span, .. } => *span,
62        }
63    }
64}
65
66pub fn parse_rust_type(ty: &syn::Type) -> RustType<'_> {
67    let span = ty.span();
68    match ty {
69        syn::Type::Path(type_path) => {
70            if let Some(last_segment) = type_path.path.segments.last() {
71                match last_segment.ident.to_string().as_str() {
72                    "Box" | "Arc" | "Rc" => {
73                        if let Some(inner_type) = extract_generic_argument(last_segment) {
74                            return RustType::Ref {
75                                syn: Cow::Borrowed(ty),
76                                inner: Box::new(parse_rust_type(inner_type)),
77                                span,
78                            };
79                        }
80                    }
81                    "Option" => {
82                        if let Some(inner_type) = extract_generic_argument(last_segment) {
83                            return RustType::Optional {
84                                syn: Cow::Borrowed(type_path),
85                                inner: Box::new(parse_rust_type(inner_type)),
86                                span,
87                            };
88                        }
89                    }
90                    "Vec" => {
91                        if let Some(inner_type) = extract_generic_argument(last_segment) {
92                            return RustType::List {
93                                syn: Cow::Borrowed(ty),
94                                inner: Box::new(parse_rust_type(inner_type)),
95                                span,
96                            };
97                        }
98                    }
99                    _ => {}
100                }
101            }
102        }
103        syn::Type::Reference(syn::TypeReference { elem, .. })
104            if matches!(**elem, syn::Type::Slice(_)) =>
105        {
106            let syn::Type::Slice(array) = &**elem else {
107                unreachable!()
108            };
109            return RustType::List {
110                syn: Cow::Borrowed(ty),
111                inner: Box::new(parse_rust_type(&array.elem)),
112                span,
113            };
114        }
115        syn::Type::Reference(reference) => {
116            return RustType::Ref {
117                syn: Cow::Borrowed(ty),
118                inner: Box::new(parse_rust_type(&reference.elem)),
119                span,
120            }
121        }
122        syn::Type::Array(array) => {
123            return RustType::List {
124                syn: Cow::Borrowed(ty),
125                inner: Box::new(parse_rust_type(&array.elem)),
126                span,
127            }
128        }
129        syn::Type::Slice(slice) => {
130            return RustType::List {
131                syn: Cow::Borrowed(ty),
132                inner: Box::new(parse_rust_type(&slice.elem)),
133                span,
134            }
135        }
136        _ => {}
137    }
138
139    RustType::SimpleType {
140        syn: Cow::Borrowed(ty),
141        span,
142    }
143}
144
145fn extract_generic_argument(segment: &syn::PathSegment) -> Option<&syn::Type> {
147    if let syn::PathArguments::AngleBracketed(angle_bracketed) = &segment.arguments {
148        for arg in &angle_bracketed.args {
149            if let syn::GenericArgument::Type(inner_type) = arg {
150                return Some(inner_type);
151            }
152        }
153    }
154
155    None
156}
157
158impl<'a> RustType<'a> {
159    pub fn to_syn(&self) -> syn::Type {
160        match self {
161            RustType::Optional { syn, .. } => syn::Type::Path(syn.clone().into_owned()),
162            RustType::List { syn, .. } => syn.clone().into_owned(),
163            RustType::Ref { syn, .. } => syn.clone().into_owned(),
164            RustType::SimpleType { syn, .. } => syn.clone().into_owned(),
165        }
166    }
167
168    pub fn replace_inner(self, new_inner: RustType<'a>) -> RustType<'a> {
169        match self {
170            RustType::SimpleType { .. } => {
171                panic!("Can't replace inner on simple or unknown types")
172            }
173            RustType::Optional { mut syn, span, .. } => {
174                syn.to_mut().replace_generic_param(&new_inner);
175                RustType::Optional {
176                    syn,
177                    inner: Box::new(new_inner),
178                    span,
179                }
180            }
181            RustType::Ref { mut syn, span, .. } => {
182                match syn.to_mut() {
183                    syn::Type::Path(path) => path.replace_generic_param(&new_inner),
184                    syn::Type::Reference(reference) => reference.elem = Box::new(new_inner.to_syn()),
185                    _ => panic!("We shouldn't have constructed RustType::Ref for anything else than these types")
186                }
187                RustType::Ref {
188                    syn,
189                    inner: Box::new(new_inner),
190                    span,
191                }
192            }
193            RustType::List { mut syn, span, .. } => {
194                match syn.to_mut() {
195                    syn::Type::Path(path) => path.replace_generic_param(&new_inner),
196                    syn::Type::Array(array) => array.elem = Box::new(new_inner.to_syn()),
197                    syn::Type::Slice(slice) => slice.elem = Box::new(new_inner.to_syn()),
198                    syn::Type::Reference(ref_to_slice) => {
199                        let syn::Type::Slice(slice) = &mut *ref_to_slice.elem
200                            else { panic!("We shouldn't have constructed RustType::List for a Ref unless the type beneath is a Slice") };
201                        slice.elem = Box::new(new_inner.to_syn());
202                    }
203                    _ => panic!("We shouldn't have constructed RustType::List for anything else than these types")
204                }
205
206                RustType::List {
207                    syn,
208                    inner: Box::new(new_inner),
209                    span,
210                }
211            }
212        }
213    }
214}
215
216trait TypePathExt {
217    fn replace_generic_param(&mut self, replacement: &RustType<'_>);
218}
219
220impl TypePathExt for syn::TypePath {
221    fn replace_generic_param(&mut self, replacement: &RustType<'_>) {
222        fn get_generic_argument(type_path: &mut syn::TypePath) -> Option<&mut GenericArgument> {
223            let segment = type_path.path.segments.last_mut()?;
224
225            match &mut segment.arguments {
226                syn::PathArguments::AngleBracketed(angle_bracketed) => {
227                    angle_bracketed.args.first_mut()
228                }
229                _ => None,
230            }
231        }
232
233        let generic_argument = get_generic_argument(self)
234            .expect("Don't call replace_generic_param on a type without a generic argument");
235
236        *generic_argument = syn::GenericArgument::Type(replacement.to_syn())
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use {proc_macro2::TokenStream, quote::quote, rstest::rstest};
243
244    use super::*;
245
246    #[rstest]
247    #[case::replace_on_option(
248        quote! { Option<i32> },
249        quote! { Vec<i32> },
250        quote! { Option<Vec<i32>> },
251    )]
252    #[case::replace_on_vec(
253        quote! { Vec<i32> },
254        quote! { Vec<i32> },
255        quote! { Vec<Vec<i32>> },
256    )]
257    #[case::replace_on_box(
258        quote! { Box<i32> },
259        quote! { Vec<i32> },
260        quote! { Box<Vec<i32>> },
261    )]
262    #[case::replace_on_arc(
263        quote! { Arc<i32> },
264        quote! { Vec<i32> },
265        quote! { Arc<Vec<i32>> },
266    )]
267    #[case::replace_with_complex_inner(
268        quote! { Arc<i32> },
269        quote! { Vec<chrono::DateTime<chrono::Utc>> },
270        quote! { Arc<Vec<chrono::DateTime<chrono::Utc>>> },
271    )]
272    #[case::replace_with_a_full_path(
273        quote! { std::sync::Arc<i32> },
274        quote! { Vec<chrono::DateTime<chrono::Utc>> },
275        quote! { std::sync::Arc<Vec<chrono::DateTime<chrono::Utc>>> },
276    )]
277    fn test_replace_inner(
278        #[case] original: TokenStream,
279        #[case] replace: TokenStream,
280        #[case] expected: TokenStream,
281    ) {
282        let original = syn::parse2(original).unwrap();
283        let replace = syn::parse2(replace).unwrap();
284        let expected = syn::parse2(expected).unwrap();
285
286        let result = parse_rust_type(&original)
287            .replace_inner(parse_rust_type(&replace))
288            .to_syn();
289
290        assert_eq!(result, expected);
291    }
292}