1use std::collections::HashMap;
2
3use proc_macro::TokenStream;
4use proc_macro2::Ident;
5use quote::{ToTokens, format_ident, quote};
6use syn::parse::Parse;
7use syn::punctuated::Punctuated;
8use syn::token::Comma;
9use syn::{
10 Attribute, GenericParam, Generics, Index, Meta, Path, PredicateType, Token, Type, TypeParam,
11 TypePath, WhereClause, WherePredicate,
12};
13
14#[proc_macro_derive(With, attributes(with))]
66pub fn derive(input: TokenStream) -> TokenStream {
67 let ast: syn::DeriveInput = syn::parse(input).expect("Couldn't parse item");
68 let result = match ast.data {
69 syn::Data::Struct(ref s) => with_for_struct(&ast, &s.fields),
70 syn::Data::Enum(_) => panic!("doesn't work with enums yet"),
71 syn::Data::Union(_) => panic!("doesn't work with unions yet"),
72 };
73 result.into()
74}
75
76fn with_for_struct(ast: &syn::DeriveInput, fields: &syn::Fields) -> proc_macro2::TokenStream {
77 match *fields {
78 syn::Fields::Named(ref fields) => with_constructor_for_named(ast, &fields.named),
79 syn::Fields::Unnamed(ref fields) => with_constructor_for_unnamed(ast, &fields.unnamed),
80 syn::Fields::Unit => panic!("Unit structs are not supported"),
81 }
82}
83
84fn with_constructor_for_named(
85 ast: &syn::DeriveInput,
86 fields: &Punctuated<syn::Field, Token![,]>,
87) -> proc_macro2::TokenStream {
88 let name = &ast.ident;
89 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
90 let generics_map = index_generics(&ast.generics);
91 let where_predicate_map = index_where_predicates(&ast.generics.where_clause);
92 let with_args = parse_with_args::<Ident>(&ast.attrs);
93 let field_count = fields.len();
94
95 let mut constructors = quote!();
96 for field in fields {
97 let field_name = field.ident.as_ref().unwrap();
98 if !contains_field(&with_args, field_name) {
99 continue;
100 }
101 let field_type = &field.ty;
102 let constructor_name = format_ident!("with_{}", field_name);
103
104 let constructor = match field_type {
106 Type::Path(type_path) => {
108 match generics_map.get(&type_path.path).cloned() {
110 None => generate_constructor_for_named(
112 &constructor_name,
113 field_name,
114 field_type,
115 field_count,
116 ),
117 Some(mut generic) => {
119 let new_generic = format_ident!("W{}", generic.ident);
120 generic.ident = new_generic.clone();
122
123 let mut new_generic_params = Vec::new();
125 for param in &ast.generics.params {
126 new_generic_params.push(match param {
127 GenericParam::Type(type_param)
129 if type_path.path.is_ident(&type_param.ident) =>
130 {
131 new_generic.to_token_stream()
133 }
134 GenericParam::Type(type_param) => {
135 type_param.ident.to_token_stream()
136 }
137 GenericParam::Lifetime(lifetime_param) => {
138 lifetime_param.lifetime.to_token_stream()
139 }
140 GenericParam::Const(const_param) => {
141 const_param.ident.to_token_stream()
142 }
143 });
144 }
145
146 let mut other_fields = Vec::new();
148 for other_field in fields {
149 let other_field_name = other_field.ident.as_ref().unwrap();
150 if other_field_name != field_name {
151 other_fields
152 .push(quote! { #other_field_name: self.#other_field_name });
153 } else {
154 other_fields.push(quote! { #field_name });
155 }
156 }
157
158 let where_clause = where_predicate_map.get(&type_path.path).cloned().map(
160 |mut predicate| {
161 predicate.bounded_ty = Type::Path(TypePath {
163 qself: None,
164 path: Path::from(new_generic.clone()),
165 });
166 quote! { where #predicate }
167 },
168 );
169
170 quote! {
171 pub fn #constructor_name <#generic> (self, #field_name: #new_generic)
172 -> #name < #(#new_generic_params),* >
173 #where_clause
174 {
175 #name {
176 #(#other_fields),*
177 }
178 }
179 }
180 }
181 }
182 }
183 _ => generate_constructor_for_named(
185 &constructor_name,
186 field_name,
187 field_type,
188 field_count,
189 ),
190 };
191
192 constructors = quote! {
193 #constructors
194 #constructor
195 };
196 }
197 quote! {
198 #[automatically_derived]
199 impl #impl_generics #name #ty_generics #where_clause {
200 #constructors
201 }
202 }
203}
204
205fn with_constructor_for_unnamed(
206 ast: &syn::DeriveInput,
207 fields: &Punctuated<syn::Field, Token![,]>,
208) -> proc_macro2::TokenStream {
209 let name = &ast.ident;
210 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
211 let generics_map = index_generics(&ast.generics);
212 let where_predicate_map = index_where_predicates(&ast.generics.where_clause);
213 let with_args = parse_with_args::<Index>(&ast.attrs);
214
215 let mut constructors = quote!();
216 for (index, field) in fields.iter().enumerate() {
217 let index = syn::Index::from(index);
218 if !contains_field(&with_args, &index) {
219 continue;
220 }
221 let field_type = &field.ty;
222 let field_name = format_ident!("field_{}", index);
223 let constructor_name = format_ident!("with_{}", index);
224
225 let constructor = match field_type {
227 Type::Path(type_path) => {
229 match generics_map.get(&type_path.path).cloned() {
231 None => generate_constructor_for_unnamed(
233 &constructor_name,
234 index,
235 &field_name,
236 field_type,
237 ),
238
239 Some(mut generic) => {
241 let new_generic = format_ident!("W{}", generic.ident);
242 generic.ident = new_generic.clone();
244
245 let mut new_generic_params = Vec::new();
247 for param in &ast.generics.params {
248 new_generic_params.push(match param {
249 GenericParam::Type(type_param)
251 if type_path.path.is_ident(&type_param.ident) =>
252 {
253 new_generic.to_token_stream()
255 }
256 GenericParam::Type(type_param) => {
257 type_param.ident.to_token_stream()
258 }
259 GenericParam::Lifetime(lifetime_param) => {
260 lifetime_param.lifetime.to_token_stream()
261 }
262 GenericParam::Const(const_param) => {
263 const_param.ident.to_token_stream()
264 }
265 });
266 }
267
268 let mut other_fields = Vec::new();
270 for (other_index, _) in fields.iter().enumerate() {
271 let other_index = syn::Index::from(other_index);
272 if other_index != index {
273 other_fields.push(quote! { self.#other_index });
274 } else {
275 other_fields.push(quote! { #field_name });
276 }
277 }
278
279 let where_clause = where_predicate_map.get(&type_path.path).cloned().map(
281 |mut predicate| {
282 predicate.bounded_ty = Type::Path(TypePath {
284 qself: None,
285 path: Path::from(new_generic.clone()),
286 });
287 quote! { where #predicate }
288 },
289 );
290
291 quote! {
292 pub fn #constructor_name <#generic> (self, #field_name: #new_generic)
293 -> #name < #(#new_generic_params),* >
294 #where_clause
295 {
296 #name ( #(#other_fields),* )
297 }
298 }
299 }
300 }
301 }
302 _ => {
304 generate_constructor_for_unnamed(&constructor_name, index, &field_name, field_type)
305 }
306 };
307
308 constructors = quote! {
309 #constructors
310 #constructor
311 };
312 }
313 quote! {
314 #[automatically_derived]
315 impl #impl_generics #name #ty_generics #where_clause {
316 #constructors
317 }
318 }
319}
320
321fn parse_with_args<T: Parse>(attrs: &[Attribute]) -> Option<Punctuated<T, Comma>> {
322 if let Some(attr) = attrs.iter().find(|attr| attr.path().is_ident("with")) {
323 match &attr.meta {
324 Meta::List(list) => Some(
325 list.parse_args_with(Punctuated::<T, Comma>::parse_terminated)
326 .expect("Couldn't parse with args"),
327 ),
328 _ => panic!("`with` attribute should like `#[with(a, b, c)]`"),
329 }
330 } else {
331 None
332 }
333}
334
335fn contains_field<T: Parse + PartialEq>(
336 with_args: &Option<Punctuated<T, Comma>>,
337 item: &T,
338) -> bool {
339 with_args.is_none() || with_args.as_ref().unwrap().iter().any(|arg| arg == item)
340}
341
342fn index_generics(generics: &Generics) -> HashMap<Path, TypeParam> {
343 generics
344 .params
345 .iter()
346 .filter_map(|p| match p {
347 GenericParam::Type(type_param) => Some(type_param),
348 _ => None,
349 })
350 .map(|p| (Path::from(p.ident.clone()), p.clone()))
351 .collect()
352}
353
354fn index_where_predicates(where_clause: &Option<WhereClause>) -> HashMap<Path, PredicateType> {
355 where_clause
356 .as_ref()
357 .map(|w| {
358 w.predicates
359 .iter()
360 .filter_map(|p| match p {
361 WherePredicate::Type(t) => Some(t),
362 _ => None,
363 })
364 .filter_map(|t| match &t.bounded_ty {
365 Type::Path(type_path) => Some((type_path.path.clone(), t.clone())),
366 _ => None,
367 })
368 .collect()
369 })
370 .unwrap_or_default()
371}
372
373fn generate_constructor_for_named(
374 constructor_name: &Ident,
375 field_name: &Ident,
376 field_type: &Type,
377 field_count: usize,
378) -> proc_macro2::TokenStream {
379 let field_arg_type = match field_type {
380 Type::Path(type_path) if is_builtin_numeric_type(&type_path.path) => quote! { #field_type },
381 _ => quote! { impl Into<#field_type> },
382 };
383 if field_count == 1 {
384 quote! {
385 pub fn #constructor_name(self, #field_name: #field_arg_type) -> Self {
386 Self {
387 #field_name: #field_name.into(),
388 }
389 }
390 }
391 } else {
392 quote! {
393 pub fn #constructor_name(self, #field_name: #field_arg_type) -> Self {
394 Self {
395 #field_name: #field_name.into(),
396 ..self
397 }
398 }
399 }
400 }
401}
402
403fn generate_constructor_for_unnamed(
404 constructor_name: &Ident,
405 field_index: Index,
406 field_name: &Ident,
407 field_type: &Type,
408) -> proc_macro2::TokenStream {
409 let field_arg_type = match field_type {
410 Type::Path(type_path) if is_builtin_numeric_type(&type_path.path) => {
411 quote! { #field_type }
412 }
413 _ => quote! { impl Into<#field_type> },
414 };
415 quote! {
416 pub fn #constructor_name(mut self, #field_name: #field_arg_type) -> Self {
417 self.#field_index = #field_name.into();
418 self
419 }
420 }
421}
422
423fn is_builtin_numeric_type(path: &Path) -> bool {
425 let path_str = path.to_token_stream().to_string();
427
428 matches!(
430 path_str.as_str(),
431 "i8" | "i16"
432 | "i32"
433 | "i64"
434 | "i128"
435 | "isize"
436 | "u8"
437 | "u16"
438 | "u32"
439 | "u64"
440 | "u128"
441 | "usize"
442 | "f32"
443 | "f64"
444 )
445}