1#![deny(missing_docs)]
25#![forbid(unsafe_code)]
26
27use proc_macro::TokenStream;
28use proc_macro2::TokenStream as TokenStream2;
29use quote::quote;
30use syn::{
31 parse::{Parse, ParseStream},
32 parse_macro_input, Attribute, DeriveInput, Expr, Ident, LitStr, Token,
33};
34
35#[proc_macro_derive(Unit, attributes(unit))]
41pub fn derive_unit(input: TokenStream) -> TokenStream {
42 let input = parse_macro_input!(input as DeriveInput);
43
44 match derive_unit_impl(input) {
45 Ok(tokens) => tokens.into(),
46 Err(err) => err.to_compile_error().into(),
47 }
48}
49
50fn derive_unit_impl(input: DeriveInput) -> syn::Result<TokenStream2> {
51 let name = &input.ident;
52
53 let unit_attr = parse_unit_attribute(&input.attrs)?;
55
56 let symbol = &unit_attr.symbol;
57 let dimension = &unit_attr.dimension;
58 let ratio = &unit_attr.ratio;
59
60 let expanded = quote! {
61 impl crate::Unit for #name {
62 const RATIO: f64 = #ratio;
63 type Dim = #dimension;
64 const SYMBOL: &'static str = #symbol;
65 }
66
67 impl ::core::fmt::Display for crate::Quantity<#name> {
68 fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
69 write!(f, "{} {}", self.value(), <#name as crate::Unit>::SYMBOL)
70 }
71 }
72 };
73
74 Ok(expanded)
75}
76
77struct UnitAttribute {
79 symbol: LitStr,
80 dimension: Expr,
81 ratio: Expr,
82 }
89
90impl Parse for UnitAttribute {
91 fn parse(input: ParseStream) -> syn::Result<Self> {
92 let mut symbol: Option<LitStr> = None;
93 let mut dimension: Option<Expr> = None;
94 let mut ratio: Option<Expr> = None;
95
96 while !input.is_empty() {
97 let ident: Ident = input.parse()?;
98 input.parse::<Token![=]>()?;
99
100 match ident.to_string().as_str() {
101 "symbol" => {
102 symbol = Some(input.parse()?);
103 }
104 "dimension" => {
105 dimension = Some(input.parse()?);
106 }
107 "ratio" => {
108 ratio = Some(input.parse()?);
109 }
110 other => {
117 return Err(syn::Error::new(
118 ident.span(),
119 format!("unknown attribute `{}`", other),
120 ));
121 }
122 }
123
124 if input.peek(Token![,]) {
126 input.parse::<Token![,]>()?;
127 }
128 }
129
130 let symbol = symbol
131 .ok_or_else(|| syn::Error::new(input.span(), "missing required attribute `symbol`"))?;
132 let dimension = dimension.ok_or_else(|| {
133 syn::Error::new(input.span(), "missing required attribute `dimension`")
134 })?;
135 let ratio = ratio
136 .ok_or_else(|| syn::Error::new(input.span(), "missing required attribute `ratio`"))?;
137
138 Ok(UnitAttribute {
139 symbol,
140 dimension,
141 ratio,
142 })
143 }
144}
145
146fn parse_unit_attribute(attrs: &[Attribute]) -> syn::Result<UnitAttribute> {
147 for attr in attrs {
148 if attr.path().is_ident("unit") {
149 return attr.parse_args::<UnitAttribute>();
150 }
151 }
152
153 Err(syn::Error::new(
154 proc_macro2::Span::call_site(),
155 "missing #[unit(...)] attribute",
156 ))
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162 use quote::quote;
163 use syn::parse_quote;
164
165 #[test]
166 fn test_parse_unit_attribute_complete() {
167 let input: DeriveInput = parse_quote! {
168 #[unit(symbol = "m", dimension = Length, ratio = 1.0)]
169 pub enum Meter {}
170 };
171
172 let attr = parse_unit_attribute(&input.attrs).unwrap();
173 assert_eq!(attr.symbol.value(), "m");
174 }
175
176 #[test]
177 fn test_parse_unit_attribute_missing() {
178 let input: DeriveInput = parse_quote! {
179 pub enum Meter {}
180 };
181
182 let result = parse_unit_attribute(&input.attrs);
183 assert!(result.is_err());
184 let err = result.err().unwrap();
185 let err_msg = err.to_string();
186 assert!(err_msg.contains("missing #[unit(...)] attribute"));
187 }
188
189 #[test]
190 fn test_parse_unit_attribute_missing_symbol() {
191 let input: DeriveInput = parse_quote! {
192 #[unit(dimension = Length, ratio = 1.0)]
193 pub enum Meter {}
194 };
195
196 let result = parse_unit_attribute(&input.attrs);
197 assert!(result.is_err());
198 let err = result.err().unwrap();
199 let err_msg = err.to_string();
200 assert!(err_msg.contains("missing required attribute `symbol`"));
201 }
202
203 #[test]
204 fn test_parse_unit_attribute_missing_dimension() {
205 let input: DeriveInput = parse_quote! {
206 #[unit(symbol = "m", ratio = 1.0)]
207 pub enum Meter {}
208 };
209
210 let result = parse_unit_attribute(&input.attrs);
211 assert!(result.is_err());
212 let err = result.err().unwrap();
213 let err_msg = err.to_string();
214 assert!(err_msg.contains("missing required attribute `dimension`"));
215 }
216
217 #[test]
218 fn test_parse_unit_attribute_missing_ratio() {
219 let input: DeriveInput = parse_quote! {
220 #[unit(symbol = "m", dimension = Length)]
221 pub enum Meter {}
222 };
223
224 let result = parse_unit_attribute(&input.attrs);
225 assert!(result.is_err());
226 let err = result.err().unwrap();
227 let err_msg = err.to_string();
228 assert!(err_msg.contains("missing required attribute `ratio`"));
229 }
230
231 #[test]
232 fn test_parse_unit_attribute_unknown_field() {
233 let input: DeriveInput = parse_quote! {
234 #[unit(symbol = "m", dimension = Length, ratio = 1.0, unknown = "value")]
235 pub enum Meter {}
236 };
237
238 let result = parse_unit_attribute(&input.attrs);
239 assert!(result.is_err());
240 let err = result.err().unwrap();
241 let err_msg = err.to_string();
242 assert!(err_msg.contains("unknown attribute"));
243 }
244
245 #[test]
246 fn test_derive_unit_impl_basic() {
247 let input: DeriveInput = parse_quote! {
248 #[unit(symbol = "m", dimension = Length, ratio = 1.0)]
249 pub enum Meter {}
250 };
251
252 let result = derive_unit_impl(input);
253 assert!(result.is_ok());
254 let tokens = result.unwrap();
255 let code = tokens.to_string();
256 assert!(code.contains("impl crate :: Unit for Meter"));
257 assert!(code.contains("const RATIO : f64 = 1.0"));
258 assert!(code.contains("const SYMBOL : & 'static str = \"m\""));
259 assert!(code.contains("type Dim = Length"));
260 }
261
262 #[test]
263 fn test_derive_unit_impl_with_expression_ratio() {
264 let input: DeriveInput = parse_quote! {
265 #[unit(symbol = "km", dimension = Length, ratio = 1000.0)]
266 pub enum Kilometer {}
267 };
268
269 let result = derive_unit_impl(input);
270 assert!(result.is_ok());
271 let tokens = result.unwrap();
272 let code = tokens.to_string();
273 assert!(code.contains("const RATIO : f64 = 1000.0"));
274 }
275
276 #[test]
277 fn test_unit_attribute_parse_with_trailing_comma() {
278 let tokens = quote! {
279 symbol = "m", dimension = Length, ratio = 1.0,
280 };
281 let attr: UnitAttribute = syn::parse2(tokens).unwrap();
282 assert_eq!(attr.symbol.value(), "m");
283 }
284
285 #[test]
286 fn test_unit_attribute_parse_no_trailing_comma() {
287 let tokens = quote! {
288 symbol = "m", dimension = Length, ratio = 1.0
289 };
290 let attr: UnitAttribute = syn::parse2(tokens).unwrap();
291 assert_eq!(attr.symbol.value(), "m");
292 }
293
294 #[test]
295 fn test_unit_attribute_parse_duplicate_symbol() {
296 let tokens = quote! {
298 symbol = "m", symbol = "km", dimension = Length, ratio = 1.0
299 };
300 let attr: UnitAttribute = syn::parse2(tokens).unwrap();
301 assert_eq!(attr.symbol.value(), "km");
302 }
303
304 #[test]
305 fn test_parse_empty_attribute() {
306 let tokens = quote! {};
307 let result: syn::Result<UnitAttribute> = syn::parse2(tokens);
308 assert!(result.is_err());
309 }
310
311 #[test]
312 fn test_derive_unit_impl_error_path() {
313 let input: DeriveInput = parse_quote! {
315 pub enum Meter {}
316 };
317 let result = derive_unit_impl(input);
318 assert!(result.is_err());
319 let err = result.err().unwrap();
321 let err_tokens = err.to_compile_error();
322 let code = err_tokens.to_string();
323 assert!(code.contains("compile_error"));
324 }
325}