primwrap/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2
3mod accum;
4mod cmp;
5mod fmt;
6mod ops;
7mod util;
8
9use proc_macro::TokenStream;
10use strum::{EnumIter, EnumString, IntoEnumIterator};
11use virtue::parse::{Attribute, StructBody};
12use virtue::prelude::*;
13use crate::accum::generate_accum;
14use crate::cmp::generate_cmp;
15use crate::fmt::generate_fmt;
16use crate::ops::{Arithmetic, Bit, Op};
17
18#[derive(EnumIter, EnumString, Eq, PartialEq)]
19#[strum(ascii_case_insensitive)]
20enum Group {
21	Arithmetic,
22	Bitwise,
23	Formatting,
24	Comparison,
25	Accumulation
26}
27
28/// Derives arithmetic, bitwise, comparison, and formatting traits on a primitive
29/// wrapper struct, exposing its inner type. Integer, float, and boolean types are
30/// supported.
31///
32/// The implemented traits can be selected with a `#[primwrap(...)]` attribute:
33/// - `arithmetic` enables `Add`, `Sub`, `Mul`, `Div`, `Rem`, and `Neg`
34/// - `bitwise` enables `Not`, `BitAnd`, `BitOr`, `BitXor`, `Shl`, and `Shr`
35/// - `formatting` enables `Debug`, `Display`, `Binary`, `Octal`, `LowerExp`,
36///   `LowerHex`, `UpperExp`, and `UpperHex`
37/// - `comparison` enables `PartialEq`/`PartialOrd` with the inner type
38/// - `accumulation` enables `Sum` and `Product`
39#[proc_macro_derive(Primitive, attributes(primwrap))]
40pub fn primitive_derive(input: TokenStream) -> TokenStream {
41	let expand = || {
42		let parsed = Parse::new(input)?;
43		let groups = if let Parse::Struct { ref attributes, .. } = parsed {
44			parse_attributes(attributes)?
45		} else {
46			Vec::default()
47		};
48
49		let (
50			mut gen,
51			_,
52			Body::Struct(
53				StructBody {
54					fields: Some(Fields::Tuple(fields))
55				}
56			)
57		) = parsed.into_generator() else {
58			return Err(Error::custom("expected tuple struct"))
59		};
60
61		let [field] = &fields[..] else {
62			return Err(Error::custom("expected tuple struct with one field"))
63		};
64
65		let [TokenTree::Ident(inner_type)] = &field.r#type[..] else {
66			return Err(Error::custom("unknown type"))
67		};
68		let ref target = gen.target_name().to_string();
69		let ref inner = inner_type.to_string();
70
71		let has_arith = groups.contains(&Group::Arithmetic);
72		for group in groups {
73			match group {
74				Group::Arithmetic => Arithmetic::generate_all(&mut gen, target, inner)?,
75				Group::Bitwise => Bit::generate_all(&mut gen, target, inner)?,
76				Group::Formatting => generate_fmt(&mut gen, target, inner)?,
77				Group::Comparison => generate_cmp(&mut gen, target, inner)?,
78				Group::Accumulation => generate_accum(&mut gen, has_arith, target, inner)?,
79			}
80		}
81
82		gen.finish()
83	};
84
85	expand().unwrap_or_else(Error::into_token_stream)
86}
87
88fn parse_attributes(attributes: &Vec<Attribute>) -> Result<Vec<Group>> {
89	fn convert_error<T>(result: syn::Result<T>) -> Result<T> {
90		result.map_err(|err| Error::custom_at(err.to_string(), err.span().unwrap()))
91	}
92
93	for Attribute { tokens, .. } in attributes.iter() {
94		let stream = tokens.stream();
95		let meta: syn::Meta = convert_error(syn::parse(stream))?;
96		let list = convert_error(meta.require_list())?;
97		if !list.path.is_ident("primwrap") { continue }
98
99		let mut groups = Vec::with_capacity(4);
100		convert_error(list.parse_nested_meta(|meta| {
101			let ident = meta.path.require_ident()?.to_string();
102			let group = ident.parse().map_err(|_|
103				meta.input.error(r#"expected "arithmetic", "bitwise", "formatting", or "comparison""#)
104			)?;
105			groups.push(group);
106			Ok(())
107		}))?;
108
109		if !groups.is_empty() {
110			return Ok(groups)
111		}
112	}
113
114	Ok(Group::iter().collect())
115}