1use convert_case::{Case, Casing};
3use once_cell::sync::Lazy;
4use proc_macro::{TokenStream, TokenTree};
5use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
6use quote::quote;
7use quote::ToTokens;
8use std::collections::{HashMap, HashSet};
9use std::sync::Mutex;
10use syn::punctuated::Punctuated;
11use syn::token::Comma;
12use syn::DataStruct;
13use syn::FieldsNamed;
14use syn::ImplItem;
15use syn::ImplItemFn;
16use syn::Variant;
17use syn::{Data, DeriveInput, Field, Fields, ItemImpl, Type, TypePath};
18use thiserror::Error;
19
20#[derive(Debug, Eq, PartialEq, Clone)]
21enum MixinType {
22 Unknown,
23 Enum,
24 Struct,
25}
26
27#[derive(Error, Debug)]
28enum Error {
29 #[error("global data unavailable")]
30 GlobalUnavailable,
31 #[error("can't find mixin with name: {0}")]
32 NoMixin(String),
33 #[error("invalid expansion of the mixin")]
34 InvalidExpansion,
35 #[error("syn error: {0}")]
36 SynError(#[from] syn::Error),
37 #[error("lex error: {0}")]
38 LexError(#[from] proc_macro::LexError),
39 #[error("parameter error")]
40 ParameterError,
41 #[error("You SHOULD place overwrite for {0} before insert")]
42 OverWriteError(String),
43 #[error("You impl trait {0} twice")]
44 OverWriteTraitTwice(String),
45 #[error("Unsupport Type {0}")]
48 UnsupportType(String),
49}
50
51impl Error {
52 fn to_compile_error(self) -> TokenStream {
53 let txt = self.to_string();
54 let err = syn::Error::new(Span::call_site(), txt).to_compile_error();
55 TokenStream::from(err)
56 }
57}
58
59#[proc_macro_attribute]
68pub fn insert(args: TokenStream, input: TokenStream) -> TokenStream {
69 insert_impl(args, input).unwrap_or_else(Error::to_compile_error)
71}
72
73fn insert_impl(args: TokenStream, input: TokenStream) -> Result<TokenStream, Error> {
74 let mut output: TokenStream = "#[allow(dead_code)]".parse()?;
75 output.extend(input.clone().into_iter());
76 let mut data = GLOBAL_DATA.lock().map_err(|_| Error::GlobalUnavailable)?;
77 let the_struct: DeriveInput = syn::parse(input.clone())?;
78 let the_struct_name = the_struct.ident.to_string();
79 let the_struct_mixin = data.get(&the_struct_name); let mut the_struct_mixin_ctx = if let Some(mixin) = the_struct_mixin {
83 let mut ctx = MixinCtx::from(mixin);
84 ctx.declaration = Some(the_struct);
85 ctx
86 } else {
87 let mixin_type = match the_struct.data {
88 Data::Struct(_) => MixinType::Struct,
89 Data::Enum(_) => MixinType::Enum,
90 Data::Union(_) => todo!(),
91 };
92 MixinCtx {
93 name: the_struct.ident.clone(),
94 mixin_type,
95 declaration: Some(the_struct),
96 extensions: HashMap::new(),
97 overwrite_impls: HashMap::new(),
98 impl_traits: HashMap::new(),
99 over_traits: HashMap::new(),
100 }
101 };
102
103 let mut mixin_names = HashSet::new();
105 for ident in args.into_iter() {
106 if let TokenTree::Ident(idt) = ident {
108 mixin_names.insert(idt.to_string());
109 }
110 }
111 let mut mixed_fields = Vec::new();
115 let mut mixed_variants = Vec::new();
116 for mixin_name in mixin_names {
117 let mixin = data
118 .get(&mixin_name)
119 .ok_or_else(|| Error::NoMixin(mixin_name.clone()))?; let extend_mixin_ctx: MixinCtx = mixin.into();
122 if let Data::Struct(st) = extend_mixin_ctx.declaration.clone().unwrap().data {
126 if let Fields::Named(named) = st.fields {
127 mixed_fields.push(named.named); }
129 } else if let Data::Enum(en) = extend_mixin_ctx.declaration.unwrap().data {
130 for variant in en.variants {
131 mixed_variants.push(variant);
132 }
133 }
134 for (fn_name, fn_impl) in extend_mixin_ctx.extensions.iter() {
136 the_struct_mixin_ctx
137 .extensions
138 .insert(fn_name.clone(), fn_impl.clone());
139 }
140
141 for (trait_name, trait_impl) in extend_mixin_ctx.impl_traits.iter() {
153 let mut trait_impl = trait_impl.clone();
155
156 let ty = trait_impl.self_ty.as_mut(); let path = if let Type::Path(TypePath { path, .. }, ..) = ty {
158 path
159 } else {
160 return Err(Error::UnsupportType(ty.into_token_stream().to_string()));
161 };
162 let x = path.segments.last_mut().unwrap();
163 x.ident = Ident::new(&the_struct_name, x.ident.span());
164
165 the_struct_mixin_ctx
166 .impl_traits
167 .insert(trait_name.clone(), trait_impl);
168 }
169 }
176
177 for (fn_name, fn_impl) in the_struct_mixin_ctx.overwrite_impls.iter() {
179 the_struct_mixin_ctx
181 .extensions
182 .insert(fn_name.clone(), fn_impl.clone());
183 }
184 for (trait_name, trait_impl) in the_struct_mixin_ctx.over_traits.iter() {
185 the_struct_mixin_ctx
186 .impl_traits
187 .insert(trait_name.clone(), trait_impl.clone());
188 }
189
190 if let Data::Struct(ref mut st) = the_struct_mixin_ctx.declaration.as_mut().unwrap().data {
193 if let Fields::Named(ref mut named) = st.fields {
194 let mut the_struct_fields = HashSet::<String>::new();
195 for p in named.named.iter() {
196 if let Some(idt) = p.ident.clone() {
197 the_struct_fields.insert(idt.to_string());
198 }
199 } for fields in mixed_fields {
203 let mut new_fields: Punctuated<Field, Comma> = Punctuated::new();
204 for field in fields.iter() {
205 if let Some(idt) = field.ident.clone() {
206 if !the_struct_fields.contains(&idt.to_string()) {
208 the_struct_fields.insert(idt.to_string());
210 new_fields.push(field.clone());
211 }
212 }
213 }
214 named.named.extend(new_fields.into_pairs());
216 }
217 }
218 } else if let Data::Enum(ref mut en) = the_struct_mixin_ctx.declaration.as_mut().unwrap().data {
219 let mut the_enum_varients = HashSet::new();
221 for variant in en.variants.iter() {
222 the_enum_varients.insert(variant.ident.clone().to_string());
223 }
224 let mut new_variants: Punctuated<Variant, Comma> = Punctuated::new();
226 for variant in mixed_variants {
227 if !the_enum_varients.contains(&variant.ident.clone().to_string()) {
228 the_enum_varients.insert(variant.ident.clone().to_string());
229 new_variants.push(variant);
230 }
231 }
232 en.variants.extend(new_variants.into_pairs());
233 }
234
235 let declaration = the_struct_mixin_ctx.declaration.as_ref().unwrap();
237 if let Data::Struct(_) = declaration.data {
239 let get_set_impls_stream = gen_get_set_impls(declaration);
240 let get_set_impls = syn::parse::<ItemImpl>(get_set_impls_stream.into()).unwrap();
241 the_struct_mixin_ctx.add_extension(&get_set_impls);
242 }
243
244 let stream: TokenStream = the_struct_mixin_ctx.to_token_stream();
246 let the_struct_mixin = Mixin::from(&the_struct_mixin_ctx);
247
248 data.insert(the_struct_name, the_struct_mixin);
252 Ok(stream)
253}
254
255#[derive(Debug)]
256struct Mixin {
257 name: String,
258 mixin_type: MixinType,
259 declaration: Option<String>, extensions: Vec<String>, overwrite_impls: HashMap<String, String>, impl_traits: HashMap<String, String>, over_traits: HashMap<String, String>,
265}
266
267struct MixinCtx {
268 name: Ident,
269 mixin_type: MixinType,
270 declaration: Option<DeriveInput>,
271 extensions: HashMap<String, ImplItemFn>, overwrite_impls: HashMap<String, ImplItemFn>,
273 impl_traits: HashMap<String, ItemImpl>, over_traits: HashMap<String, ItemImpl>,
275}
276
277fn insert_impl_hm(hm: &mut HashMap<String, ImplItemFn>, item_impl: &ItemImpl) {
278 for impl_item in item_impl.items.iter() {
279 match impl_item {
280 ImplItem::Const(_) => todo!(),
281 ImplItem::Fn(impl_item_fn) => {
282 let ident_name = impl_item_fn.sig.ident.to_string();
283 let pre = hm.get(&ident_name);
284 if pre.is_some() {
285 }
287 hm.insert(ident_name, impl_item_fn.clone()); }
289 ImplItem::Type(_) => todo!(),
290 ImplItem::Macro(_) => todo!(),
291 ImplItem::Verbatim(_) => todo!(),
292 _ => todo!(),
293 }
294 }
295}
296
297impl MixinCtx {
298 #[allow(dead_code)]
299 fn dbg_print(&self) {
300 let mixin_name = self.name.to_string();
301 dbg!("=========================", mixin_name);
302 dbg!(self.declaration.to_token_stream().to_string());
303 for item_fn in self.extensions.iter() {
304 dbg!(item_fn.0, item_fn.1.to_token_stream().to_string());
305 }
306 for item_impl in self.impl_traits.iter() {
307 dbg!(item_impl.0, item_impl.1.to_token_stream().to_string());
308 }
309 }
310
311 fn add_overwrite_impls(&mut self, item_impl: &ItemImpl) {
312 insert_impl_hm(&mut self.overwrite_impls, item_impl)
313 }
314 fn add_extension(&mut self, item_impl: &ItemImpl) {
315 insert_impl_hm(&mut self.extensions, item_impl)
316 }
317
318 fn to_token_stream(&self) -> TokenStream {
319 if self.declaration.is_none() {
320 return Error::InvalidExpansion.to_compile_error();
321 }
322 let name = self.name.clone();
323 let derive_input = self.declaration.clone().unwrap();
324
325 let impl_fns_token: Vec<TokenStream2> = self
327 .extensions
328 .iter()
329 .map(|(_, impl_fn)| quote! { #impl_fn })
330 .collect();
331 let (impl_generics, ty_generics, where_clause) = derive_input.generics.split_for_impl();
333 let impl_token = quote! {
334 impl #impl_generics #name #ty_generics #where_clause{
335 #(#impl_fns_token)*
336 }
337 };
338
339 let mut stream: TokenStream2 = derive_input.clone().into_token_stream();
340 stream.extend(impl_token);
341
342 for (_, trait_impl) in self.impl_traits.iter() {
344 stream.extend(trait_impl.to_token_stream())
345 }
346 stream.into()
347 }
348}
349
350impl From<&Mixin> for MixinCtx {
351 fn from(value: &Mixin) -> Self {
352 let name = Ident::new(&value.name, Span::call_site());
353 let declaration = if let Some(declaration) = value.declaration.as_ref() {
354 Some(syn::parse::<DeriveInput>(declaration.parse::<TokenStream>().unwrap()).unwrap())
355 } else {
356 None
357 };
358
359 let mut extensions = HashMap::new();
360 for extention in value.extensions.iter() {
361 let ext_tokenstream = extention.parse::<TokenStream>().unwrap();
362 let ext_item_impl = syn::parse(ext_tokenstream).unwrap();
363 insert_impl_hm(&mut extensions, &ext_item_impl);
364 }
365
366 let mut overwrite_impls = HashMap::new();
367 for (_, overwrite_impl) in value.overwrite_impls.iter() {
368 let ov_tokenstream = overwrite_impl.parse::<TokenStream>().unwrap();
369 let ov_item_impl: ImplItemFn = syn::parse(ov_tokenstream).unwrap();
370 overwrite_impls.insert(ov_item_impl.sig.ident.to_string(), ov_item_impl);
371 }
372 let mut impl_traits = HashMap::new();
373 for ov in value.impl_traits.iter() {
374 let trait_impl = syn::parse::<ItemImpl>(ov.1.parse().unwrap()).unwrap();
375 impl_traits.insert(ov.0.clone(), trait_impl);
376 }
377
378 let mut over_traits = HashMap::new();
379 for ov in value.over_traits.iter() {
380 let ov_trait = syn::parse::<ItemImpl>(ov.1.parse().unwrap()).unwrap();
381 over_traits.insert(ov.0.clone(), ov_trait);
382 }
383
384 MixinCtx {
385 name,
386 mixin_type: value.mixin_type.clone(),
387 declaration: declaration,
388 extensions: extensions,
389 overwrite_impls: overwrite_impls,
390 impl_traits: impl_traits,
391 over_traits: over_traits,
392 }
393 }
394}
395
396impl From<&MixinCtx> for Mixin {
397 fn from(value: &MixinCtx) -> Self {
398 let name: Ident = value.name.clone();
399
400 let declaration = if let Some(declaration) = value.declaration.as_ref() {
401 Some(declaration.to_token_stream().to_string())
402 } else {
403 None
404 };
405
406 let mut extensions: Vec<String> = Vec::new(); let impl_fns_token: Vec<TokenStream2> = value
409 .extensions
410 .iter()
411 .map(|(_, impl_fn)| quote! { #impl_fn })
412 .collect();
413 let impl_token = quote! {
414 impl #name {
415 #(#impl_fns_token)*
416 }
417 };
418 extensions.push(impl_token.to_string());
419
420 let mut overwrite_impls = HashMap::new();
422 for trait_impl in value.overwrite_impls.iter() {
423 overwrite_impls.insert(
424 trait_impl.0.clone(),
425 trait_impl.1.to_token_stream().to_string(),
426 );
427 }
428
429 let mut impl_traits = HashMap::new();
431 for (trait_name, trait_impl) in value.impl_traits.iter() {
432 impl_traits.insert(trait_name.clone(), trait_impl.to_token_stream().to_string());
433 }
434
435 let mut over_traits = HashMap::new();
437 for (trait_name, trait_impl) in value.over_traits.iter() {
438 over_traits.insert(trait_name.clone(), trait_impl.to_token_stream().to_string());
439 }
440
441 Mixin {
442 name: name.to_string(),
443 mixin_type: value.mixin_type.clone(),
444 declaration: declaration,
445 extensions: extensions,
446 overwrite_impls: overwrite_impls,
447 impl_traits: impl_traits,
448 over_traits: over_traits,
449 }
450 }
451}
452static GLOBAL_DATA: Lazy<Mutex<HashMap<String, Mixin>>> = Lazy::new(|| Mutex::new(HashMap::new()));
454
455fn gen_get_set_impls(input: &DeriveInput) -> TokenStream2 {
456 let name = &input.ident;
457
458 let name_string = name.to_string();
459 let get_fn_name = Ident::new(
460 &("get".to_owned() + &name_string).to_case(Case::Snake),
461 name.span(),
462 );
463 let set_fn_name = Ident::new(
464 &("set".to_owned() + &name_string).to_case(Case::Snake),
465 name.span(),
466 );
467
468 let fields = if let Data::Struct(DataStruct {
469 fields: Fields::Named(FieldsNamed { named, .. }),
470 ..
471 }) = &input.data
472 {
473 named
474 } else {
475 panic!("Unsupported data type");
476 };
477 let fds: Vec<Ident> = fields
478 .into_iter()
479 .map(|f| f.ident.clone().unwrap())
480 .collect();
481 let get_fds_token: Vec<TokenStream2> = fds
482 .iter()
483 .map(|name| quote! { #name: self.#name.clone() })
484 .collect();
485 let set_fds_token: Vec<TokenStream2> = fds
486 .iter()
487 .map(|name| quote! { self.#name = p.#name.clone() })
488 .collect();
489 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
492 let impl_get_set = quote! {
493 impl #impl_generics #name #ty_generics #where_clause {
494 pub fn #get_fn_name(&self) -> #name #ty_generics{
495 #name {
496 #(#get_fds_token,)*
497 }
498 }
499 pub fn #set_fn_name(&mut self, p: &#name #ty_generics){
500 #(#set_fds_token;)*
501 }
502 }
503 };
504 impl_get_set
506}
507
508#[proc_macro_attribute]
509pub fn declare(_attribute: TokenStream, input: TokenStream) -> TokenStream {
510 declare_impl(input).unwrap_or_else(Error::to_compile_error)
511}
512
513fn declare_impl(input: TokenStream) -> Result<TokenStream, Error> {
514 let mut output: TokenStream = "#[allow(dead_code)]".parse()?;
516 output.extend(input.clone().into_iter());
517
518 let input = syn::parse::<DeriveInput>(input).unwrap();
519 let mixin_type = match input.data {
520 Data::Struct(_) => MixinType::Struct,
521 Data::Enum(_) => MixinType::Enum,
522 Data::Union(_) => todo!(),
523 };
524
525 let name_string = input.ident.clone().to_string();
526
527 let mut get_set_impls = None;
528
529 if mixin_type == MixinType::Struct {
530 let get_set_impls_stream = gen_get_set_impls(&input);
531 get_set_impls = Some(syn::parse::<ItemImpl>(get_set_impls_stream.into())?);
532 }
533
534 let mut mixin_ctx = MixinCtx {
535 name: input.ident.clone(),
536 mixin_type: mixin_type.clone(),
537 declaration: Some(input),
538 extensions: HashMap::new(),
539 overwrite_impls: HashMap::new(),
540 impl_traits: HashMap::new(),
541 over_traits: HashMap::new(),
542 };
543
544 if mixin_type == MixinType::Struct {
545 mixin_ctx.add_extension(&get_set_impls.unwrap());
546 }
547
548 let mixin = (&mixin_ctx).into();
549 let mut data: std::sync::MutexGuard<'_, HashMap<String, Mixin>> =
550 GLOBAL_DATA.lock().map_err(|_| Error::GlobalUnavailable)?;
551 data.insert(name_string, mixin);
552 Ok(mixin_ctx.to_token_stream())
553}
554
555#[proc_macro_attribute]
556pub fn expand(_attribute: TokenStream, input: TokenStream) -> TokenStream {
557 expand_impl(input).unwrap_or_else(Error::to_compile_error)
558}
559
560fn get_name_of_impl(input: &ItemImpl) -> Result<(String, String), Error> {
562 let trait_name = if let Some((_, path, _)) = &input.trait_ {
563 path.to_token_stream().to_string()
564 } else {
565 "".into()
566 };
567
568 let ty = input.self_ty.as_ref();
569 let path = if let Type::Path(TypePath { path, .. }, ..) = ty {
570 path
571 } else {
572 return Err(Error::UnsupportType(ty.to_token_stream().to_string()));
573 };
574
575 let idt = path.get_ident().unwrap();
576 let name = idt.to_string();
577
578 Ok((name, trait_name))
579}
580
581fn expand_impl(input: TokenStream) -> Result<TokenStream, Error> {
582 let input = syn::parse::<ItemImpl>(input).unwrap();
583 let output = input.to_token_stream().into();
584
585 let (name, trait_name) = get_name_of_impl(&input)?;
586
587 let mut data = GLOBAL_DATA.lock().map_err(|_| Error::GlobalUnavailable)?;
588 let mixin = data
589 .get(&name)
590 .ok_or_else(|| Error::NoMixin(name.clone()))?; let mut mixin_ctx: MixinCtx = mixin.into();
595
596 if trait_name != String::from("") {
597 mixin_ctx.impl_traits.insert(trait_name, input);
598 } else {
599 mixin_ctx.add_extension(&input);
600 }
601
602 let mixin: Mixin = (&mixin_ctx).into();
603 data.insert(name, mixin);
604
605 Ok(output)
606}
607
608#[proc_macro_attribute]
610pub fn overwrite(attribute: TokenStream, input: TokenStream) -> TokenStream {
611 if !attribute.is_empty() {
612 let e = Error::ParameterError;
613 return e.to_compile_error();
614 }
615
616 overwrite_impl(input).unwrap_or_else(Error::to_compile_error)
617}
618
619fn overwrite_impl(input: TokenStream) -> Result<TokenStream, Error> {
620 let input = syn::parse::<ItemImpl>(input).unwrap();
621
622 let (name, trait_name) = get_name_of_impl(&input)?;
623
624 let mut data = GLOBAL_DATA.lock().map_err(|_| Error::GlobalUnavailable)?;
625
626 let mut mixin_ctx = if let Some(mixin) = data.get(&name) {
627 MixinCtx::from(mixin)
628 } else {
629 MixinCtx {
631 name: Ident::new(&name, Span::call_site()),
632 mixin_type: MixinType::Unknown,
633 declaration: None,
634 extensions: HashMap::new(),
635 overwrite_impls: HashMap::new(),
636 impl_traits: HashMap::new(),
637 over_traits: HashMap::new(),
638 }
639 };
640
641 if mixin_ctx.declaration.is_some() {
642 return Err(Error::OverWriteError(name));
644 }
645
646 if trait_name == String::from("") {
647 mixin_ctx.add_overwrite_impls(&input);
649 } else {
650 if mixin_ctx.over_traits.contains_key(&trait_name) {
651 return Err(Error::OverWriteTraitTwice(trait_name));
652 }
653 mixin_ctx.over_traits.insert(trait_name, input);
654 }
655
656 let mixin = Mixin::from(&mixin_ctx);
657 data.insert(name, mixin);
658
659 let output = "".parse::<TokenStream>().unwrap();
660 Ok(output) }