abol_codegen/
lib.rs

1use abol_parser::dictionary::{
2    AttributeType, Dictionary, DictionaryAttribute, DictionaryValue, SizeFlag,
3};
4use heck::{ToPascalCase, ToShoutySnakeCase, ToSnakeCase, ToUpperCamelCase};
5use proc_macro2::TokenStream;
6use quote::{format_ident, quote};
7use std::collections::{HashMap, HashSet};
8use std::io::Write;
9use std::process::{Command, Stdio};
10pub mod aruba;
11pub mod microsoft;
12pub mod rfc2865;
13pub mod rfc2866;
14pub mod rfc2869;
15pub mod rfc3576;
16pub mod rfc6911;
17pub mod wispr;
18
19/// A code generator that transforms RADIUS dictionary definitions into type-safe Rust traits.
20///
21/// This generator produces a trait (e.g., `Rfc2865Ext`) that extends the base `Packet` struct
22/// with getter and setter methods for every attribute defined in the dictionary.
23pub struct Generator {
24    /// The name of the trait to be generated (e.g., "Rfc2865Ext").
25    pub trait_name: String,
26    /// Attributes that should be skipped during generation.
27    pub ignored_attributes: Vec<String>,
28    /// Maps attribute names to external crate/module paths if they are defined elsewhere.
29    pub external_attributes: HashMap<String, String>,
30}
31
32impl Generator {
33    /// Creates a new generator instance for the specified module name.
34    ///
35    /// # Arguments
36    /// * `module_name` - The base name for the generated trait (will be converted to PascalCase).
37    pub fn new(trait_name: &str) -> Self {
38        Self {
39            trait_name: trait_name.to_string(),
40            ignored_attributes: Vec::new(),
41            external_attributes: HashMap::new(),
42        }
43    }
44    /// Validates an attribute against RADIUS protocol constraints and generator logic.
45    ///
46    /// This ensures that attributes follow standard OID limits, size constraints,
47    /// and encryption requirements.
48    fn validate_attr(&self, attr: &DictionaryAttribute) -> Result<(), String> {
49        // OID Check: Standard attributes must fit in 1 byte
50        if attr.oid.vendor.is_none() && attr.oid.code > 255 {
51            return Err(format!(
52                "Standard attribute {} OID must be <= 255",
53                attr.name
54            ));
55        }
56
57        // Size Check: Only String/Octets support size constraints
58        if attr.size.is_constrained()
59            && !matches!(
60                attr.attr_type,
61                AttributeType::String | AttributeType::Octets
62            )
63        {
64            return Err(format!(
65                "Size constraint invalid for non-binary type in {}",
66                attr.name
67            ));
68        }
69
70        // Encryption: Only specific flags (User-Password/Tunnel) supported
71        if let Some(enc) = attr.encrypt
72            && enc != 1
73            && enc != 2
74        {
75            return Err(format!(
76                "Unsupported encryption type {} on {}",
77                enc, attr.name
78            ));
79        }
80
81        // Concat: Strict rules (no encryption/tag/size allowed with concat)
82        if attr.concat.unwrap_or(false) {
83            let is_binary = matches!(
84                attr.attr_type,
85                AttributeType::String | AttributeType::Octets
86            );
87            let flags_present =
88                attr.encrypt.is_some() || attr.has_tag.is_some() || attr.size.is_constrained();
89            if !is_binary || flags_present {
90                return Err(format!("Invalid Concat configuration for {}", attr.name));
91            }
92        }
93
94        Ok(())
95    }
96    fn format_code(&self, content: &str) -> String {
97        let child = Command::new("rustfmt")
98            .stdin(Stdio::piped())
99            .stdout(Stdio::piped())
100            .stderr(Stdio::null())
101            .spawn()
102            .ok()
103            .and_then(|mut child| {
104                let mut stdin = child.stdin.take()?;
105                stdin.write_all(content.as_bytes()).ok()?;
106                drop(stdin);
107                let output = child.wait_with_output().ok()?;
108                if output.status.success() {
109                    Some(String::from_utf8_lossy(&output.stdout).to_string())
110                } else {
111                    None
112                }
113            });
114
115        child.unwrap_or_else(|| content.to_string())
116    }
117    /// Generates the Rust source code for the given dictionary.
118    ///
119    /// # Errors
120    /// Returns an error if the generation process fails or if there are severe inconsistencies
121    /// in the dictionary structure.
122    pub fn generate(&self, dict: &Dictionary) -> Result<String, Box<dyn std::error::Error>> {
123        let mut tokens = TokenStream::new();
124        let mut trait_signatures = TokenStream::new();
125        let mut trait_impl_bodies = TokenStream::new();
126
127        let trait_ident = format_ident!("{}Ext", self.trait_name.to_pascal_case());
128        let ignored: HashSet<_> = self.ignored_attributes.iter().collect();
129
130        // 1. Group Values by Attribute Name for easier lookup
131        let mut value_map: HashMap<String, Vec<&DictionaryValue>> = HashMap::new();
132        for val in &dict.values {
133            value_map
134                .entry(val.attribute_name.clone())
135                .or_default()
136                .push(val);
137        }
138
139        // 2. Base Imports
140        tokens.extend(quote! {
141            use std::net::{Ipv4Addr, Ipv6Addr};
142          use abol_core::{packet::Packet, attribute::FromRadiusAttribute, attribute::ToRadiusAttribute};
143            use std::time::SystemTime;
144        });
145
146        // 3. Process Standard Attributes
147        for attr in &dict.attributes {
148            self.process_attribute(
149                attr,
150                &ignored,
151                &value_map,
152                &mut tokens,
153                &mut trait_signatures,
154                &mut trait_impl_bodies,
155            );
156        }
157
158        // 4. Process Vendors and their specific Attributes/Values
159        for vendor in &dict.vendors {
160            let vendor_id = vendor.code;
161            let vendor_const = format_ident!("VENDOR_{}", vendor.name.to_shouty_snake_case());
162
163            tokens.extend(quote! { pub const #vendor_const: u32 = #vendor_id; });
164
165            // Create a specific value map for this vendor
166            let mut vendor_val_map: HashMap<String, Vec<&DictionaryValue>> = HashMap::new();
167            for val in &vendor.values {
168                vendor_val_map
169                    .entry(val.attribute_name.clone())
170                    .or_default()
171                    .push(val);
172            }
173
174            for attr in &vendor.attributes {
175                self.process_attribute(
176                    attr,
177                    &ignored,
178                    &vendor_val_map,
179                    &mut tokens,
180                    &mut trait_signatures,
181                    &mut trait_impl_bodies,
182                );
183            }
184        }
185
186        // 5. Wrap in Trait
187        tokens.extend(quote! {
188            pub trait #trait_ident {
189                #trait_signatures
190            }
191            impl #trait_ident for Packet {
192                #trait_impl_bodies
193            }
194        });
195        let raw_code = tokens.to_string();
196
197        Ok(self.format_code(&raw_code))
198    }
199    /// Internal helper to process a single attribute and generate its types, constants, and methods.
200    fn process_attribute(
201        &self,
202        attr: &DictionaryAttribute,
203        ignored: &HashSet<&String>,
204        value_map: &HashMap<String, Vec<&DictionaryValue>>,
205        tokens: &mut TokenStream,
206        signatures: &mut TokenStream,
207        bodies: &mut TokenStream,
208    ) {
209        if ignored.contains(&attr.name) {
210            return;
211        }
212        if let Err(e) = self.validate_attr(attr) {
213            eprintln!("Skipping {}: {}", attr.name, e);
214            return;
215        }
216
217        // 1. Map Dictionary Type to Rust "Wire" Types
218        let (wire_type, user_get_type, user_set_type, needs_into) = match attr.attr_type {
219            AttributeType::String => (
220                quote! { String },
221                quote! { String },
222                quote! { impl Into<String> },
223                true,
224            ),
225            AttributeType::Integer => (quote! { u32 }, quote! { u32 }, quote! { u32 }, false),
226            AttributeType::IpAddr => (
227                quote! { Ipv4Addr },
228                quote! { Ipv4Addr },
229                quote! { Ipv4Addr },
230                false,
231            ),
232            AttributeType::Ipv6Addr => (
233                quote! { Ipv6Addr },
234                quote! { Ipv6Addr },
235                quote! { Ipv6Addr },
236                false,
237            ),
238            AttributeType::Octets
239            | AttributeType::Ether
240            | AttributeType::ABinary
241            | AttributeType::Vsa => (
242                quote! { Vec<u8> },
243                quote! { Vec<u8> },
244                quote! { impl Into<Vec<u8>> },
245                true, // Needs .into()
246            ),
247            AttributeType::Date => (
248                quote! { SystemTime },
249                quote! { SystemTime },
250                quote! { SystemTime },
251                false,
252            ),
253            AttributeType::Byte => (quote! { u8 }, quote! { u8 }, quote! { u8 }, false),
254            AttributeType::Short => (quote! { u16 }, quote! { u16 }, quote! { u16 }, false),
255            AttributeType::Signed => (quote! { i32 }, quote! { i32 }, quote! { i32 }, false),
256            AttributeType::Tlv => (quote! { Tlv }, quote! { Tlv }, quote! { Tlv }, false),
257            AttributeType::Ipv4Prefix | AttributeType::Ipv6Prefix => (
258                quote! { Vec<u8> },
259                quote! { Vec<u8> },
260                quote! { Vec<u8> },
261                false,
262            ),
263            AttributeType::Ifid | AttributeType::InterfaceId => {
264                (quote! { u64 }, quote! { u64 }, quote! { u64 }, false)
265            }
266            _ => return,
267        };
268
269        let has_values = value_map.contains_key(&attr.name);
270        let normalized_name = attr.name.replace("-", "_").to_lowercase();
271        let enum_name = format_ident!("{}", normalized_name.to_upper_camel_case());
272        let is_external = self.external_attributes.contains_key(&attr.name);
273        let const_type_ident = format_ident!("{}_TYPE", normalized_name.to_shouty_snake_case());
274        // 2. Determine Final Method Types (Override if Enum exists)
275        let (final_get_type, final_set_type, final_needs_into) = if has_values {
276            (quote! { #enum_name }, quote! { #enum_name }, true)
277        } else {
278            (user_get_type, user_set_type, needs_into)
279        };
280
281        // 3. Generate Type Constants (Always generated)
282        if !is_external {
283            let code = attr.oid.code as u8;
284            tokens.extend(quote! { pub const #const_type_ident: u8 = #code; });
285        }
286
287        // 4. Generate the Enum definition if values exist
288        if let Some(values) = value_map.get(&attr.name) {
289            let mut variants = Vec::new();
290            let mut from_arms = Vec::new();
291            let mut to_arms = Vec::new();
292            let mut seen_values = HashSet::new();
293            for val in values {
294                let variant_ident = format_ident!("{}", val.name.to_upper_camel_case());
295                let val_lit = val.value as u32;
296
297                variants.push(quote! { #variant_ident });
298                if seen_values.insert(val_lit) {
299                    from_arms.push(quote! { #val_lit => Self::#variant_ident });
300                }
301                to_arms.push(quote! { #enum_name::#variant_ident => #val_lit });
302            }
303
304            tokens.extend(quote! {
305                #[derive(Debug, Clone, Copy, PartialEq, Eq)]
306                #[repr(u32)]
307                pub enum #enum_name {
308                    #(#variants),*,
309                    Unknown(u32),
310                }
311
312                impl From<u32> for #enum_name {
313                    fn from(v: u32) -> Self {
314                        match v {
315                            #(#from_arms),*,
316                            other => Self::Unknown(other),
317                        }
318                    }
319                }
320
321                impl From<#enum_name> for u32 {
322                    fn from(e: #enum_name) -> Self {
323                        match e {
324                            #(#to_arms),*,
325                            #enum_name::Unknown(v) => v,
326                        }
327                    }
328                }
329            });
330        }
331
332        let get_ident = format_ident!("get_{}", normalized_name.to_snake_case());
333        let set_ident = format_ident!("set_{}", normalized_name.to_snake_case());
334
335        // 5. Generate Trait Signatures
336        signatures.extend(quote! {
337            fn #get_ident(&self) -> Option<#final_get_type>;
338            fn #set_ident(&mut self, value: #final_set_type);
339        });
340
341        // 6. Generate Validation Logic
342        let size_validation = match attr.size {
343            SizeFlag::Exact(n) => quote! {
344                if ToRadiusAttribute::to_bytes(&wire_val).len() != #n as usize {
345                    return;
346                }
347            },
348            SizeFlag::Range(min, max) => quote! {
349                let len = ToRadiusAttribute::to_bytes(&wire_val).len();
350                if len < #min as usize || len > #max as usize {
351                    return;
352                }
353            },
354            SizeFlag::Any => quote! {},
355        };
356
357        // 7. Generate Method Bodies
358        let is_vsa = attr.oid.vendor.is_some();
359        let v_const = if let Some(vid) = attr.oid.vendor {
360            format_ident!("VENDOR_{}", vid)
361        } else {
362            format_ident!("UNUSED")
363        };
364
365        let (method, args) = if is_vsa {
366            (
367                quote!(get_vsa_attribute_as),
368                quote!(#v_const, #const_type_ident),
369            )
370        } else {
371            (quote!(get_attribute_as), quote!(#const_type_ident))
372        };
373
374        let (target_type, map_clause) = if has_values {
375            (quote!(u32), quote!(.map(#enum_name::from)))
376        } else {
377            (quote!(#wire_type), quote!())
378        };
379
380        let body_get = quote! {
381            self.#method::<#target_type>(#args) #map_clause
382        };
383
384        let (set_method, set_args) = if is_vsa {
385            (
386                quote!(set_vsa_attribute_as),
387                quote!(#v_const, #const_type_ident),
388            )
389        } else {
390            (quote!(set_attribute_as), quote!(#const_type_ident))
391        };
392
393        let value_type = if has_values {
394            quote!(u32)
395        } else {
396            quote!(#wire_type)
397        };
398
399        let body_set = if final_needs_into {
400            quote! {
401                let wire_val: #value_type = value.into();
402                #size_validation
403                self.#set_method::<#value_type>(#set_args, wire_val);
404            }
405        } else {
406            quote! {
407                let wire_val = value; // Direct assignment, no .into()
408                #size_validation
409                self.#set_method::<#value_type>(#set_args, wire_val);
410            }
411        };
412
413        bodies.extend(quote! {
414            fn #get_ident(&self) -> Option<#final_get_type> { #body_get }
415            fn #set_ident(&mut self, value: #final_set_type) { #body_set }
416        });
417    }
418}
419#[cfg(test)]
420mod tests {
421    use abol_parser::dictionary;
422
423    use super::*;
424
425    #[test]
426    fn test_generator_new() {
427        let generator = Generator::new("Rfc2865Ext");
428        assert_eq!(generator.trait_name, "Rfc2865Ext");
429        assert!(generator.ignored_attributes.is_empty());
430    }
431
432    #[test]
433    fn test_validate_attr_oid_overflow() {
434        let generator = Generator::new("test");
435        let attr = DictionaryAttribute {
436            name: "Test-Attr".to_string(),
437            oid: dictionary::Oid {
438                vendor: None,
439                code: 256,
440            },
441            attr_type: AttributeType::String,
442            size: dictionary::SizeFlag::Any,
443            encrypt: None,
444            has_tag: None,
445            concat: None,
446        };
447        // Standard RADIUS OID is 1 byte (0-255)
448
449        assert!(generator.validate_attr(&attr).is_err());
450    }
451
452    #[test]
453    fn test_validate_attr_size_constraint_type() {
454        let generator = Generator::new("test");
455        let mut attr = DictionaryAttribute {
456            name: "Test-Attr".to_string(),
457            oid: dictionary::Oid {
458                vendor: None,
459                code: 100,
460            },
461            attr_type: AttributeType::Integer,
462            size: dictionary::SizeFlag::Range(1, 10),
463            encrypt: None,
464            has_tag: None,
465            concat: None,
466        };
467
468        // Integer types cannot have size constraints in RADIUS
469        assert!(generator.validate_attr(&attr).is_err());
470
471        attr.attr_type = AttributeType::String;
472        assert!(generator.validate_attr(&attr).is_ok());
473    }
474
475    #[test]
476    fn test_process_attribute_generation() {
477        let generator = Generator::new("Rfc2865");
478        let mut tokens = TokenStream::new();
479        let mut signatures = TokenStream::new();
480        let mut bodies = TokenStream::new();
481
482        let attr = DictionaryAttribute {
483            name: "User-Name".to_string(),
484            oid: dictionary::Oid {
485                vendor: None,
486                code: 1,
487            },
488            attr_type: AttributeType::String,
489            size: dictionary::SizeFlag::Any,
490            encrypt: None,
491            has_tag: None,
492            concat: None,
493        };
494
495        generator.process_attribute(
496            &attr,
497            &HashSet::new(),
498            &HashMap::new(),
499            &mut tokens,
500            &mut signatures,
501            &mut bodies,
502        );
503
504        let sig_str = signatures.to_string();
505        assert!(sig_str.contains("get_user_name"));
506        assert!(sig_str.contains("set_user_name"));
507    }
508
509    #[test]
510    fn test_ignored_attributes() {
511        let mut generator = Generator::new("test");
512        generator.ignored_attributes.push("Password".to_string());
513
514        let ignored: HashSet<_> = generator.ignored_attributes.iter().collect();
515        let mut tokens = TokenStream::new();
516        let mut signatures = TokenStream::new();
517        let mut bodies = TokenStream::new();
518
519        let attr = DictionaryAttribute {
520            name: "Password".to_string(),
521            oid: dictionary::Oid {
522                vendor: None,
523                code: 2,
524            },
525            attr_type: AttributeType::String,
526            size: dictionary::SizeFlag::Any,
527            encrypt: None,
528            has_tag: None,
529            concat: None,
530        };
531
532        generator.process_attribute(
533            &attr,
534            &ignored,
535            &HashMap::new(),
536            &mut tokens,
537            &mut signatures,
538            &mut bodies,
539        );
540
541        assert!(signatures.is_empty());
542    }
543}