1use abol_parser::dictionary::{
2 AttributeType, Dictionary, DictionaryAttribute, DictionaryValue, SizeFlag,
3};
4use heck::{ToPascalCase, ToShoutySnakeCase, ToSnakeCase, ToUpperCamelCase};
5use proc_macro2::TokenStream;
6use quote::{format_ident, quote};
7use std::collections::{HashMap, HashSet};
8use std::io::Write;
9use std::process::{Command, Stdio};
10pub mod aruba;
11pub mod microsoft;
12pub mod rfc2865;
13pub mod rfc2866;
14pub mod rfc2869;
15pub mod rfc3576;
16pub mod rfc6911;
17pub mod wispr;
18
19pub struct Generator {
24 pub trait_name: String,
26 pub ignored_attributes: Vec<String>,
28 pub external_attributes: HashMap<String, String>,
30}
31
32impl Generator {
33 pub fn new(trait_name: &str) -> Self {
38 Self {
39 trait_name: trait_name.to_string(),
40 ignored_attributes: Vec::new(),
41 external_attributes: HashMap::new(),
42 }
43 }
44 fn validate_attr(&self, attr: &DictionaryAttribute) -> Result<(), String> {
49 if attr.oid.vendor.is_none() && attr.oid.code > 255 {
51 return Err(format!(
52 "Standard attribute {} OID must be <= 255",
53 attr.name
54 ));
55 }
56
57 if attr.size.is_constrained()
59 && !matches!(
60 attr.attr_type,
61 AttributeType::String | AttributeType::Octets
62 )
63 {
64 return Err(format!(
65 "Size constraint invalid for non-binary type in {}",
66 attr.name
67 ));
68 }
69
70 if let Some(enc) = attr.encrypt
72 && enc != 1
73 && enc != 2
74 {
75 return Err(format!(
76 "Unsupported encryption type {} on {}",
77 enc, attr.name
78 ));
79 }
80
81 if attr.concat.unwrap_or(false) {
83 let is_binary = matches!(
84 attr.attr_type,
85 AttributeType::String | AttributeType::Octets
86 );
87 let flags_present =
88 attr.encrypt.is_some() || attr.has_tag.is_some() || attr.size.is_constrained();
89 if !is_binary || flags_present {
90 return Err(format!("Invalid Concat configuration for {}", attr.name));
91 }
92 }
93
94 Ok(())
95 }
96 fn format_code(&self, content: &str) -> String {
97 let child = Command::new("rustfmt")
98 .stdin(Stdio::piped())
99 .stdout(Stdio::piped())
100 .stderr(Stdio::null())
101 .spawn()
102 .ok()
103 .and_then(|mut child| {
104 let mut stdin = child.stdin.take()?;
105 stdin.write_all(content.as_bytes()).ok()?;
106 drop(stdin);
107 let output = child.wait_with_output().ok()?;
108 if output.status.success() {
109 Some(String::from_utf8_lossy(&output.stdout).to_string())
110 } else {
111 None
112 }
113 });
114
115 child.unwrap_or_else(|| content.to_string())
116 }
117 pub fn generate(&self, dict: &Dictionary) -> Result<String, Box<dyn std::error::Error>> {
123 let mut tokens = TokenStream::new();
124 let mut trait_signatures = TokenStream::new();
125 let mut trait_impl_bodies = TokenStream::new();
126
127 let trait_ident = format_ident!("{}Ext", self.trait_name.to_pascal_case());
128 let ignored: HashSet<_> = self.ignored_attributes.iter().collect();
129
130 let mut value_map: HashMap<String, Vec<&DictionaryValue>> = HashMap::new();
132 for val in &dict.values {
133 value_map
134 .entry(val.attribute_name.clone())
135 .or_default()
136 .push(val);
137 }
138
139 tokens.extend(quote! {
141 use std::net::{Ipv4Addr, Ipv6Addr};
142 use abol_core::{packet::Packet, attribute::FromRadiusAttribute, attribute::ToRadiusAttribute};
143 use std::time::SystemTime;
144 });
145
146 for attr in &dict.attributes {
148 self.process_attribute(
149 attr,
150 &ignored,
151 &value_map,
152 &mut tokens,
153 &mut trait_signatures,
154 &mut trait_impl_bodies,
155 );
156 }
157
158 for vendor in &dict.vendors {
160 let vendor_id = vendor.code;
161 let vendor_const = format_ident!("VENDOR_{}", vendor.name.to_shouty_snake_case());
162
163 tokens.extend(quote! { pub const #vendor_const: u32 = #vendor_id; });
164
165 let mut vendor_val_map: HashMap<String, Vec<&DictionaryValue>> = HashMap::new();
167 for val in &vendor.values {
168 vendor_val_map
169 .entry(val.attribute_name.clone())
170 .or_default()
171 .push(val);
172 }
173
174 for attr in &vendor.attributes {
175 self.process_attribute(
176 attr,
177 &ignored,
178 &vendor_val_map,
179 &mut tokens,
180 &mut trait_signatures,
181 &mut trait_impl_bodies,
182 );
183 }
184 }
185
186 tokens.extend(quote! {
188 pub trait #trait_ident {
189 #trait_signatures
190 }
191 impl #trait_ident for Packet {
192 #trait_impl_bodies
193 }
194 });
195 let raw_code = tokens.to_string();
196
197 Ok(self.format_code(&raw_code))
198 }
199 fn process_attribute(
201 &self,
202 attr: &DictionaryAttribute,
203 ignored: &HashSet<&String>,
204 value_map: &HashMap<String, Vec<&DictionaryValue>>,
205 tokens: &mut TokenStream,
206 signatures: &mut TokenStream,
207 bodies: &mut TokenStream,
208 ) {
209 if ignored.contains(&attr.name) {
210 return;
211 }
212 if let Err(e) = self.validate_attr(attr) {
213 eprintln!("Skipping {}: {}", attr.name, e);
214 return;
215 }
216
217 let (wire_type, user_get_type, user_set_type, needs_into) = match attr.attr_type {
219 AttributeType::String => (
220 quote! { String },
221 quote! { String },
222 quote! { impl Into<String> },
223 true,
224 ),
225 AttributeType::Integer => (quote! { u32 }, quote! { u32 }, quote! { u32 }, false),
226 AttributeType::IpAddr => (
227 quote! { Ipv4Addr },
228 quote! { Ipv4Addr },
229 quote! { Ipv4Addr },
230 false,
231 ),
232 AttributeType::Ipv6Addr => (
233 quote! { Ipv6Addr },
234 quote! { Ipv6Addr },
235 quote! { Ipv6Addr },
236 false,
237 ),
238 AttributeType::Octets
239 | AttributeType::Ether
240 | AttributeType::ABinary
241 | AttributeType::Vsa => (
242 quote! { Vec<u8> },
243 quote! { Vec<u8> },
244 quote! { impl Into<Vec<u8>> },
245 true, ),
247 AttributeType::Date => (
248 quote! { SystemTime },
249 quote! { SystemTime },
250 quote! { SystemTime },
251 false,
252 ),
253 AttributeType::Byte => (quote! { u8 }, quote! { u8 }, quote! { u8 }, false),
254 AttributeType::Short => (quote! { u16 }, quote! { u16 }, quote! { u16 }, false),
255 AttributeType::Signed => (quote! { i32 }, quote! { i32 }, quote! { i32 }, false),
256 AttributeType::Tlv => (quote! { Tlv }, quote! { Tlv }, quote! { Tlv }, false),
257 AttributeType::Ipv4Prefix | AttributeType::Ipv6Prefix => (
258 quote! { Vec<u8> },
259 quote! { Vec<u8> },
260 quote! { Vec<u8> },
261 false,
262 ),
263 AttributeType::Ifid | AttributeType::InterfaceId => {
264 (quote! { u64 }, quote! { u64 }, quote! { u64 }, false)
265 }
266 _ => return,
267 };
268
269 let has_values = value_map.contains_key(&attr.name);
270 let normalized_name = attr.name.replace("-", "_").to_lowercase();
271 let enum_name = format_ident!("{}", normalized_name.to_upper_camel_case());
272 let is_external = self.external_attributes.contains_key(&attr.name);
273 let const_type_ident = format_ident!("{}_TYPE", normalized_name.to_shouty_snake_case());
274 let (final_get_type, final_set_type, final_needs_into) = if has_values {
276 (quote! { #enum_name }, quote! { #enum_name }, true)
277 } else {
278 (user_get_type, user_set_type, needs_into)
279 };
280
281 if !is_external {
283 let code = attr.oid.code as u8;
284 tokens.extend(quote! { pub const #const_type_ident: u8 = #code; });
285 }
286
287 if let Some(values) = value_map.get(&attr.name) {
289 let mut variants = Vec::new();
290 let mut from_arms = Vec::new();
291 let mut to_arms = Vec::new();
292 let mut seen_values = HashSet::new();
293 for val in values {
294 let variant_ident = format_ident!("{}", val.name.to_upper_camel_case());
295 let val_lit = val.value as u32;
296
297 variants.push(quote! { #variant_ident });
298 if seen_values.insert(val_lit) {
299 from_arms.push(quote! { #val_lit => Self::#variant_ident });
300 }
301 to_arms.push(quote! { #enum_name::#variant_ident => #val_lit });
302 }
303
304 tokens.extend(quote! {
305 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
306 #[repr(u32)]
307 pub enum #enum_name {
308 #(#variants),*,
309 Unknown(u32),
310 }
311
312 impl From<u32> for #enum_name {
313 fn from(v: u32) -> Self {
314 match v {
315 #(#from_arms),*,
316 other => Self::Unknown(other),
317 }
318 }
319 }
320
321 impl From<#enum_name> for u32 {
322 fn from(e: #enum_name) -> Self {
323 match e {
324 #(#to_arms),*,
325 #enum_name::Unknown(v) => v,
326 }
327 }
328 }
329 });
330 }
331
332 let get_ident = format_ident!("get_{}", normalized_name.to_snake_case());
333 let set_ident = format_ident!("set_{}", normalized_name.to_snake_case());
334
335 signatures.extend(quote! {
337 fn #get_ident(&self) -> Option<#final_get_type>;
338 fn #set_ident(&mut self, value: #final_set_type);
339 });
340
341 let size_validation = match attr.size {
343 SizeFlag::Exact(n) => quote! {
344 if ToRadiusAttribute::to_bytes(&wire_val).len() != #n as usize {
345 return;
346 }
347 },
348 SizeFlag::Range(min, max) => quote! {
349 let len = ToRadiusAttribute::to_bytes(&wire_val).len();
350 if len < #min as usize || len > #max as usize {
351 return;
352 }
353 },
354 SizeFlag::Any => quote! {},
355 };
356
357 let is_vsa = attr.oid.vendor.is_some();
359 let v_const = if let Some(vid) = attr.oid.vendor {
360 format_ident!("VENDOR_{}", vid)
361 } else {
362 format_ident!("UNUSED")
363 };
364
365 let (method, args) = if is_vsa {
366 (
367 quote!(get_vsa_attribute_as),
368 quote!(#v_const, #const_type_ident),
369 )
370 } else {
371 (quote!(get_attribute_as), quote!(#const_type_ident))
372 };
373
374 let (target_type, map_clause) = if has_values {
375 (quote!(u32), quote!(.map(#enum_name::from)))
376 } else {
377 (quote!(#wire_type), quote!())
378 };
379
380 let body_get = quote! {
381 self.#method::<#target_type>(#args) #map_clause
382 };
383
384 let (set_method, set_args) = if is_vsa {
385 (
386 quote!(set_vsa_attribute_as),
387 quote!(#v_const, #const_type_ident),
388 )
389 } else {
390 (quote!(set_attribute_as), quote!(#const_type_ident))
391 };
392
393 let value_type = if has_values {
394 quote!(u32)
395 } else {
396 quote!(#wire_type)
397 };
398
399 let body_set = if final_needs_into {
400 quote! {
401 let wire_val: #value_type = value.into();
402 #size_validation
403 self.#set_method::<#value_type>(#set_args, wire_val);
404 }
405 } else {
406 quote! {
407 let wire_val = value; #size_validation
409 self.#set_method::<#value_type>(#set_args, wire_val);
410 }
411 };
412
413 bodies.extend(quote! {
414 fn #get_ident(&self) -> Option<#final_get_type> { #body_get }
415 fn #set_ident(&mut self, value: #final_set_type) { #body_set }
416 });
417 }
418}
419#[cfg(test)]
420mod tests {
421 use abol_parser::dictionary;
422
423 use super::*;
424
425 #[test]
426 fn test_generator_new() {
427 let generator = Generator::new("Rfc2865Ext");
428 assert_eq!(generator.trait_name, "Rfc2865Ext");
429 assert!(generator.ignored_attributes.is_empty());
430 }
431
432 #[test]
433 fn test_validate_attr_oid_overflow() {
434 let generator = Generator::new("test");
435 let attr = DictionaryAttribute {
436 name: "Test-Attr".to_string(),
437 oid: dictionary::Oid {
438 vendor: None,
439 code: 256,
440 },
441 attr_type: AttributeType::String,
442 size: dictionary::SizeFlag::Any,
443 encrypt: None,
444 has_tag: None,
445 concat: None,
446 };
447 assert!(generator.validate_attr(&attr).is_err());
450 }
451
452 #[test]
453 fn test_validate_attr_size_constraint_type() {
454 let generator = Generator::new("test");
455 let mut attr = DictionaryAttribute {
456 name: "Test-Attr".to_string(),
457 oid: dictionary::Oid {
458 vendor: None,
459 code: 100,
460 },
461 attr_type: AttributeType::Integer,
462 size: dictionary::SizeFlag::Range(1, 10),
463 encrypt: None,
464 has_tag: None,
465 concat: None,
466 };
467
468 assert!(generator.validate_attr(&attr).is_err());
470
471 attr.attr_type = AttributeType::String;
472 assert!(generator.validate_attr(&attr).is_ok());
473 }
474
475 #[test]
476 fn test_process_attribute_generation() {
477 let generator = Generator::new("Rfc2865");
478 let mut tokens = TokenStream::new();
479 let mut signatures = TokenStream::new();
480 let mut bodies = TokenStream::new();
481
482 let attr = DictionaryAttribute {
483 name: "User-Name".to_string(),
484 oid: dictionary::Oid {
485 vendor: None,
486 code: 1,
487 },
488 attr_type: AttributeType::String,
489 size: dictionary::SizeFlag::Any,
490 encrypt: None,
491 has_tag: None,
492 concat: None,
493 };
494
495 generator.process_attribute(
496 &attr,
497 &HashSet::new(),
498 &HashMap::new(),
499 &mut tokens,
500 &mut signatures,
501 &mut bodies,
502 );
503
504 let sig_str = signatures.to_string();
505 assert!(sig_str.contains("get_user_name"));
506 assert!(sig_str.contains("set_user_name"));
507 }
508
509 #[test]
510 fn test_ignored_attributes() {
511 let mut generator = Generator::new("test");
512 generator.ignored_attributes.push("Password".to_string());
513
514 let ignored: HashSet<_> = generator.ignored_attributes.iter().collect();
515 let mut tokens = TokenStream::new();
516 let mut signatures = TokenStream::new();
517 let mut bodies = TokenStream::new();
518
519 let attr = DictionaryAttribute {
520 name: "Password".to_string(),
521 oid: dictionary::Oid {
522 vendor: None,
523 code: 2,
524 },
525 attr_type: AttributeType::String,
526 size: dictionary::SizeFlag::Any,
527 encrypt: None,
528 has_tag: None,
529 concat: None,
530 };
531
532 generator.process_attribute(
533 &attr,
534 &ignored,
535 &HashMap::new(),
536 &mut tokens,
537 &mut signatures,
538 &mut bodies,
539 );
540
541 assert!(signatures.is_empty());
542 }
543}