1use std::borrow::Cow;
2
3use {
4 proc_macro2::Span,
5 syn::{spanned::Spanned, GenericArgument, TypePath},
6};
7
8#[derive(Debug, Clone)]
9pub enum RustType<'a> {
10 Optional {
11 syn: Cow<'a, TypePath>,
12 inner: Box<RustType<'a>>,
13 span: Span,
14 },
15 List {
16 syn: Cow<'a, syn::Type>,
17 inner: Box<RustType<'a>>,
18 span: Span,
19 },
20 Ref {
21 syn: Cow<'a, syn::Type>,
22 inner: Box<RustType<'a>>,
23 span: Span,
24 },
25 SimpleType {
26 syn: Cow<'a, syn::Type>,
27 span: Span,
28 },
29}
30
31impl RustType<'_> {
32 pub fn into_owned(self) -> RustType<'static> {
33 match self {
34 RustType::Optional { syn, inner, span } => RustType::Optional {
35 syn: Cow::Owned(syn.into_owned()),
36 inner: Box::new(inner.into_owned()),
37 span,
38 },
39 RustType::List { syn, inner, span } => RustType::List {
40 syn: Cow::Owned(syn.into_owned()),
41 inner: Box::new(inner.into_owned()),
42 span,
43 },
44 RustType::Ref { syn, inner, span } => RustType::Ref {
45 syn: Cow::Owned(syn.into_owned()),
46 inner: Box::new(inner.into_owned()),
47 span,
48 },
49 RustType::SimpleType { syn, span } => RustType::SimpleType {
50 syn: Cow::Owned(syn.into_owned()),
51 span,
52 },
53 }
54 }
55
56 pub fn span(&self) -> Span {
57 match self {
58 RustType::Optional { span, .. } => *span,
59 RustType::List { span, .. } => *span,
60 RustType::Ref { span, .. } => *span,
61 RustType::SimpleType { span, .. } => *span,
62 }
63 }
64}
65
66pub fn parse_rust_type(ty: &syn::Type) -> RustType<'_> {
67 let span = ty.span();
68 match ty {
69 syn::Type::Path(type_path) => {
70 if let Some(last_segment) = type_path.path.segments.last() {
71 match last_segment.ident.to_string().as_str() {
72 "Box" | "Arc" | "Rc" => {
73 if let Some(inner_type) = extract_generic_argument(last_segment) {
74 return RustType::Ref {
75 syn: Cow::Borrowed(ty),
76 inner: Box::new(parse_rust_type(inner_type)),
77 span,
78 };
79 }
80 }
81 "Option" => {
82 if let Some(inner_type) = extract_generic_argument(last_segment) {
83 return RustType::Optional {
84 syn: Cow::Borrowed(type_path),
85 inner: Box::new(parse_rust_type(inner_type)),
86 span,
87 };
88 }
89 }
90 "Vec" => {
91 if let Some(inner_type) = extract_generic_argument(last_segment) {
92 return RustType::List {
93 syn: Cow::Borrowed(ty),
94 inner: Box::new(parse_rust_type(inner_type)),
95 span,
96 };
97 }
98 }
99 _ => {}
100 }
101 }
102 }
103 syn::Type::Reference(syn::TypeReference { elem, .. })
104 if matches!(**elem, syn::Type::Slice(_)) =>
105 {
106 let syn::Type::Slice(array) = &**elem else {
107 unreachable!()
108 };
109 return RustType::List {
110 syn: Cow::Borrowed(ty),
111 inner: Box::new(parse_rust_type(&array.elem)),
112 span,
113 };
114 }
115 syn::Type::Reference(reference) => {
116 return RustType::Ref {
117 syn: Cow::Borrowed(ty),
118 inner: Box::new(parse_rust_type(&reference.elem)),
119 span,
120 }
121 }
122 syn::Type::Array(array) => {
123 return RustType::List {
124 syn: Cow::Borrowed(ty),
125 inner: Box::new(parse_rust_type(&array.elem)),
126 span,
127 }
128 }
129 syn::Type::Slice(slice) => {
130 return RustType::List {
131 syn: Cow::Borrowed(ty),
132 inner: Box::new(parse_rust_type(&slice.elem)),
133 span,
134 }
135 }
136 _ => {}
137 }
138
139 RustType::SimpleType {
140 syn: Cow::Borrowed(ty),
141 span,
142 }
143}
144
145fn extract_generic_argument(segment: &syn::PathSegment) -> Option<&syn::Type> {
147 if let syn::PathArguments::AngleBracketed(angle_bracketed) = &segment.arguments {
148 for arg in &angle_bracketed.args {
149 if let syn::GenericArgument::Type(inner_type) = arg {
150 return Some(inner_type);
151 }
152 }
153 }
154
155 None
156}
157
158impl<'a> RustType<'a> {
159 pub fn to_syn(&self) -> syn::Type {
160 match self {
161 RustType::Optional { syn, .. } => syn::Type::Path(syn.clone().into_owned()),
162 RustType::List { syn, .. } => syn.clone().into_owned(),
163 RustType::Ref { syn, .. } => syn.clone().into_owned(),
164 RustType::SimpleType { syn, .. } => syn.clone().into_owned(),
165 }
166 }
167
168 pub fn replace_inner(self, new_inner: RustType<'a>) -> RustType<'a> {
169 match self {
170 RustType::SimpleType { .. } => {
171 panic!("Can't replace inner on simple or unknown types")
172 }
173 RustType::Optional { mut syn, span, .. } => {
174 syn.to_mut().replace_generic_param(&new_inner);
175 RustType::Optional {
176 syn,
177 inner: Box::new(new_inner),
178 span,
179 }
180 }
181 RustType::Ref { mut syn, span, .. } => {
182 match syn.to_mut() {
183 syn::Type::Path(path) => path.replace_generic_param(&new_inner),
184 syn::Type::Reference(reference) => reference.elem = Box::new(new_inner.to_syn()),
185 _ => panic!("We shouldn't have constructed RustType::Ref for anything else than these types")
186 }
187 RustType::Ref {
188 syn,
189 inner: Box::new(new_inner),
190 span,
191 }
192 }
193 RustType::List { mut syn, span, .. } => {
194 match syn.to_mut() {
195 syn::Type::Path(path) => path.replace_generic_param(&new_inner),
196 syn::Type::Array(array) => array.elem = Box::new(new_inner.to_syn()),
197 syn::Type::Slice(slice) => slice.elem = Box::new(new_inner.to_syn()),
198 syn::Type::Reference(ref_to_slice) => {
199 let syn::Type::Slice(slice) = &mut *ref_to_slice.elem
200 else { panic!("We shouldn't have constructed RustType::List for a Ref unless the type beneath is a Slice") };
201 slice.elem = Box::new(new_inner.to_syn());
202 }
203 _ => panic!("We shouldn't have constructed RustType::List for anything else than these types")
204 }
205
206 RustType::List {
207 syn,
208 inner: Box::new(new_inner),
209 span,
210 }
211 }
212 }
213 }
214}
215
216trait TypePathExt {
217 fn replace_generic_param(&mut self, replacement: &RustType<'_>);
218}
219
220impl TypePathExt for syn::TypePath {
221 fn replace_generic_param(&mut self, replacement: &RustType<'_>) {
222 fn get_generic_argument(type_path: &mut syn::TypePath) -> Option<&mut GenericArgument> {
223 let segment = type_path.path.segments.last_mut()?;
224
225 match &mut segment.arguments {
226 syn::PathArguments::AngleBracketed(angle_bracketed) => {
227 angle_bracketed.args.first_mut()
228 }
229 _ => None,
230 }
231 }
232
233 let generic_argument = get_generic_argument(self)
234 .expect("Don't call replace_generic_param on a type without a generic argument");
235
236 *generic_argument = syn::GenericArgument::Type(replacement.to_syn())
237 }
238}
239
240#[cfg(test)]
241mod tests {
242 use {proc_macro2::TokenStream, quote::quote, rstest::rstest};
243
244 use super::*;
245
246 #[rstest]
247 #[case::replace_on_option(
248 quote! { Option<i32> },
249 quote! { Vec<i32> },
250 quote! { Option<Vec<i32>> },
251 )]
252 #[case::replace_on_vec(
253 quote! { Vec<i32> },
254 quote! { Vec<i32> },
255 quote! { Vec<Vec<i32>> },
256 )]
257 #[case::replace_on_box(
258 quote! { Box<i32> },
259 quote! { Vec<i32> },
260 quote! { Box<Vec<i32>> },
261 )]
262 #[case::replace_on_arc(
263 quote! { Arc<i32> },
264 quote! { Vec<i32> },
265 quote! { Arc<Vec<i32>> },
266 )]
267 #[case::replace_with_complex_inner(
268 quote! { Arc<i32> },
269 quote! { Vec<chrono::DateTime<chrono::Utc>> },
270 quote! { Arc<Vec<chrono::DateTime<chrono::Utc>>> },
271 )]
272 #[case::replace_with_a_full_path(
273 quote! { std::sync::Arc<i32> },
274 quote! { Vec<chrono::DateTime<chrono::Utc>> },
275 quote! { std::sync::Arc<Vec<chrono::DateTime<chrono::Utc>>> },
276 )]
277 fn test_replace_inner(
278 #[case] original: TokenStream,
279 #[case] replace: TokenStream,
280 #[case] expected: TokenStream,
281 ) {
282 let original = syn::parse2(original).unwrap();
283 let replace = syn::parse2(replace).unwrap();
284 let expected = syn::parse2(expected).unwrap();
285
286 let result = parse_rust_type(&original)
287 .replace_inner(parse_rust_type(&replace))
288 .to_syn();
289
290 assert_eq!(result, expected);
291 }
292}