auto_impl_ops/
lib.rs

1#![doc = include_str!("../README.md")]
2#[cfg(test)]
3mod tests;
4use proc_macro2::{Ident, Span, TokenStream};
5use quote::{format_ident, quote, ToTokens, TokenStreamExt};
6use std::collections::HashMap;
7use std::convert::{TryFrom, TryInto};
8use std::str::FromStr;
9use strum::{Display, EnumString};
10use syn::{parse::Parser, punctuated::Punctuated, spanned::Spanned, *};
11
12fn is_ref(type_: &Type) -> bool {
13    matches!(type_, Type::Reference(_))
14}
15
16fn remove_reference(type_: &Type) -> &Type {
17    match type_ {
18        Type::Reference(ref_) => &ref_.elem,
19        _ => type_,
20    }
21}
22
23fn copy_reference(target: &Type, source: &Type) -> Type {
24    match source {
25        Type::Reference(inner) => {
26            let mut out = inner.clone();
27            out.elem = Box::new(target.clone());
28            Type::Reference(out)
29        }
30        _ => target.clone(),
31    }
32}
33
34fn get_last_segment(implement: &ItemImpl) -> Result<&PathSegment> {
35    if implement.trait_.is_none() {
36        return Err(Error::new(implement.span(), "Is not Trait impl"));
37    };
38    let trait_ = implement.trait_.as_ref().unwrap();
39    if let Some(bang) = trait_.0 {
40        return Err(Error::new(bang.span(), "Unexpected negative impl"));
41    }
42    let segments = &trait_.1.segments;
43    if segments.is_empty() {
44        return Err(Error::new(segments.span(), "Unexpected empty trait path"));
45    }
46    Ok(segments.last().unwrap())
47}
48
49fn get_rhs_type<'a>(args: &'a PathArguments, self_type: &'a Type) -> Result<&'a Type> {
50    match args {
51        PathArguments::None => Ok(self_type),
52        PathArguments::AngleBracketed(args) => {
53            let args = &args.args;
54            if args.len() != 1 {
55                return Err(Error::new(
56                    args.span(),
57                    "Number of trait arguments is not 1",
58                ));
59            }
60            if let GenericArgument::Type(rhs_type) = args.first().unwrap() {
61                Ok(rhs_type)
62            } else {
63                Err(Error::new(args.span(), "Is not type"))
64            }
65        }
66        _ => Err(Error::new(args.span(), "Unexpected trait arguments")),
67    }
68}
69
70#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
71struct Operate(OpTrait, bool, bool);
72impl Operate {
73    fn lhs_move(&self) -> bool {
74        !self.0.is_assign() && !self.1
75    }
76    fn rhs_move(&self) -> bool {
77        !self.2
78    }
79    fn require_lhs_clone(&self, op: Self) -> bool {
80        (self.lhs_move() || self.0.is_assign()) && op.1
81    }
82    fn require_rhs_clone(&self, op: Self) -> bool {
83        self.rhs_move() && op.2
84    }
85    fn require_clone(&self, op: Self) -> bool {
86        self.require_lhs_clone(op) || self.require_rhs_clone(op)
87    }
88}
89
90#[derive(Clone, Debug)]
91struct Generator<'a> {
92    implement: &'a ItemImpl,
93    source_op: Operate,
94    self_type: &'a Type,
95    rhs_type: &'a Type,
96}
97impl<'a> Generator<'a> {
98    fn get_arg_type(is_ref_: bool, target: &Type, source: &Type) -> Type {
99        if !is_ref_ {
100            remove_reference(target).clone()
101        } else if is_ref(target) {
102            target.clone()
103        } else if is_ref(source) {
104            copy_reference(target, source)
105        } else {
106            parse_quote! {
107                &#target
108            }
109        }
110    }
111    fn update_where_clause(&self, generics: &mut Generics, op: Operate) {
112        let rr_self_type = remove_reference(self.self_type);
113        if self.source_op.require_clone(op) {
114            let wc = generics.make_where_clause();
115            wc.predicates.push(parse_quote! {
116                #rr_self_type: Clone
117            });
118        }
119        if self.source_op.lhs_move() && op.0.is_assign() && cfg!(not(feature = "take_mut")) {
120            let wc = generics.make_where_clause();
121            wc.predicates.push(parse_quote! {
122                #rr_self_type: Default
123            });
124        }
125    }
126    fn assgin_body(source_op: Operate) -> TokenStream {
127        let source_fn_name = source_op.0.to_func_ident();
128        if source_op.0.is_assign() {
129            quote! {
130                self.#source_fn_name(rhs);
131            }
132        } else if source_op.1 {
133            quote! {
134                *self = (&*self).#source_fn_name(rhs);
135            }
136        } else if cfg!(feature = "take_mut") {
137            quote! {
138                take_mut::take(self, |x| x.#source_fn_name(rhs));
139            }
140        } else {
141            quote! {
142                let mut t = Self::default();
143                std::mem::swap(&mut t, self);
144                let mut u = t.#source_fn_name(rhs);
145                std::mem::swap(&mut u, self);
146            }
147        }
148    }
149    fn gen_rhs(source_op: Operate, op: Operate) -> TokenStream {
150        #[allow(clippy::collapsible_else_if)]
151        if source_op.2 {
152            if op.2 {
153                TokenStream::new()
154            } else {
155                quote!(let rhs = &rhs;)
156            }
157        } else {
158            if op.2 {
159                quote!(let rhs = rhs.clone();)
160            } else {
161                TokenStream::new()
162            }
163        }
164    }
165    fn gen_lhs(source_op: Operate, op: Operate) -> TokenStream {
166        #[allow(clippy::collapsible_else_if)]
167        if source_op.0.is_assign() {
168            if op.1 {
169                quote!(let mut lhs = self.clone();)
170            } else {
171                quote!(let mut lhs = self;)
172            }
173        } else if source_op.1 {
174            if op.1 {
175                quote!(let lhs = self;)
176            } else {
177                quote!(let lhs = &self;)
178            }
179        } else {
180            if op.1 {
181                quote!(let lhs = self.clone();)
182            } else {
183                quote!(let lhs = self;)
184            }
185        }
186    }
187    fn gen_output(&self) -> Result<Type> {
188        let rr_self_type = remove_reference(self.self_type);
189        if self.source_op.0.is_assign() {
190            Ok(rr_self_type.clone())
191        } else {
192            let v = self
193                .implement
194                .items
195                .iter()
196                .filter_map(|x| {
197                    if let ImplItem::Type(x) = x {
198                        Some(x)
199                    } else {
200                        None
201                    }
202                })
203                .filter_map(|x| {
204                    if x.ident == "Output" {
205                        Some(&x.ty)
206                    } else {
207                        None
208                    }
209                })
210                .collect::<Vec<_>>();
211            if let [x] = v[..] {
212                if x == &parse_quote!(Self) {
213                    Ok(rr_self_type.clone())
214                } else {
215                    Ok(x.clone())
216                }
217            } else {
218                Err(Error::new(
219                    Span::call_site(),
220                    "`type Output =` is not found or multiple",
221                ))
222            }
223        }
224    }
225    fn generate(&self, op: Operate) -> Result<TokenStream> {
226        if op.0.is_assign() && op.1 {
227            return Err(Error::new(
228                Span::call_site(),
229                "Type of LHS of assign operations must not reference",
230            ));
231        }
232        if op == self.source_op {
233            return Ok(self.implement.to_token_stream());
234        }
235        let mut work = self.implement.clone();
236        if let Operate(_, false, false) = op {
237            work.attrs.push(parse_quote! {
238                #[allow(clippy::extra_unused_lifetimes)]
239            });
240        }
241        let rhs_type = Self::get_arg_type(op.2, self.rhs_type, self.self_type);
242        let trait_ = op.0;
243        *work.trait_.as_mut().unwrap().1.segments.last_mut().unwrap() =
244            parse_quote! { #trait_<#rhs_type> };
245        *work.self_ty.as_mut() = Self::get_arg_type(op.1, self.self_type, self.rhs_type);
246        self.update_where_clause(&mut work.generics, op);
247        work.items.clear();
248        let fn_name = op.0.to_func_ident();
249        let preamble_rhs = Self::gen_rhs(self.source_op, op);
250        if op.0.is_assign() {
251            let body = Self::assgin_body(self.source_op);
252            work.items.push(parse_quote! {
253                fn #fn_name(&mut self, rhs: #rhs_type) {
254                    #preamble_rhs
255                    #body
256                }
257            });
258        } else {
259            let output_type = self.gen_output()?;
260            work.items.push(parse_quote! {
261                type Output = #output_type;
262            });
263            let preamble_lhs = Self::gen_lhs(self.source_op, op);
264            let source_fn_name = self.source_op.0.to_func_ident();
265            let body = if self.source_op.0.is_assign() {
266                quote! {
267                    lhs.#source_fn_name(rhs);
268                    lhs
269                }
270            } else {
271                quote! {
272                    lhs.#source_fn_name(rhs)
273                }
274            };
275            work.items.push(parse_quote! {
276                fn #fn_name(self, rhs: #rhs_type) -> Self::Output {
277                    #preamble_lhs
278                    #preamble_rhs
279                    #body
280                }
281            });
282        }
283        Ok(quote!(#work))
284    }
285}
286
287type Attributes = Punctuated<Ident, token::Comma>;
288fn auto_ops_generate(mut attrs: Attributes, implement: ItemImpl) -> Result<TokenStream> {
289    let last_segment = get_last_segment(&implement)?;
290    let op: OpTrait = last_segment.ident.clone().try_into()?;
291    let self_type = &implement.self_ty;
292    let rhs_type = get_rhs_type(&last_segment.arguments, self_type)?;
293    let generator = Generator {
294        implement: &implement,
295        source_op: Operate(op, is_ref(self_type), is_ref(rhs_type)),
296        self_type,
297        rhs_type,
298    };
299    let list = [
300        ("assign_ref", Operate(op.to_assign(), false, true)),
301        ("assign_val", Operate(op.to_assign(), false, false)),
302        ("ref_ref", Operate(op.to_non_assign(), true, true)),
303        ("ref_val", Operate(op.to_non_assign(), true, false)),
304        ("val_ref", Operate(op.to_non_assign(), false, true)),
305        ("val_val", Operate(op.to_non_assign(), false, false)),
306    ];
307    let map = HashMap::from(list);
308    let rev_map = list.iter().map(|&(v, k)| (k, v)).collect::<HashMap<_, _>>();
309    if attrs.is_empty() {
310        attrs = list.iter().map(|(x, _)| format_ident!("{}", x)).collect();
311    }
312    let source = rev_map[&generator.source_op];
313    if !attrs.iter().any(|x| x == source) {
314        attrs.push(format_ident!("{}", source));
315    }
316    let mut result = TokenStream::new();
317    for i in attrs.iter() {
318        let s = i.to_string();
319        if let Some(op) = map.get(s.as_str()) {
320            let code = generator.generate(*op)?;
321            result.extend(code);
322        }
323    }
324    Ok(result)
325}
326
327#[derive(Clone, Copy, Debug, Display, EnumString, PartialEq, Eq, Hash)]
328enum OpTrait {
329    Add,
330    AddAssign,
331    Sub,
332    SubAssign,
333    Mul,
334    MulAssign,
335    Div,
336    DivAssign,
337    Rem,
338    RemAssign,
339    BitAnd,
340    BitAndAssign,
341    BitOr,
342    BitOrAssign,
343    BitXor,
344    BitXorAssign,
345    Shl,
346    ShlAssign,
347    Shr,
348    ShrAssign,
349}
350impl TryFrom<Ident> for OpTrait {
351    type Error = Error;
352    fn try_from(ident: Ident) -> Result<Self> {
353        if let Ok(x) = Self::from_str(&ident.to_string()) {
354            Ok(x)
355        } else {
356            Err(Error::new(
357                ident.span(),
358                format!("unexpacted Ident: {}", ident),
359            ))
360        }
361    }
362}
363impl ToTokens for OpTrait {
364    fn to_tokens(&self, tokens: &mut TokenStream) {
365        tokens.append(Ident::new(&self.to_string(), Span::call_site()));
366    }
367}
368
369impl OpTrait {
370    fn to_assign(self) -> Self {
371        use OpTrait::*;
372        match self {
373            Add | AddAssign => AddAssign,
374            Sub | SubAssign => SubAssign,
375            Mul | MulAssign => MulAssign,
376            Div | DivAssign => DivAssign,
377            Rem | RemAssign => RemAssign,
378            BitAnd | BitAndAssign => BitAndAssign,
379            BitOr | BitOrAssign => BitOrAssign,
380            BitXor | BitXorAssign => BitXorAssign,
381            Shl | ShlAssign => ShlAssign,
382            Shr | ShrAssign => ShrAssign,
383        }
384    }
385    fn to_non_assign(self) -> Self {
386        use OpTrait::*;
387        match self {
388            Add | AddAssign => Add,
389            Sub | SubAssign => Sub,
390            Mul | MulAssign => Mul,
391            Div | DivAssign => Div,
392            Rem | RemAssign => Rem,
393            BitAnd | BitAndAssign => BitAnd,
394            BitOr | BitOrAssign => BitOr,
395            BitXor | BitXorAssign => BitXor,
396            Shl | ShlAssign => Shl,
397            Shr | ShrAssign => Shr,
398        }
399    }
400    fn is_assign(self) -> bool {
401        self.to_assign() == self
402    }
403    fn to_func_ident(self) -> Ident {
404        use OpTrait::*;
405        match self {
406            Add => format_ident!("add"),
407            AddAssign => format_ident!("add_assign"),
408            Sub => format_ident!("sub"),
409            SubAssign => format_ident!("sub_assign"),
410            Mul => format_ident!("mul"),
411            MulAssign => format_ident!("mul_assign"),
412            Div => format_ident!("div"),
413            DivAssign => format_ident!("div_assign"),
414            Rem => format_ident!("rem"),
415            RemAssign => format_ident!("rem_assign"),
416            BitAnd => format_ident!("bitand"),
417            BitAndAssign => format_ident!("bitand_assign"),
418            BitOr => format_ident!("bitor"),
419            BitOrAssign => format_ident!("bitor_assign"),
420            BitXor => format_ident!("bitxor"),
421            BitXorAssign => format_ident!("bitxor_assign"),
422            Shl => format_ident!("shl"),
423            ShlAssign => format_ident!("shl_assign"),
424            Shr => format_ident!("shr"),
425            ShrAssign => format_ident!("shr_assign"),
426        }
427    }
428}
429
430fn auto_ops_impl_inner(attrs: TokenStream, tokens: TokenStream) -> Result<TokenStream> {
431    let a = Punctuated::parse_terminated.parse2(attrs)?;
432    let i = parse2(tokens)?;
433    auto_ops_generate(a, i)
434}
435
436fn auto_ops_impl(attrs: TokenStream, tokens: TokenStream) -> TokenStream {
437    auto_ops_impl_inner(attrs, tokens).unwrap_or_else(Error::into_compile_error)
438}
439
440#[proc_macro_attribute]
441pub fn auto_ops(
442    attrs: proc_macro::TokenStream,
443    tokens: proc_macro::TokenStream,
444) -> proc_macro::TokenStream {
445    auto_ops_impl(attrs.into(), tokens.into()).into()
446}