Skip to main content

cefact_units_macros/
lib.rs

1//! Proc-macro for generating UN/CEFACT unit code types from rec20.xlsx.
2
3use std::collections::{HashMap, HashSet};
4use std::fs::File;
5use std::io::BufReader;
6use std::path::PathBuf;
7
8use calamine::{open_workbook, Data, Reader, Xlsx};
9use heck::ToPascalCase;
10use proc_macro::TokenStream;
11use quote::{format_ident, quote};
12use syn::{parse_macro_input, LitStr};
13use unicode_normalization::UnicodeNormalization;
14
15// ── data model ───────────────────────────────────────────────────────────────
16
17struct Record {
18    variant: String,
19    code_variant: String,
20    code: String,
21    name: String,
22    symbol: Option<String>,
23    quantity: Option<String>,
24    sector: Option<String>,
25    conversion_factor: Option<String>,
26    level_category: Option<String>,
27    description: Option<String>,
28}
29
30// ── helpers ──────────────────────────────────────────────────────────────────
31
32fn cell(data: &Data) -> Option<String> {
33    match data {
34        Data::String(s) => {
35            let t = s.trim().to_owned();
36            (!t.is_empty()).then_some(t)
37        }
38        Data::Float(f) => Some(f.to_string()),
39        Data::Int(i) => Some(i.to_string()),
40        _ => None,
41    }
42}
43
44fn is_active(status_cell: &Data) -> bool {
45    !matches!(
46        cell(status_cell).as_deref(),
47        Some("D") | Some("X") | Some("¦")
48    )
49}
50
51fn pascal(name: &str, code: &str) -> String {
52    let cleaned: String = name
53        .chars()
54        .map(|c| {
55            if c.is_ascii_alphanumeric() || c == ' ' {
56                c
57            } else {
58                ' '
59            }
60        })
61        .collect();
62    let base = cleaned.trim().to_pascal_case();
63    if base.is_empty() || base.starts_with(|c: char| c.is_ascii_digit()) {
64        format!(
65            "_{}",
66            code.chars()
67                .map(|c| if c.is_ascii_alphanumeric() {
68                    c.to_ascii_uppercase()
69                } else {
70                    '_'
71                })
72                .collect::<String>()
73        )
74    } else {
75        base
76    }
77}
78
79fn code_suffix_pascal(code: &str) -> String {
80    let cleaned: String = code
81        .chars()
82        .map(|c| if c.is_ascii_alphanumeric() { c } else { ' ' })
83        .collect();
84    let pascal = cleaned.to_pascal_case();
85    if pascal.starts_with(|c: char| c.is_ascii_digit()) {
86        format!("Code{pascal}")
87    } else {
88        pascal
89    }
90}
91
92fn code_variant(code: &str) -> String {
93    let cleaned: String = code
94        .chars()
95        .map(|c| if c.is_ascii_alphanumeric() { c } else { ' ' })
96        .collect();
97    let base = cleaned.trim().to_pascal_case();
98    if base.is_empty() || base.starts_with(|c: char| c.is_ascii_digit()) {
99        format!("Code{base}")
100    } else {
101        base
102    }
103}
104
105fn normalize_text(value: String) -> String {
106    // U+2126 (OHM SIGN) → U+03A9 (GREEK CAPITAL LETTER OMEGA)
107    value.replace('\u{2126}', "\u{03A9}").nfc().collect()
108}
109
110fn test_fn_name(code: &str) -> String {
111    let sanitized: String = code
112        .to_ascii_lowercase()
113        .chars()
114        .map(|c| if c.is_ascii_alphanumeric() { c } else { '_' })
115        .collect();
116    if sanitized.starts_with(|c: char| c.is_ascii_digit()) {
117        format!("test_code_{}", sanitized)
118    } else {
119        format!("test_{}", sanitized)
120    }
121}
122
123// ── parsing ───────────────────────────────────────────────────────────────────
124
125fn parse(wb: &mut Xlsx<BufReader<File>>) -> Vec<Record> {
126    let mut records: Vec<Record> = Vec::new();
127    let mut seen: HashSet<String> = HashSet::new();
128
129    // Annex I columns:
130    // 0 Group Number | 1 Sector | 2 Group ID | 3 Quantity | 4 Level/Category
131    // 5 Status | 6 Common Code | 7 Name | 8 Conversion Factor | 9 Symbol | 10 Description
132    {
133        let sheet = wb
134            .worksheet_range("Annex I")
135            .expect("sheet 'Annex I' missing");
136        let get = |row: &[Data], i: usize| row.get(i).and_then(cell).map(normalize_text);
137
138        for row in sheet.rows().skip(1) {
139            if !is_active(row.get(5).unwrap_or(&Data::Empty)) {
140                continue;
141            }
142            let Some(code) = get(row, 6) else {
143                continue;
144            };
145            if seen.contains(&code) {
146                continue;
147            }
148            seen.insert(code.clone());
149            records.push(Record {
150                variant: String::new(),
151                code_variant: String::new(),
152                code,
153                name: get(row, 7).unwrap_or_default(),
154                symbol: get(row, 9),
155                quantity: get(row, 3),
156                sector: get(row, 1),
157                conversion_factor: get(row, 8),
158                level_category: get(row, 4),
159                description: get(row, 10),
160            });
161        }
162    }
163
164    // Annex II & III columns:
165    // 0 Status | 1 Common Code | 2 Name | 3 Description
166    // 4 Level/Category | 5 Symbol | 6 Conversion Factor
167    let mut desc_supplement: HashMap<String, String> = HashMap::new();
168    {
169        let sheet = wb
170            .worksheet_range("Annex II & Annex III")
171            .expect("sheet 'Annex II & Annex III' missing");
172        let get = |row: &[Data], i: usize| row.get(i).and_then(cell).map(normalize_text);
173
174        for row in sheet.rows().skip(1) {
175            if !is_active(row.first().unwrap_or(&Data::Empty)) {
176                continue;
177            }
178            let Some(code) = get(row, 1) else {
179                continue;
180            };
181
182            if let Some(desc) = get(row, 3) {
183                desc_supplement.entry(code.clone()).or_insert(desc);
184            }
185
186            if seen.contains(&code) {
187                continue;
188            }
189            seen.insert(code.clone());
190            records.push(Record {
191                variant: String::new(),
192                code_variant: String::new(),
193                code,
194                name: get(row, 2).unwrap_or_default(),
195                symbol: get(row, 5),
196                quantity: None,
197                sector: None,
198                conversion_factor: get(row, 6),
199                level_category: get(row, 4),
200                description: get(row, 3),
201            });
202        }
203    }
204
205    // Fill missing descriptions from supplement map.
206    for rec in &mut records {
207        if rec.description.is_none() {
208            rec.description = desc_supplement.get(&rec.code).cloned();
209        }
210    }
211
212    // Assign unique PascalCase variant names.
213    let mut counts: HashMap<String, usize> = HashMap::new();
214    for rec in &mut records {
215        let base = pascal(&rec.name, &rec.code);
216        let n = counts.entry(base.clone()).or_insert(0);
217        rec.variant = if *n == 0 {
218            base
219        } else {
220            let suffix = code_suffix_pascal(&rec.code);
221            format!("{base}{suffix}")
222        };
223        *n += 1;
224    }
225
226    let mut code_variant_counts: HashMap<String, usize> = HashMap::new();
227    for rec in &mut records {
228        let base = code_variant(&rec.code);
229        let n = code_variant_counts.entry(base.clone()).or_insert(0);
230        rec.code_variant = if *n == 0 {
231            base
232        } else {
233            let suffix = code_suffix_pascal(&rec.code);
234            format!("{base}{suffix}")
235        };
236        *n += 1;
237    }
238
239    records
240}
241
242fn opt_str(v: &Option<String>) -> proc_macro2::TokenStream {
243    match v {
244        Some(s) => quote! { Some(#s) },
245        None => quote! { None },
246    }
247}
248
249/// Generate UN/CEFACT unit types from an Excel file.
250///
251/// Usage: `cefact_units!("rec20.xlsx");`
252#[proc_macro]
253pub fn cefact_units(input: TokenStream) -> TokenStream {
254    let path_lit = parse_macro_input!(input as LitStr);
255    let rel_path = path_lit.value();
256
257    // Resolve path relative to calling crate's CARGO_MANIFEST_DIR
258    let manifest_dir = std::env::var("CARGO_MANIFEST_DIR")
259        .expect("CARGO_MANIFEST_DIR not set");
260    let xlsx_path = PathBuf::from(&manifest_dir).join(&rel_path);
261
262    let mut wb: Xlsx<_> = open_workbook(&xlsx_path)
263        .unwrap_or_else(|e| panic!("cannot open {}: {e}", xlsx_path.display()));
264
265    let records = parse(&mut wb);
266
267    // Generate identifiers
268    let unit_variants: Vec<_> = records
269        .iter()
270        .map(|r| format_ident!("{}", r.variant))
271        .collect();
272    let code_variants: Vec<_> = records
273        .iter()
274        .map(|r| format_ident!("{}", r.code_variant))
275        .collect();
276    let codes: Vec<_> = records.iter().map(|r| &r.code).collect();
277    let codes_upper: Vec<_> = records.iter().map(|r| r.code.to_ascii_uppercase()).collect();
278    let names: Vec<_> = records.iter().map(|r| &r.name).collect();
279    let symbols: Vec<_> = records.iter().map(|r| opt_str(&r.symbol)).collect();
280    let quantities: Vec<_> = records.iter().map(|r| opt_str(&r.quantity)).collect();
281    let sectors: Vec<_> = records.iter().map(|r| opt_str(&r.sector)).collect();
282    let conversion_factors: Vec<_> = records
283        .iter()
284        .map(|r| opt_str(&r.conversion_factor))
285        .collect();
286    let level_categories: Vec<_> = records
287        .iter()
288        .map(|r| opt_str(&r.level_category))
289        .collect();
290    let descriptions: Vec<_> = records.iter().map(|r| opt_str(&r.description)).collect();
291    let count = records.len();
292
293    let code_docs: Vec<_> = records
294        .iter()
295        .map(|r| format!("Code `{}`.", r.code))
296        .collect();
297    let unit_docs: Vec<_> = records
298        .iter()
299        .map(|r| format!("Unit `{}`.", r.code))
300        .collect();
301    let test_fns: Vec<_> = records
302        .iter()
303        .map(|r| format_ident!("{}", test_fn_name(&r.code)))
304        .collect();
305
306    let output = quote! {
307        // ── UnitCode enum ────────────────────────────────────────────────────────
308
309        /// Canonical UN/CEFACT common code (for example: `"MTR"`, `"KGM"`).
310        #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
311        #[non_exhaustive]
312        pub enum UnitCode {
313            #(
314                #[doc = #code_docs]
315                #code_variants,
316            )*
317        }
318
319        impl UnitCode {
320            /// Returns this code as a static string.
321            #[inline]
322            #[must_use]
323            pub const fn as_str(self) -> &'static str {
324                match self {
325                    #( Self::#code_variants => #codes, )*
326                }
327            }
328
329            /// Look up a code by its string representation (case-sensitive).
330            #[inline]
331            #[must_use]
332            pub fn from_code(code: &str) -> Option<Self> {
333                match code {
334                    #( #codes => Some(Self::#code_variants), )*
335                    _ => None,
336                }
337            }
338
339            /// Look up a code by its string representation (case-insensitive).
340            #[inline]
341            #[cfg(feature = "case-insensitive")]
342            #[must_use]
343            pub fn from_code_ignore_case(code: &str) -> Option<Self> {
344                match code.to_ascii_uppercase().as_str() {
345                    #( #codes_upper => Some(Self::#code_variants), )*
346                    _ => None,
347                }
348            }
349        }
350
351        impl core::fmt::Display for UnitCode {
352            fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
353                f.write_str(self.as_str())
354            }
355        }
356
357        impl core::str::FromStr for UnitCode {
358            type Err = UnknownCode;
359
360            fn from_str(s: &str) -> Result<Self, Self::Err> {
361                #[cfg(feature = "case-insensitive")]
362                let result = Self::from_code_ignore_case(s);
363
364                #[cfg(not(feature = "case-insensitive"))]
365                let result = Self::from_code(s);
366
367                result.ok_or_else(|| UnknownCode(s.to_owned()))
368            }
369        }
370
371        #[cfg(feature = "serde")]
372        impl serde::Serialize for UnitCode {
373            fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
374            where
375                S: serde::Serializer,
376            {
377                serializer.serialize_str(self.as_str())
378            }
379        }
380
381        #[cfg(feature = "serde")]
382        impl<'de> serde::Deserialize<'de> for UnitCode {
383            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
384            where
385                D: serde::Deserializer<'de>,
386            {
387                let code = <&str>::deserialize(deserializer)?;
388                code.parse().map_err(serde::de::Error::custom)
389            }
390        }
391
392        impl<'a> TryFrom<&'a str> for UnitCode {
393            type Error = UnknownCode;
394
395            fn try_from(s: &'a str) -> Result<Self, Self::Error> {
396                s.parse()
397            }
398        }
399
400        // ── UnitOfMeasure enum ───────────────────────────────────────────────────
401
402        /// Every active UN/CEFACT unit of measure code (Rec 20, Rev 17 — 2021).
403        ///
404        /// Use [`UnitOfMeasure::from_code`] to look up by common code,
405        /// or iterate [`UnitOfMeasure::ALL`].
406        #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
407        #[non_exhaustive]
408        pub enum UnitOfMeasure {
409            #(
410                #[doc = #unit_docs]
411                #unit_variants,
412            )*
413        }
414
415        // ── UnitInfo struct ──────────────────────────────────────────────────────
416
417        /// All metadata for a [`UnitOfMeasure`] variant.
418        #[derive(Debug, Clone, Copy, PartialEq, Eq)]
419        #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
420        pub struct UnitInfo {
421            /// 2- or 3-character alphanumeric common code (e.g. `"MTR"`, `"KGM"`).
422            pub code: &'static str,
423            /// English unit name.
424            pub name: &'static str,
425            /// Unit symbol (e.g. `m`, `kg`), if defined.
426            pub symbol: Option<&'static str>,
427            /// Physical quantity this unit measures (Annex I only).
428            pub quantity: Option<&'static str>,
429            /// Sector / scientific discipline (Annex I only).
430            pub sector: Option<&'static str>,
431            /// Informative conversion factor expression.
432            pub conversion_factor: Option<&'static str>,
433            /// Normative level/category code from the recommendation.
434            pub level_category: Option<&'static str>,
435            /// Human-readable description.
436            pub description: Option<&'static str>,
437        }
438
439        impl UnitOfMeasure {
440            /// Returns all static metadata for this unit.
441            #[inline]
442            #[must_use]
443            pub const fn info(self) -> UnitInfo {
444                match self {
445                    #(
446                        Self::#unit_variants => UnitInfo {
447                            code: #codes,
448                            name: #names,
449                            symbol: #symbols,
450                            quantity: #quantities,
451                            sector: #sectors,
452                            conversion_factor: #conversion_factors,
453                            level_category: #level_categories,
454                            description: #descriptions,
455                        },
456                    )*
457                }
458            }
459
460            /// 2- or 3-character common code.
461            #[inline]
462            #[must_use]
463            pub const fn code(self) -> &'static str {
464                self.info().code
465            }
466
467            /// English unit name.
468            #[inline]
469            #[must_use]
470            pub const fn name(self) -> &'static str {
471                self.info().name
472            }
473
474            /// Unit symbol, if any.
475            #[inline]
476            #[must_use]
477            pub const fn symbol(self) -> Option<&'static str> {
478                self.info().symbol
479            }
480
481            /// Physical quantity (Annex I only).
482            #[inline]
483            #[must_use]
484            pub const fn quantity(self) -> Option<&'static str> {
485                self.info().quantity
486            }
487
488            /// Sector / discipline (Annex I only).
489            #[inline]
490            #[must_use]
491            pub const fn sector(self) -> Option<&'static str> {
492                self.info().sector
493            }
494
495            /// Informative conversion factor expression.
496            #[inline]
497            #[must_use]
498            pub const fn conversion_factor(self) -> Option<&'static str> {
499                self.info().conversion_factor
500            }
501
502            /// Normative level/category code.
503            #[inline]
504            #[must_use]
505            pub const fn level_category(self) -> Option<&'static str> {
506                self.info().level_category
507            }
508
509            /// Description text.
510            #[inline]
511            #[must_use]
512            pub const fn description(self) -> Option<&'static str> {
513                self.info().description
514            }
515
516            /// Returns this unit's strongly-typed code.
517            #[inline]
518            #[must_use]
519            pub const fn unit_code(self) -> UnitCode {
520                match self {
521                    #( Self::#unit_variants => UnitCode::#code_variants, )*
522                }
523            }
524
525            /// Converts a code into its associated unit.
526            #[inline]
527            #[must_use]
528            pub const fn from_unit_code(code: UnitCode) -> Self {
529                match code {
530                    #( UnitCode::#code_variants => Self::#unit_variants, )*
531                }
532            }
533
534            /// Look up a unit by its common code (case-sensitive).
535            #[inline]
536            #[must_use]
537            pub fn from_code(code: &str) -> Option<Self> {
538                UnitCode::from_code(code).map(Self::from_unit_code)
539            }
540
541            /// Look up a unit by its common code (case-insensitive).
542            #[inline]
543            #[cfg(feature = "case-insensitive")]
544            #[must_use]
545            pub fn from_code_ignore_case(code: &str) -> Option<Self> {
546                UnitCode::from_code_ignore_case(code).map(Self::from_unit_code)
547            }
548
549            /// Every active unit in source order (Annex I first, then Annex II/III additions).
550            pub const ALL: &'static [Self; #count] = &[
551                #( Self::#unit_variants, )*
552            ];
553        }
554
555        impl core::fmt::Display for UnitOfMeasure {
556            fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
557                f.write_str(self.code())
558            }
559        }
560
561        impl core::str::FromStr for UnitOfMeasure {
562            type Err = UnknownCode;
563
564            fn from_str(s: &str) -> Result<Self, Self::Err> {
565                #[cfg(feature = "case-insensitive")]
566                let result = Self::from_code_ignore_case(s);
567
568                #[cfg(not(feature = "case-insensitive"))]
569                let result = Self::from_code(s);
570
571                result.ok_or_else(|| UnknownCode(s.to_owned()))
572            }
573        }
574
575        #[cfg(feature = "serde")]
576        impl serde::Serialize for UnitOfMeasure {
577            fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
578            where
579                S: serde::Serializer,
580            {
581                serializer.serialize_str(self.code())
582            }
583        }
584
585        #[cfg(feature = "serde")]
586        impl<'de> serde::Deserialize<'de> for UnitOfMeasure {
587            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
588            where
589                D: serde::Deserializer<'de>,
590            {
591                let code = <&str>::deserialize(deserializer)?;
592                code.parse().map_err(serde::de::Error::custom)
593            }
594        }
595
596        impl<'a> TryFrom<&'a str> for UnitOfMeasure {
597            type Error = UnknownCode;
598
599            fn try_from(s: &'a str) -> Result<Self, Self::Error> {
600                s.parse()
601            }
602        }
603
604        // ── UnknownCode error ────────────────────────────────────────────────────
605
606        /// Error returned when parsing an unrecognised UN/CEFACT common code.
607        #[derive(Debug, Clone, PartialEq, Eq)]
608        pub struct UnknownCode(pub String);
609
610        impl core::fmt::Display for UnknownCode {
611            fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
612                write!(f, "unknown UN/CEFACT unit code: {:?}", self.0)
613            }
614        }
615
616        impl std::error::Error for UnknownCode {}
617
618        // ── Generated tests ──────────────────────────────────────────────────────
619
620        #[cfg(all(test, feature = "generated-tests"))]
621        mod generated_tests {
622            use super::*;
623            #(
624                #[test]
625                fn #test_fns() {
626                    let code: UnitCode = #codes.parse().unwrap();
627                    assert_eq!(code.as_str(), #codes);
628                    let unit = UnitOfMeasure::from_unit_code(code);
629                    assert_eq!(unit.code(), #codes);
630                    let parsed: UnitOfMeasure = #codes.parse().unwrap();
631                    assert_eq!(parsed.code(), #codes);
632                    assert_eq!(unit.unit_code(), code);
633                    let _ = unit.info();
634                }
635            )*
636        }
637    };
638
639    output.into()
640}