1#![doc = include_str!("../README.md")]
2#[cfg(test)]
3mod tests;
4use proc_macro2::{Ident, Span, TokenStream};
5use quote::{format_ident, quote, ToTokens, TokenStreamExt};
6use std::collections::HashMap;
7use std::convert::{TryFrom, TryInto};
8use std::str::FromStr;
9use strum::{Display, EnumString};
10use syn::{parse::Parser, punctuated::Punctuated, spanned::Spanned, *};
11
12fn is_ref(type_: &Type) -> bool {
13 matches!(type_, Type::Reference(_))
14}
15
16fn remove_reference(type_: &Type) -> &Type {
17 match type_ {
18 Type::Reference(ref_) => &ref_.elem,
19 _ => type_,
20 }
21}
22
23fn copy_reference(target: &Type, source: &Type) -> Type {
24 match source {
25 Type::Reference(inner) => {
26 let mut out = inner.clone();
27 out.elem = Box::new(target.clone());
28 Type::Reference(out)
29 }
30 _ => target.clone(),
31 }
32}
33
34fn get_last_segment(implement: &ItemImpl) -> Result<&PathSegment> {
35 if implement.trait_.is_none() {
36 return Err(Error::new(implement.span(), "Is not Trait impl"));
37 };
38 let trait_ = implement.trait_.as_ref().unwrap();
39 if let Some(bang) = trait_.0 {
40 return Err(Error::new(bang.span(), "Unexpected negative impl"));
41 }
42 let segments = &trait_.1.segments;
43 if segments.is_empty() {
44 return Err(Error::new(segments.span(), "Unexpected empty trait path"));
45 }
46 Ok(segments.last().unwrap())
47}
48
49fn get_rhs_type<'a>(args: &'a PathArguments, self_type: &'a Type) -> Result<&'a Type> {
50 match args {
51 PathArguments::None => Ok(self_type),
52 PathArguments::AngleBracketed(args) => {
53 let args = &args.args;
54 if args.len() != 1 {
55 return Err(Error::new(
56 args.span(),
57 "Number of trait arguments is not 1",
58 ));
59 }
60 if let GenericArgument::Type(rhs_type) = args.first().unwrap() {
61 Ok(rhs_type)
62 } else {
63 Err(Error::new(args.span(), "Is not type"))
64 }
65 }
66 _ => Err(Error::new(args.span(), "Unexpected trait arguments")),
67 }
68}
69
70#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
71struct Operate(OpTrait, bool, bool);
72impl Operate {
73 fn lhs_move(&self) -> bool {
74 !self.0.is_assign() && !self.1
75 }
76 fn rhs_move(&self) -> bool {
77 !self.2
78 }
79 fn require_lhs_clone(&self, op: Self) -> bool {
80 (self.lhs_move() || self.0.is_assign()) && op.1
81 }
82 fn require_rhs_clone(&self, op: Self) -> bool {
83 self.rhs_move() && op.2
84 }
85 fn require_clone(&self, op: Self) -> bool {
86 self.require_lhs_clone(op) || self.require_rhs_clone(op)
87 }
88}
89
90#[derive(Clone, Debug)]
91struct Generator<'a> {
92 implement: &'a ItemImpl,
93 source_op: Operate,
94 self_type: &'a Type,
95 rhs_type: &'a Type,
96}
97impl<'a> Generator<'a> {
98 fn get_arg_type(is_ref_: bool, target: &Type, source: &Type) -> Type {
99 if !is_ref_ {
100 remove_reference(target).clone()
101 } else if is_ref(target) {
102 target.clone()
103 } else if is_ref(source) {
104 copy_reference(target, source)
105 } else {
106 parse_quote! {
107 &#target
108 }
109 }
110 }
111 fn update_where_clause(&self, generics: &mut Generics, op: Operate) {
112 let rr_self_type = remove_reference(self.self_type);
113 if self.source_op.require_clone(op) {
114 let wc = generics.make_where_clause();
115 wc.predicates.push(parse_quote! {
116 #rr_self_type: Clone
117 });
118 }
119 if self.source_op.lhs_move() && op.0.is_assign() && cfg!(not(feature = "take_mut")) {
120 let wc = generics.make_where_clause();
121 wc.predicates.push(parse_quote! {
122 #rr_self_type: Default
123 });
124 }
125 }
126 fn assgin_body(source_op: Operate) -> TokenStream {
127 let source_fn_name = source_op.0.to_func_ident();
128 if source_op.0.is_assign() {
129 quote! {
130 self.#source_fn_name(rhs);
131 }
132 } else if source_op.1 {
133 quote! {
134 *self = (&*self).#source_fn_name(rhs);
135 }
136 } else if cfg!(feature = "take_mut") {
137 quote! {
138 take_mut::take(self, |x| x.#source_fn_name(rhs));
139 }
140 } else {
141 quote! {
142 let mut t = Self::default();
143 std::mem::swap(&mut t, self);
144 let mut u = t.#source_fn_name(rhs);
145 std::mem::swap(&mut u, self);
146 }
147 }
148 }
149 fn gen_rhs(source_op: Operate, op: Operate) -> TokenStream {
150 #[allow(clippy::collapsible_else_if)]
151 if source_op.2 {
152 if op.2 {
153 TokenStream::new()
154 } else {
155 quote!(let rhs = &rhs;)
156 }
157 } else {
158 if op.2 {
159 quote!(let rhs = rhs.clone();)
160 } else {
161 TokenStream::new()
162 }
163 }
164 }
165 fn gen_lhs(source_op: Operate, op: Operate) -> TokenStream {
166 #[allow(clippy::collapsible_else_if)]
167 if source_op.0.is_assign() {
168 if op.1 {
169 quote!(let mut lhs = self.clone();)
170 } else {
171 quote!(let mut lhs = self;)
172 }
173 } else if source_op.1 {
174 if op.1 {
175 quote!(let lhs = self;)
176 } else {
177 quote!(let lhs = &self;)
178 }
179 } else {
180 if op.1 {
181 quote!(let lhs = self.clone();)
182 } else {
183 quote!(let lhs = self;)
184 }
185 }
186 }
187 fn gen_output(&self) -> Result<Type> {
188 let rr_self_type = remove_reference(self.self_type);
189 if self.source_op.0.is_assign() {
190 Ok(rr_self_type.clone())
191 } else {
192 let v = self
193 .implement
194 .items
195 .iter()
196 .filter_map(|x| {
197 if let ImplItem::Type(x) = x {
198 Some(x)
199 } else {
200 None
201 }
202 })
203 .filter_map(|x| {
204 if x.ident == "Output" {
205 Some(&x.ty)
206 } else {
207 None
208 }
209 })
210 .collect::<Vec<_>>();
211 if let [x] = v[..] {
212 if x == &parse_quote!(Self) {
213 Ok(rr_self_type.clone())
214 } else {
215 Ok(x.clone())
216 }
217 } else {
218 Err(Error::new(
219 Span::call_site(),
220 "`type Output =` is not found or multiple",
221 ))
222 }
223 }
224 }
225 fn generate(&self, op: Operate) -> Result<TokenStream> {
226 if op.0.is_assign() && op.1 {
227 return Err(Error::new(
228 Span::call_site(),
229 "Type of LHS of assign operations must not reference",
230 ));
231 }
232 if op == self.source_op {
233 return Ok(self.implement.to_token_stream());
234 }
235 let mut work = self.implement.clone();
236 if let Operate(_, false, false) = op {
237 work.attrs.push(parse_quote! {
238 #[allow(clippy::extra_unused_lifetimes)]
239 });
240 }
241 let rhs_type = Self::get_arg_type(op.2, self.rhs_type, self.self_type);
242 let trait_ = op.0;
243 *work.trait_.as_mut().unwrap().1.segments.last_mut().unwrap() =
244 parse_quote! { #trait_<#rhs_type> };
245 *work.self_ty.as_mut() = Self::get_arg_type(op.1, self.self_type, self.rhs_type);
246 self.update_where_clause(&mut work.generics, op);
247 work.items.clear();
248 let fn_name = op.0.to_func_ident();
249 let preamble_rhs = Self::gen_rhs(self.source_op, op);
250 if op.0.is_assign() {
251 let body = Self::assgin_body(self.source_op);
252 work.items.push(parse_quote! {
253 fn #fn_name(&mut self, rhs: #rhs_type) {
254 #preamble_rhs
255 #body
256 }
257 });
258 } else {
259 let output_type = self.gen_output()?;
260 work.items.push(parse_quote! {
261 type Output = #output_type;
262 });
263 let preamble_lhs = Self::gen_lhs(self.source_op, op);
264 let source_fn_name = self.source_op.0.to_func_ident();
265 let body = if self.source_op.0.is_assign() {
266 quote! {
267 lhs.#source_fn_name(rhs);
268 lhs
269 }
270 } else {
271 quote! {
272 lhs.#source_fn_name(rhs)
273 }
274 };
275 work.items.push(parse_quote! {
276 fn #fn_name(self, rhs: #rhs_type) -> Self::Output {
277 #preamble_lhs
278 #preamble_rhs
279 #body
280 }
281 });
282 }
283 Ok(quote!(#work))
284 }
285}
286
287type Attributes = Punctuated<Ident, token::Comma>;
288fn auto_ops_generate(mut attrs: Attributes, implement: ItemImpl) -> Result<TokenStream> {
289 let last_segment = get_last_segment(&implement)?;
290 let op: OpTrait = last_segment.ident.clone().try_into()?;
291 let self_type = &implement.self_ty;
292 let rhs_type = get_rhs_type(&last_segment.arguments, self_type)?;
293 let generator = Generator {
294 implement: &implement,
295 source_op: Operate(op, is_ref(self_type), is_ref(rhs_type)),
296 self_type,
297 rhs_type,
298 };
299 let list = [
300 ("assign_ref", Operate(op.to_assign(), false, true)),
301 ("assign_val", Operate(op.to_assign(), false, false)),
302 ("ref_ref", Operate(op.to_non_assign(), true, true)),
303 ("ref_val", Operate(op.to_non_assign(), true, false)),
304 ("val_ref", Operate(op.to_non_assign(), false, true)),
305 ("val_val", Operate(op.to_non_assign(), false, false)),
306 ];
307 let map = HashMap::from(list);
308 let rev_map = list.iter().map(|&(v, k)| (k, v)).collect::<HashMap<_, _>>();
309 if attrs.is_empty() {
310 attrs = list.iter().map(|(x, _)| format_ident!("{}", x)).collect();
311 }
312 let source = rev_map[&generator.source_op];
313 if !attrs.iter().any(|x| x == source) {
314 attrs.push(format_ident!("{}", source));
315 }
316 let mut result = TokenStream::new();
317 for i in attrs.iter() {
318 let s = i.to_string();
319 if let Some(op) = map.get(s.as_str()) {
320 let code = generator.generate(*op)?;
321 result.extend(code);
322 }
323 }
324 Ok(result)
325}
326
327#[derive(Clone, Copy, Debug, Display, EnumString, PartialEq, Eq, Hash)]
328enum OpTrait {
329 Add,
330 AddAssign,
331 Sub,
332 SubAssign,
333 Mul,
334 MulAssign,
335 Div,
336 DivAssign,
337 Rem,
338 RemAssign,
339 BitAnd,
340 BitAndAssign,
341 BitOr,
342 BitOrAssign,
343 BitXor,
344 BitXorAssign,
345 Shl,
346 ShlAssign,
347 Shr,
348 ShrAssign,
349}
350impl TryFrom<Ident> for OpTrait {
351 type Error = Error;
352 fn try_from(ident: Ident) -> Result<Self> {
353 if let Ok(x) = Self::from_str(&ident.to_string()) {
354 Ok(x)
355 } else {
356 Err(Error::new(
357 ident.span(),
358 format!("unexpacted Ident: {}", ident),
359 ))
360 }
361 }
362}
363impl ToTokens for OpTrait {
364 fn to_tokens(&self, tokens: &mut TokenStream) {
365 tokens.append(Ident::new(&self.to_string(), Span::call_site()));
366 }
367}
368
369impl OpTrait {
370 fn to_assign(self) -> Self {
371 use OpTrait::*;
372 match self {
373 Add | AddAssign => AddAssign,
374 Sub | SubAssign => SubAssign,
375 Mul | MulAssign => MulAssign,
376 Div | DivAssign => DivAssign,
377 Rem | RemAssign => RemAssign,
378 BitAnd | BitAndAssign => BitAndAssign,
379 BitOr | BitOrAssign => BitOrAssign,
380 BitXor | BitXorAssign => BitXorAssign,
381 Shl | ShlAssign => ShlAssign,
382 Shr | ShrAssign => ShrAssign,
383 }
384 }
385 fn to_non_assign(self) -> Self {
386 use OpTrait::*;
387 match self {
388 Add | AddAssign => Add,
389 Sub | SubAssign => Sub,
390 Mul | MulAssign => Mul,
391 Div | DivAssign => Div,
392 Rem | RemAssign => Rem,
393 BitAnd | BitAndAssign => BitAnd,
394 BitOr | BitOrAssign => BitOr,
395 BitXor | BitXorAssign => BitXor,
396 Shl | ShlAssign => Shl,
397 Shr | ShrAssign => Shr,
398 }
399 }
400 fn is_assign(self) -> bool {
401 self.to_assign() == self
402 }
403 fn to_func_ident(self) -> Ident {
404 use OpTrait::*;
405 match self {
406 Add => format_ident!("add"),
407 AddAssign => format_ident!("add_assign"),
408 Sub => format_ident!("sub"),
409 SubAssign => format_ident!("sub_assign"),
410 Mul => format_ident!("mul"),
411 MulAssign => format_ident!("mul_assign"),
412 Div => format_ident!("div"),
413 DivAssign => format_ident!("div_assign"),
414 Rem => format_ident!("rem"),
415 RemAssign => format_ident!("rem_assign"),
416 BitAnd => format_ident!("bitand"),
417 BitAndAssign => format_ident!("bitand_assign"),
418 BitOr => format_ident!("bitor"),
419 BitOrAssign => format_ident!("bitor_assign"),
420 BitXor => format_ident!("bitxor"),
421 BitXorAssign => format_ident!("bitxor_assign"),
422 Shl => format_ident!("shl"),
423 ShlAssign => format_ident!("shl_assign"),
424 Shr => format_ident!("shr"),
425 ShrAssign => format_ident!("shr_assign"),
426 }
427 }
428}
429
430fn auto_ops_impl_inner(attrs: TokenStream, tokens: TokenStream) -> Result<TokenStream> {
431 let a = Punctuated::parse_terminated.parse2(attrs)?;
432 let i = parse2(tokens)?;
433 auto_ops_generate(a, i)
434}
435
436fn auto_ops_impl(attrs: TokenStream, tokens: TokenStream) -> TokenStream {
437 auto_ops_impl_inner(attrs, tokens).unwrap_or_else(Error::into_compile_error)
438}
439
440#[proc_macro_attribute]
441pub fn auto_ops(
442 attrs: proc_macro::TokenStream,
443 tokens: proc_macro::TokenStream,
444) -> proc_macro::TokenStream {
445 auto_ops_impl(attrs.into(), tokens.into()).into()
446}