kstool_helper_generator/
lib.rs

1mod enchanted;
2
3use once_cell::sync::Lazy;
4use proc_macro2::TokenStream;
5use proc_macro2::{Ident, Span};
6use quote::{quote, ToTokens};
7use syn::{spanned::Spanned, Attribute, DataEnum, Fields, Variant};
8
9static WORD_RE: Lazy<fancy_regex::Regex> =
10    Lazy::new(|| fancy_regex::Regex::new(r".(?:[^A-Z0-9]+|[A-Z0-9]*)(?![^A-Z0-9])").unwrap());
11static WORD_RE_ERROR: Lazy<String> = Lazy::new(|| "ERROR_PLEASE_REPORT".to_string());
12
13type TyType = TokenStream;
14type IdentType = Ident;
15
16#[derive(Clone)]
17enum FieldsType {
18    Named(Vec<(IdentType, TyType)>),
19    Unnamed(Vec<TyType>),
20    None,
21}
22
23impl FieldsType {
24    fn to_arg_def(&self, span: Span) -> TokenStream {
25        match self {
26            FieldsType::Named(v) => {
27                let updated = v
28                    .iter()
29                    .map(|(ident, ty)| {
30                        quote!(
31                            #ident: #ty
32                        )
33                    })
34                    .collect::<Vec<_>>();
35                quote!(#(#updated), *)
36            }
37            FieldsType::Unnamed(v) => {
38                let i = (0..v.len())
39                    .map(|i| Ident::new(&format!("input{}", i), span))
40                    .collect::<Vec<_>>();
41                quote!(#(#i: #v), *)
42            }
43            FieldsType::None => TokenStream::new(),
44        }
45    }
46
47    fn to_arg(&self, span: Span) -> TokenStream {
48        match self {
49            FieldsType::Named(v) => {
50                let idents = v
51                    .iter()
52                    .map(|(ident, _)| quote!(#ident))
53                    .collect::<Vec<_>>();
54                quote!({#(#idents), *})
55            }
56            FieldsType::Unnamed(v) => {
57                let i = (0..v.len())
58                    .map(|i| (Ident::new(&format!("input{}", i), span)))
59                    .collect::<Vec<_>>();
60                quote!((#(#i), *))
61            }
62            FieldsType::None => TokenStream::new(),
63        }
64    }
65
66    fn enum_arg_def(&self, return_type: Option<TokenStream>) -> TokenStream {
67        match self {
68            FieldsType::Named(v) => {
69                let mut idents = v
70                    .iter()
71                    .map(|(ident, ty)| {
72                        quote!(
73                            #ident: #ty
74                        )
75                    })
76                    .collect::<Vec<_>>();
77                if let Some(ret) = return_type {
78                    let ident = Ident::new("__private_sender", ret.span());
79                    idents.push(quote! ( #ident: tokio::sync::oneshot::Sender< #ret >));
80                }
81                quote!({#(#idents), *})
82            }
83            FieldsType::Unnamed(v) => {
84                if let Some(ret) = return_type {
85                    let mut v = v.clone();
86                    v.push(quote! (tokio::sync::oneshot::Sender< #ret >));
87                    quote!((#(#v), *))
88                } else {
89                    quote!((#(#v), *))
90                }
91            }
92            FieldsType::None => {
93                if let Some(ret) = return_type {
94                    quote! {(tokio::sync::oneshot::Sender< #ret >)}
95                } else {
96                    TokenStream::new()
97                }
98            }
99        }
100    }
101
102    fn enchant_arg(&self, span: Span) -> TokenStream {
103        let private_receiver = Ident::new("__private_sender", span);
104        match self {
105            FieldsType::Named(v) => {
106                let mut idents = v
107                    .iter()
108                    .map(|(ident, _)| quote!(#ident))
109                    .collect::<Vec<_>>();
110                idents.push(quote! {#private_receiver});
111                quote!({#(#idents), *})
112            }
113            FieldsType::Unnamed(v) => {
114                let mut i = (0..v.len())
115                    .map(|i| (Ident::new(&format!("input{}", i), span)))
116                    .collect::<Vec<_>>();
117                i.push(private_receiver);
118                quote!((#(#i), *))
119            }
120            FieldsType::None => quote! {(#private_receiver)},
121        }
122    }
123}
124
125enum FieldType {
126    Named(IdentType, TyType),
127    Unnamed(TyType),
128}
129
130impl FieldType {
131    fn named(ident: IdentType, ty: TyType) -> Self {
132        Self::Named(ident, ty)
133    }
134
135    fn unnamed(ty: TyType) -> Self {
136        Self::Unnamed(ty)
137    }
138
139    fn get_named(self) -> Option<(IdentType, TyType)> {
140        match self {
141            Self::Named(a, b) => Some((a, b)),
142            _ => None,
143        }
144    }
145
146    fn get_unnamed(self) -> Option<TyType> {
147        match self {
148            FieldType::Unnamed(s) => Some(s),
149            _ => None,
150        }
151    }
152}
153
154impl From<Vec<(IdentType, TyType)>> for FieldsType {
155    fn from(value: Vec<(IdentType, TyType)>) -> Self {
156        Self::Named(value)
157    }
158}
159
160impl From<Vec<TyType>> for FieldsType {
161    fn from(value: Vec<TyType>) -> Self {
162        Self::Unnamed(value)
163    }
164}
165
166impl Default for FieldsType {
167    fn default() -> Self {
168        Self::None
169    }
170}
171
172impl TryFrom<Vec<FieldType>> for FieldsType {
173    type Error = ();
174
175    fn try_from(value: Vec<FieldType>) -> Result<Self, Self::Error> {
176        if value.is_empty() {
177            return Ok(Self::None);
178        }
179
180        let is_named = match value.first() {
181            Some(value) => match value {
182                FieldType::Named(_, _) => true,
183                FieldType::Unnamed(_) => false,
184                //FieldType::None => unreachable!("If type is None, should never going to this way"),
185            },
186            None => unreachable!("Has checked vec is empty"),
187        };
188
189        if is_named {
190            value
191                .into_iter()
192                .map(|f| f.get_named())
193                .collect::<Option<Vec<_>>>()
194                .map(|v| v.into())
195        } else {
196            value
197                .into_iter()
198                .map(|f| f.get_unnamed())
199                .collect::<Option<Vec<_>>>()
200                .map(|v| v.into())
201        }
202        .ok_or(())
203    }
204}
205
206impl TryFrom<&syn::Fields> for FieldsType {
207    type Error = syn::Error;
208    fn try_from(fields: &syn::Fields) -> Result<Self, Self::Error> {
209        let span = fields.span();
210        let fs: Vec<FieldType> = match fields {
211            Fields::Named(fields) => {
212                //fields.named.iter().map(|field| )
213                fields
214                    .named
215                    .iter()
216                    .map(|field| {
217                        Ok(FieldType::named(
218                            field.ident.clone().ok_or_else(|| {
219                                syn::Error::new_spanned(field, "Field should have a name")
220                            })?,
221                            field.ty.to_token_stream(),
222                        ))
223                    })
224                    .collect::<syn::Result<Vec<FieldType>>>()?
225            }
226            Fields::Unnamed(fields) => fields
227                .unnamed
228                .iter()
229                .map(|field| FieldType::unnamed(field.ty.to_token_stream()))
230                .collect::<Vec<_>>(),
231            Fields::Unit => Vec::new(),
232        };
233
234        let fs = fs.try_into();
235        match fs {
236            Ok(f) => Ok(f),
237            Err(_) => Err(syn::Error::new(
238                span,
239                "Type not match, it should never happened",
240            )),
241        }
242    }
243}
244
245#[derive(Clone)]
246struct EnumDefinition {
247    ident: String,
248    fields: FieldsType,
249}
250
251impl EnumDefinition {
252    fn name_into_snake_case(&self) -> String {
253        match WORD_RE
254            .find_iter(&self.ident)
255            .collect::<Result<Vec<_>, _>>()
256        {
257            Ok(matches) => matches
258                .iter()
259                .map(|s| s.as_str().to_ascii_lowercase())
260                .collect::<Vec<_>>()
261                .join("_"),
262            Err(_) => WORD_RE_ERROR.clone(),
263        }
264    }
265
266    fn get_name(&self, span: Span) -> Ident {
267        Ident::new(&self.name_into_snake_case(), span)
268    }
269
270    fn get_normal_name(&self, span: Span) -> Ident {
271        Ident::new(&self.ident, span)
272    }
273
274    fn get_name_block(&self, span: Span) -> Ident {
275        // TODO: Optimize this
276        let block_name = format!("{}_b", self.name_into_snake_case());
277        Ident::new(&block_name, span)
278    }
279
280    fn fields(&self) -> &FieldsType {
281        &self.fields
282    }
283}
284
285impl TryFrom<&Variant> for EnumDefinition {
286    type Error = syn::Error;
287
288    fn try_from(value: &Variant) -> Result<Self, Self::Error> {
289        let ident = value.ident.to_string();
290        Ok(Self {
291            ident,
292            fields: FieldsType::try_from(&value.fields)?,
293        })
294    }
295}
296
297/* fn print_fields(data: &DataEnum) {
298    let idents: Vec<_> = data.variants.iter().map(|f| &f.ident).collect();
299    //let types: Vec<_> = data.variants.iter().map(|f| &f.ty).collect();
300
301   //eprintln!("{:#?}", idents);
302    //eprintln!("{:#?}", types);
303} */
304
305/* fn check_is_enum(st: &syn::DeriveInput) -> syn::Result<&DataEnum> {
306    match st.data {
307        syn::Data::Enum(ref data_enum) => Ok(data_enum),
308        _ => Err(syn::Error::new_spanned(
309            st,
310            "Must defined a enum, not struct".to_string(),
311        )),
312    }
313} */
314
315fn generate_function(
316    st: &syn::DeriveInput,
317    de: &DataEnum,
318    block: bool,
319    no_async: bool,
320) -> syn::Result<TokenStream> {
321    let mut ret = TokenStream::new();
322    let basic = &st.ident;
323    for variant in &de.variants {
324        let definition = EnumDefinition::try_from(variant)?;
325        let arg_def = definition.fields().to_arg_def(variant.span());
326        let arg = definition.fields().to_arg(variant.span());
327        let function_name = definition.get_name(variant.span());
328        let member = &variant.ident;
329        if !no_async {
330            let result = quote! {
331                pub async fn #function_name (&self, #arg_def) -> std::option::Option<()> {
332                    self.sender
333                        .send(#basic::#member #arg)
334                        .await
335                        .ok()
336                }
337            };
338            //eprintln!("{:#?}", result.to_string());
339            ret.extend(result);
340        }
341        if block {
342            let function_name = if no_async {
343                function_name
344            } else {
345                definition.get_name_block(variant.span())
346            };
347            let result = quote! {
348                pub fn #function_name (&self, #arg_def) -> std::option::Option<()> {
349                    self.sender
350                        .blocking_send(#basic::#member #arg)
351                        .ok()
352                }
353            };
354            ret.extend(result);
355        }
356    }
357    Ok(ret)
358}
359
360pub(crate) fn parse_tokens(token_stream: &TokenStream) -> syn::Result<(bool, bool)> {
361    let mut no_async = false;
362    let mut block = false;
363
364    for token in token_stream.clone().into_iter() {
365        match &token {
366            proc_macro2::TokenTree::Ident(ident) => {
367                if ident.eq("no_async") {
368                    no_async = true;
369                } else if ident.eq("block") {
370                    block = true;
371                } else {
372                    return Err(syn::Error::new(ident.span(), "Unrecognized token"));
373                }
374            }
375            _ => continue,
376        }
377    }
378    Ok((block, no_async))
379}
380
381pub(crate) fn parse_arguments(attrs: &[Attribute]) -> syn::Result<(bool, bool)> {
382    if attrs.is_empty() {
383        return Ok((false, false));
384    }
385    for attr in attrs {
386        match &attr.meta {
387            syn::Meta::Path(_) => {
388                /* return Err(syn::Error::new(
389                    attr.span(),
390                    "Unimplemented syn::Meta::Path",
391                )) */
392            }
393            syn::Meta::List(list) => {
394                if let Some(seg) = list.path.segments.first() {
395                    if !seg.ident.eq("helper") {
396                        continue;
397                    }
398                    let (block, no_async) = parse_tokens(&list.tokens)?;
399                    if !block && no_async {
400                        return Err(syn::Error::new(
401                            list.span(),
402                            "This code generate `new' function only!",
403                        ));
404                    }
405                    return Ok((block, no_async));
406                }
407            }
408            syn::Meta::NameValue(_) => {
409                /* return Err(syn::Error::new(
410                    attr.span(),
411                    "Unimplemented syn::Meta::NameValue",
412                )) */
413            }
414        }
415    }
416    Ok((false, false))
417}
418
419type GenMemberFn = fn(&syn::DeriveInput, &syn::DataEnum, bool, bool) -> syn::Result<TokenStream>;
420
421pub(crate) fn do_expand(
422    st: &syn::DeriveInput,
423    replace_function: Option<GenMemberFn>,
424) -> syn::Result<TokenStream> {
425    /* if !st.data.eq(syn::Data) {
426        return Err(syn::Error::new(
427            st.span(),
428            "Should apply this drive to enum",
429        ));
430    } */
431    let (block, no_async) = parse_arguments(&st.attrs)?;
432    //eprintln!("{} {}", block, no_async);
433    let data_enum = extract_enum(st)?;
434    //print_fields();
435    let enum_name = st.ident.to_string();
436    let (basic_name, _) = enum_name.rsplit_once("Event").unwrap();
437
438    let helper_receiver_type = format!("{}Receiver", enum_name);
439    let helper_receiver_type_indent = syn::Ident::new(&helper_receiver_type, st.ident.span());
440
441    let helper_name = format!("{}Helper", basic_name);
442    let helper_name_ident = syn::Ident::new(&helper_name, st.ident.span());
443
444    let enum_ident = &st.ident;
445
446    let member_function = match replace_function {
447        Some(func) => func(st, data_enum, block, no_async),
448        None => generate_function(st, data_enum, block, no_async),
449    }?;
450
451    let ret = quote! {
452
453        #[derive(Clone, Debug)]
454        pub struct #helper_name_ident {
455            sender: tokio::sync::mpsc::Sender<#enum_ident>
456        }
457
458        pub type #helper_receiver_type_indent = tokio::sync::mpsc::Receiver<#enum_ident>;
459
460        impl #helper_name_ident {
461            pub fn new(size: usize) -> (Self, #helper_receiver_type_indent) {
462                let (a, b) = tokio::sync::mpsc::channel(size);
463                (a.into(), b)
464            }
465
466            #member_function
467        }
468
469        impl From<tokio::sync::mpsc::Sender<#enum_ident>> for #helper_name_ident {
470            fn from(value: tokio::sync::mpsc::Sender<#enum_ident>) -> Self {
471                Self {
472                    sender: value
473                }
474            }
475        }
476
477    };
478
479    Ok(ret)
480}
481
482pub(crate) fn early_check(st: &syn::DeriveInput) -> syn::Result<()> {
483    match st.data {
484        syn::Data::Enum(_) => {
485            if !st.ident.to_string().contains("Event") {
486                return Err(syn::Error::new(
487                    st.ident.span(),
488                    "Should contains Event in name",
489                ));
490            }
491            Ok(())
492        }
493        _ => Err(syn::Error::new_spanned(
494            st,
495            "Must defined a enum, not struct".to_string(),
496        )),
497    }
498}
499
500pub(crate) fn extract_enum(st: &syn::DeriveInput) -> syn::Result<&DataEnum> {
501    match st.data {
502        syn::Data::Enum(ref data_enum) => Ok(data_enum),
503        _ => unreachable!(),
504    }
505}
506
507#[proc_macro_derive(Helper, attributes(helper))]
508pub fn enum_helper_generator(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
509    let st = syn::parse_macro_input!(input as syn::DeriveInput);
510    //eprintln!("{:#?}", st.attrs);
511
512    if let Err(e) = crate::early_check(&st) {
513        return e.into_compile_error().into();
514    }
515
516    match do_expand(&st, None) {
517        Ok(token_stream) => {
518            //eprintln!("{}", token_stream.to_string());
519            token_stream.into()
520        }
521        Err(e) => e.to_compile_error().into(),
522    }
523    //eprintln!("{:#?}", item);
524}
525
526#[proc_macro]
527pub fn oneshot_helper(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
528    let early_st = syn::parse_macro_input!(input as syn::DeriveInput);
529    match enchanted::handle_new(early_st) {
530        Ok(ret) => ret,
531        Err(e) => e.into_compile_error(),
532    }
533    .into()
534}
535
536#[cfg(test)]
537mod test {
538    use crate::EnumDefinition;
539
540    #[test]
541    fn test_snake_case_convert() {
542        fn func(input: &str) -> String {
543            EnumDefinition {
544                ident: input.to_string(),
545                fields: super::FieldsType::None,
546            }
547            .name_into_snake_case()
548        }
549
550        assert_eq!(func("GetHTTPResponse"), "get_http_response".to_string());
551        assert_eq!(func("CSV"), "csv".to_string());
552        assert_eq!(func("IPChecker"), "ip_checker".to_string());
553        assert_eq!(func("UserAdd"), "user_add".to_string());
554        assert_eq!(
555            func("IsHTTPSpecifyASpecicalAdd"),
556            "is_http_specify_a_specical_add".to_string()
557        );
558        assert_eq!(func("IPV4"), "ipv4".to_string())
559    }
560}