1use crate::ast::{Definition, Pattern, QName, Schema};
4use serde::Deserialize;
5use std::collections::HashMap;
6use std::fmt::Write;
7
8#[derive(Debug, Clone, Default, Deserialize)]
10pub struct ModuleMappings {
11 #[serde(default)]
13 pub types: HashMap<String, String>,
14 #[serde(default)]
16 pub fields: HashMap<String, String>,
17 #[serde(default)]
19 pub variants: HashMap<String, String>,
20 #[serde(default)]
23 pub elements: HashMap<String, String>,
24}
25
26#[derive(Debug, Clone, Default, Deserialize)]
28pub struct NameMappings {
29 #[serde(default)]
31 pub shared: ModuleMappings,
32 #[serde(default)]
34 pub sml: ModuleMappings,
35 #[serde(default)]
37 pub wml: ModuleMappings,
38 #[serde(default)]
40 pub pml: ModuleMappings,
41 #[serde(default)]
43 pub dml: ModuleMappings,
44}
45
46impl NameMappings {
47 pub fn from_yaml(yaml: &str) -> Result<Self, serde_yaml::Error> {
49 serde_yaml::from_str(yaml)
50 }
51
52 pub fn from_yaml_file(path: &std::path::Path) -> Result<Self, Box<dyn std::error::Error>> {
54 let contents = std::fs::read_to_string(path)?;
55 Ok(Self::from_yaml(&contents)?)
56 }
57
58 pub fn for_module(&self, module: &str) -> &ModuleMappings {
60 match module {
61 "sml" => &self.sml,
62 "wml" => &self.wml,
63 "pml" => &self.pml,
64 "dml" => &self.dml,
65 _ => &self.shared,
66 }
67 }
68
69 pub fn resolve_type(&self, module: &str, spec_name: &str) -> Option<&str> {
71 self.for_module(module)
72 .types
73 .get(spec_name)
74 .or_else(|| self.shared.types.get(spec_name))
75 .map(|s| s.as_str())
76 }
77
78 pub fn resolve_field(&self, module: &str, spec_name: &str) -> Option<&str> {
80 self.for_module(module)
81 .fields
82 .get(spec_name)
83 .or_else(|| self.shared.fields.get(spec_name))
84 .map(|s| s.as_str())
85 }
86
87 pub fn resolve_variant(&self, module: &str, spec_name: &str) -> Option<&str> {
89 self.for_module(module)
90 .variants
91 .get(spec_name)
92 .or_else(|| self.shared.variants.get(spec_name))
93 .map(|s| s.as_str())
94 }
95
96 pub fn resolve_element(&self, module: &str, rust_type_name: &str) -> Option<&str> {
99 self.for_module(module)
100 .elements
101 .get(rust_type_name)
102 .or_else(|| self.shared.elements.get(rust_type_name))
103 .map(|s| s.as_str())
104 }
105}
106
107pub type ElementFeatures = HashMap<String, Vec<String>>;
114
115pub type ModuleFeatures = HashMap<String, ElementFeatures>;
118
119#[derive(Debug, Clone, Default, Deserialize)]
121pub struct FeatureMappings {
122 #[serde(default)]
124 pub sml: ModuleFeatures,
125 #[serde(default)]
127 pub wml: ModuleFeatures,
128 #[serde(default)]
130 pub pml: ModuleFeatures,
131 #[serde(default)]
133 pub dml: ModuleFeatures,
134}
135
136impl FeatureMappings {
137 pub fn from_yaml(yaml: &str) -> Result<Self, serde_yaml::Error> {
139 serde_yaml::from_str(yaml)
140 }
141
142 pub fn from_yaml_file(path: &std::path::Path) -> Result<Self, Box<dyn std::error::Error>> {
144 let contents = std::fs::read_to_string(path)?;
145 Ok(Self::from_yaml(&contents)?)
146 }
147
148 pub fn for_module(&self, module: &str) -> &ModuleFeatures {
150 match module {
151 "sml" => &self.sml,
152 "wml" => &self.wml,
153 "pml" => &self.pml,
154 "dml" => &self.dml,
155 _ => &self.sml, }
157 }
158
159 pub fn get_tags(&self, module: &str, element: &str, field: &str) -> Option<&[String]> {
163 self.for_module(module)
164 .get(element)
165 .and_then(|elem| elem.get(field).or_else(|| elem.get("*")))
166 .map(|v| v.as_slice())
167 }
168
169 pub fn is_core(&self, module: &str, element: &str, field: &str) -> bool {
171 self.get_tags(module, element, field)
172 .is_some_and(|tags| tags.iter().any(|t| t == "core"))
173 }
174
175 pub fn primary_feature(&self, module: &str, element: &str, field: &str) -> Option<&str> {
178 self.get_tags(module, element, field).and_then(|tags| {
179 if tags.iter().any(|t| t == "core") {
181 return None;
182 }
183 tags.first().map(|s| s.as_str())
185 })
186 }
187}
188
189#[derive(Debug, Clone, Default)]
191pub struct CodegenConfig {
192 pub strip_prefix: Option<String>,
194 pub module_name: String,
196 pub name_mappings: Option<NameMappings>,
198 pub feature_mappings: Option<FeatureMappings>,
200 pub warn_unmapped: bool,
202 pub xml_serialize_prefix: Option<String>,
206 pub cross_crate_imports: Vec<String>,
210 pub cross_crate_type_prefix: HashMap<String, (String, String)>,
216}
217
218pub fn generate(schema: &Schema, config: &CodegenConfig) -> String {
220 let mut g = Generator::new(schema, config);
221 g.run()
222}
223
224struct Generator<'a> {
225 schema: &'a Schema,
226 config: &'a CodegenConfig,
227 output: String,
228 definitions: HashMap<&'a str, &'a Pattern>,
230 generated_names: std::collections::HashSet<String>,
232}
233
234impl<'a> Generator<'a> {
235 fn new(schema: &'a Schema, config: &'a CodegenConfig) -> Self {
236 let definitions: HashMap<&str, &Pattern> = schema
237 .definitions
238 .iter()
239 .map(|d| (d.name.as_str(), &d.pattern))
240 .collect();
241
242 Self {
243 schema,
244 config,
245 output: String::new(),
246 definitions,
247 generated_names: std::collections::HashSet::new(),
248 }
249 }
250
251 fn run(&mut self) -> String {
252 self.write_header();
253
254 let mut simple_types = Vec::new();
256 let mut element_groups = Vec::new();
257 let mut complex_types = Vec::new();
258
259 for def in &self.schema.definitions {
260 if def.name.contains("_ST_") || self.is_simple_type(&def.pattern) {
261 simple_types.push(def);
262 } else if def.name.contains("_EG_") && self.is_element_choice(&def.pattern) {
263 element_groups.push(def);
264 } else if self.is_inline_attribute_ref(&def.name, &def.pattern) {
265 continue;
268 } else {
269 complex_types.push(def);
270 }
271 }
272
273 for def in &simple_types {
275 let rust_name = self.to_rust_type_name(&def.name);
276 if !self.generated_names.insert(rust_name) {
277 continue; }
279 if let Some(code) = self.gen_simple_type(def) {
280 self.output.push_str(&code);
281 self.output.push('\n');
282 }
283 }
284
285 for def in &element_groups {
287 let rust_name = self.to_rust_type_name(&def.name);
288 if !self.generated_names.insert(rust_name) {
289 continue;
290 }
291 if let Some(code) = self.gen_element_group(def) {
292 self.output.push_str(&code);
293 self.output.push('\n');
294 }
295 }
296
297 for def in &complex_types {
299 let rust_name = self.to_rust_type_name(&def.name);
300 if !self.generated_names.insert(rust_name) {
301 continue;
302 }
303 if let Some(code) = self.gen_complex_type(def) {
304 self.output.push_str(&code);
305 self.output.push('\n');
306 }
307 }
308
309 std::mem::take(&mut self.output)
310 }
311
312 fn write_header(&mut self) {
313 writeln!(self.output, "// Generated from ECMA-376 RELAX NG schema.").unwrap();
314 writeln!(self.output, "// Do not edit manually.").unwrap();
315 writeln!(self.output).unwrap();
316 writeln!(self.output, "use serde::{{Deserialize, Serialize}};").unwrap();
317 writeln!(self.output).unwrap();
318
319 if !self.schema.namespaces.is_empty() {
321 writeln!(self.output, "/// XML namespace URIs used in this schema.").unwrap();
322 writeln!(self.output, "pub mod ns {{").unwrap();
323
324 for ns in &self.schema.namespaces {
325 if ns.prefix.is_empty() {
327 continue;
328 }
329 let const_name = ns.prefix.to_uppercase();
330 if ns.is_default {
331 writeln!(
332 self.output,
333 " /// Default namespace (prefix: {})",
334 ns.prefix
335 )
336 .unwrap();
337 } else {
338 writeln!(self.output, " /// Namespace prefix: {}", ns.prefix).unwrap();
339 }
340 writeln!(
341 self.output,
342 " pub const {}: &str = \"{}\";",
343 const_name, ns.uri
344 )
345 .unwrap();
346 }
347
348 writeln!(self.output, "}}").unwrap();
349 writeln!(self.output).unwrap();
350 }
351 }
352
353 fn is_simple_type(&self, pattern: &Pattern) -> bool {
354 match pattern {
355 Pattern::Choice(variants) => variants
356 .iter()
357 .all(|v| matches!(v, Pattern::StringLiteral(_))),
358 Pattern::StringLiteral(_) => true,
359 Pattern::Datatype { .. } => true,
360 Pattern::List(_) => true, Pattern::Ref(name) => {
362 self.definitions
364 .get(name.as_str())
365 .is_some_and(|p| self.is_simple_type(p))
366 }
367 _ => false,
368 }
369 }
370
371 fn is_inline_attribute_ref(&self, name: &str, pattern: &Pattern) -> bool {
375 !name.contains("_CT_") && matches!(pattern, Pattern::Attribute { .. })
378 }
379
380 fn is_string_type(&self, pattern: &Pattern) -> bool {
382 match pattern {
383 Pattern::Datatype { library, name, .. } => {
384 library == "xsd" && (name == "string" || name == "token" || name == "NCName")
385 }
386 Pattern::Ref(name) => {
387 self.definitions
389 .get(name.as_str())
390 .is_some_and(|p| self.is_string_type(p))
391 }
392 _ => false,
393 }
394 }
395
396 fn is_element_choice(&self, pattern: &Pattern) -> bool {
398 match pattern {
399 Pattern::Choice(variants) => {
400 variants.iter().any(Self::is_direct_element_variant)
403 }
404 _ => false,
405 }
406 }
407
408 fn is_direct_element_variant(pattern: &Pattern) -> bool {
410 match pattern {
411 Pattern::Element { .. } => true,
412 Pattern::Optional(inner) | Pattern::ZeroOrMore(inner) | Pattern::OneOrMore(inner) => {
413 Self::is_direct_element_variant(inner)
414 }
415 _ => false,
416 }
417 }
418
419 fn gen_simple_type(&self, def: &Definition) -> Option<String> {
420 let rust_name = self.to_rust_type_name(&def.name);
421
422 match &def.pattern {
423 Pattern::Choice(variants) => {
424 let string_variants: Vec<_> = variants
425 .iter()
426 .filter_map(|v| match v {
427 Pattern::StringLiteral(s) => Some(s.as_str()),
428 _ => None,
429 })
430 .collect();
431
432 if !string_variants.is_empty() {
433 let mut seen_variants = std::collections::HashSet::new();
435 let dedup_variants: Vec<_> = string_variants
436 .iter()
437 .filter(|v| {
438 let name = self.to_rust_variant_name(v);
439 seen_variants.insert(name)
440 })
441 .copied()
442 .collect();
443
444 let mut code = String::new();
446 writeln!(
447 code,
448 "#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]"
449 )
450 .unwrap();
451 writeln!(code, "pub enum {} {{", rust_name).unwrap();
452
453 for variant in &dedup_variants {
454 let variant_name = self.to_rust_variant_name(variant);
455 writeln!(code, " #[serde(rename = \"{}\")]", variant).unwrap();
457 writeln!(code, " {},", variant_name).unwrap();
458 }
459
460 writeln!(code, "}}").unwrap();
461 writeln!(code).unwrap();
462
463 writeln!(code, "impl std::fmt::Display for {} {{", rust_name).unwrap();
465 writeln!(
466 code,
467 " fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {{"
468 )
469 .unwrap();
470 writeln!(code, " match self {{").unwrap();
471 for variant in &dedup_variants {
472 let variant_name = self.to_rust_variant_name(variant);
473 writeln!(
474 code,
475 " Self::{} => write!(f, \"{}\"),",
476 variant_name, variant
477 )
478 .unwrap();
479 }
480 writeln!(code, " }}").unwrap();
481 writeln!(code, " }}").unwrap();
482 writeln!(code, "}}").unwrap();
483 writeln!(code).unwrap();
484
485 writeln!(code, "impl std::str::FromStr for {} {{", rust_name).unwrap();
487 writeln!(code, " type Err = String;").unwrap();
488 writeln!(code).unwrap();
489 writeln!(
490 code,
491 " fn from_str(s: &str) -> Result<Self, Self::Err> {{"
492 )
493 .unwrap();
494 writeln!(code, " match s {{").unwrap();
495 for variant in &string_variants {
497 let variant_name = self.to_rust_variant_name(variant);
498 writeln!(
499 code,
500 " \"{}\" => Ok(Self::{}),",
501 variant, variant_name
502 )
503 .unwrap();
504 }
505 writeln!(
506 code,
507 " _ => Err(format!(\"unknown {} value: {{}}\", s)),",
508 rust_name
509 )
510 .unwrap();
511 writeln!(code, " }}").unwrap();
512 writeln!(code, " }}").unwrap();
513 writeln!(code, "}}").unwrap();
514
515 return Some(code);
516 }
517
518 let mut code = String::new();
521 writeln!(code, "pub type {} = String;", rust_name).unwrap();
522 Some(code)
523 }
524 Pattern::Datatype { library, name, .. } => {
525 let rust_type = self.xsd_to_rust(library, name);
527 let mut code = String::new();
528 writeln!(code, "pub type {} = {};", rust_name, rust_type).unwrap();
529 Some(code)
530 }
531 Pattern::Ref(target) => {
532 let target_rust = if self.definitions.contains_key(target.as_str()) {
534 self.to_rust_type_name(target)
535 } else if let Some((cross_crate_type, _)) = self.resolve_cross_crate_type(target) {
536 cross_crate_type
538 } else {
539 "String".to_string()
541 };
542 let mut code = String::new();
543 writeln!(code, "pub type {} = {};", rust_name, target_rust).unwrap();
544 Some(code)
545 }
546 Pattern::List(_) => {
547 let mut code = String::new();
549 writeln!(code, "pub type {} = String;", rust_name).unwrap();
550 Some(code)
551 }
552 _ => None,
553 }
554 }
555
556 fn gen_element_group(&self, def: &Definition) -> Option<String> {
557 let rust_name = self.to_rust_type_name(&def.name);
558
559 let mut element_variants = Vec::new();
561 let mut visited = std::collections::HashSet::new();
562 visited.insert(def.name.clone()); self.collect_element_variants(&def.pattern, &mut element_variants, &mut visited);
564
565 let mut seen = std::collections::HashSet::new();
567 element_variants.retain(|(xml_name, _)| seen.insert(xml_name.clone()));
568
569 if element_variants.is_empty() {
570 let mut code = String::new();
572 writeln!(code, "pub type {} = String;", rust_name).unwrap();
573 return Some(code);
574 }
575
576 let mut code = String::new();
577 writeln!(code, "#[derive(Debug, Clone, Serialize, Deserialize)]").unwrap();
578 writeln!(code, "pub enum {} {{", rust_name).unwrap();
579
580 for (xml_name, inner_type) in &element_variants {
581 let variant_name = self.to_rust_variant_name(xml_name);
582 writeln!(code, " #[serde(rename = \"{}\")]", xml_name).unwrap();
583 writeln!(code, " {}({}),", variant_name, inner_type).unwrap();
584 }
585
586 writeln!(code, "}}").unwrap();
587
588 Some(code)
589 }
590
591 fn collect_element_variants(
594 &self,
595 pattern: &Pattern,
596 variants: &mut Vec<(String, String)>,
597 visited: &mut std::collections::HashSet<String>,
598 ) {
599 match pattern {
600 Pattern::Element { name, pattern } => {
601 let inner_type = self.pattern_to_rust_type(pattern, false, false);
603 variants.push((name.local.clone(), inner_type));
604 }
605 Pattern::Optional(inner)
606 | Pattern::ZeroOrMore(inner)
607 | Pattern::OneOrMore(inner)
608 | Pattern::Group(inner) => {
609 self.collect_element_variants(inner, variants, visited);
610 }
611 Pattern::Ref(name) => {
612 if name.contains("_EG_")
614 && visited.insert(name.clone())
615 && let Some(def_pattern) = self.definitions.get(name.as_str())
616 {
617 self.collect_element_variants(def_pattern, variants, visited);
618 }
619 }
620 Pattern::Choice(items) | Pattern::Sequence(items) | Pattern::Interleave(items) => {
621 for item in items {
622 self.collect_element_variants(item, variants, visited);
623 }
624 }
625 _ => {}
626 }
627 }
628
629 fn has_xml_children_pattern(&self, pattern: &Pattern) -> bool {
632 match pattern {
633 Pattern::Empty => false,
634 Pattern::Attribute { .. } => false,
635 Pattern::Element { .. } => true,
636 Pattern::Ref(name) => {
637 if name.contains("_AG_") {
639 return false;
640 }
641 if let Some(def_pattern) = self.definitions.get(name.as_str()) {
643 self.has_xml_children_pattern(def_pattern)
644 } else {
645 true
647 }
648 }
649 Pattern::Sequence(items) | Pattern::Interleave(items) | Pattern::Choice(items) => {
650 items.iter().any(|i| self.has_xml_children_pattern(i))
651 }
652 Pattern::Optional(inner)
653 | Pattern::ZeroOrMore(inner)
654 | Pattern::OneOrMore(inner)
655 | Pattern::Group(inner)
656 | Pattern::Mixed(inner) => self.has_xml_children_pattern(inner),
657 Pattern::Text => true,
658 _ => false,
659 }
660 }
661
662 fn has_xml_attr_pattern(&self, pattern: &Pattern) -> bool {
664 match pattern {
665 Pattern::Attribute { .. } => true,
666 Pattern::Ref(name) if name.contains("_AG_") => true,
667 Pattern::Ref(name) => {
668 if let Some(def_pattern) = self.definitions.get(name.as_str()) {
669 self.has_xml_attr_pattern(def_pattern)
670 } else {
671 false
672 }
673 }
674 Pattern::Sequence(items) | Pattern::Interleave(items) | Pattern::Choice(items) => {
675 items.iter().any(|i| self.has_xml_attr_pattern(i))
676 }
677 Pattern::Optional(inner)
678 | Pattern::ZeroOrMore(inner)
679 | Pattern::OneOrMore(inner)
680 | Pattern::Group(inner) => self.has_xml_attr_pattern(inner),
681 _ => false,
682 }
683 }
684
685 fn is_eg_content_field(&self, field: &Field) -> bool {
687 if let Pattern::Ref(name) = &field.pattern
688 && name.contains("_EG_")
689 && let Some(pattern) = self.definitions.get(name.as_str())
690 {
691 return self.is_element_choice(pattern);
692 }
693 false
694 }
695
696 fn eg_ref_to_field_name(&self, name: &str) -> String {
699 let spec_name = strip_namespace_prefix(name);
700 let short = spec_name.strip_prefix("EG_").unwrap_or(spec_name);
702 if let Some(mappings) = &self.config.name_mappings
704 && let Some(mapped) = mappings.resolve_field(&self.config.module_name, short)
705 {
706 return mapped.to_string();
707 }
708 to_snake_case(short)
709 }
710
711 fn gen_complex_type(&self, def: &Definition) -> Option<String> {
712 if let Pattern::Element { pattern, .. } = &def.pattern {
714 let rust_name = self.to_rust_type_name(&def.name);
715 let inner_type = self.pattern_to_rust_type(pattern, false, false);
717 let mut code = String::new();
718 writeln!(code, "pub type {} = {};", rust_name, inner_type).unwrap();
719 return Some(code);
720 }
721
722 let rust_name = self.to_rust_type_name(&def.name);
723 let mut code = String::new();
724
725 let fields = self.extract_fields(&def.pattern);
727
728 let element_rename = self.get_element_name(&rust_name);
730
731 if fields.is_empty() {
732 let has_unresolved_children = self.has_xml_children_pattern(&def.pattern);
736 let has_unresolved_attrs = self.has_xml_attr_pattern(&def.pattern);
737
738 writeln!(
739 code,
740 "#[derive(Debug, Clone, Default, Serialize, Deserialize)]"
741 )
742 .unwrap();
743 if let Some(xml_name) = &element_rename {
744 writeln!(code, "#[serde(rename = \"{}\")]", xml_name).unwrap();
745 }
746
747 if has_unresolved_children || has_unresolved_attrs {
748 writeln!(code, "pub struct {} {{", rust_name).unwrap();
749 if has_unresolved_attrs {
750 writeln!(
751 code,
752 " /// Unknown attributes captured for roundtrip fidelity."
753 )
754 .unwrap();
755 writeln!(code, " #[cfg(feature = \"extra-attrs\")]").unwrap();
756 writeln!(code, " #[serde(skip)]").unwrap();
757 writeln!(code, " #[cfg(feature = \"extra-attrs\")]").unwrap();
758 writeln!(code, " #[serde(default)]").unwrap();
759 writeln!(code, " #[cfg(feature = \"extra-attrs\")]").unwrap();
760 writeln!(
761 code,
762 " pub extra_attrs: std::collections::HashMap<String, String>,"
763 )
764 .unwrap();
765 }
766 if has_unresolved_children {
767 writeln!(
768 code,
769 " /// Unknown child elements captured for roundtrip fidelity."
770 )
771 .unwrap();
772 writeln!(code, " #[cfg(feature = \"extra-children\")]").unwrap();
773 writeln!(code, " #[serde(skip)]").unwrap();
774 writeln!(code, " #[cfg(feature = \"extra-children\")]").unwrap();
775 writeln!(
776 code,
777 " pub extra_children: Vec<ooxml_xml::PositionedNode>,"
778 )
779 .unwrap();
780 }
781 writeln!(code, "}}").unwrap();
782 } else {
783 writeln!(code, "pub struct {};", rust_name).unwrap();
784 }
785 } else {
786 let all_defaultable = fields
789 .iter()
790 .all(|f| f.is_optional || f.is_vec || self.is_eg_content_field(f));
791 if all_defaultable {
792 writeln!(
793 code,
794 "#[derive(Debug, Clone, Default, Serialize, Deserialize)]"
795 )
796 .unwrap();
797 } else {
798 writeln!(code, "#[derive(Debug, Clone, Serialize, Deserialize)]").unwrap();
799 }
800 if let Some(xml_name) = &element_rename {
801 writeln!(code, "#[serde(rename = \"{}\")]", xml_name).unwrap();
802 }
803 writeln!(code, "pub struct {} {{", rust_name).unwrap();
804
805 for field in &fields {
806 let is_eg_content = self.is_eg_content_field(field);
807 let inner_type = self.pattern_to_rust_type(&field.pattern, false, field.is_vec);
809 let is_bool = inner_type == "bool";
810 let eg_needs_option = is_eg_content && !field.is_optional && !field.is_vec;
813 let field_type = if field.is_vec {
814 format!("Vec<{}>", inner_type)
815 } else if field.is_optional || eg_needs_option {
816 format!("Option<{}>", inner_type)
817 } else {
818 inner_type
819 };
820
821 if let Some(ref feature) = self.get_field_feature(&rust_name, &field.xml_name) {
823 writeln!(code, " #[cfg(feature = \"{}\")]", feature).unwrap();
824 }
825
826 if is_eg_content {
827 writeln!(code, " #[serde(skip)]").unwrap();
829 writeln!(code, " #[serde(default)]").unwrap();
830 } else {
831 let xml_name = &field.xml_name;
833 if field.is_text_content {
834 writeln!(code, " #[serde(rename = \"$text\")]").unwrap();
835 } else if field.is_attribute {
836 if let Some(prefix) = &field.xml_prefix {
838 writeln!(code, " #[serde(rename = \"@{}:{}\")]", prefix, xml_name)
839 .unwrap();
840 } else {
841 writeln!(code, " #[serde(rename = \"@{}\")]", xml_name).unwrap();
842 }
843 } else {
844 writeln!(code, " #[serde(rename = \"{}\")]", xml_name).unwrap();
845 }
846 }
847 if field.is_optional {
848 if is_bool {
849 writeln!(
851 code,
852 " #[serde(default, skip_serializing_if = \"Option::is_none\", with = \"ooxml_xml::ooxml_bool\")]"
853 )
854 .unwrap();
855 } else if !is_eg_content {
856 writeln!(
857 code,
858 " #[serde(default, skip_serializing_if = \"Option::is_none\")]"
859 )
860 .unwrap();
861 }
862 } else if field.is_vec && !is_eg_content {
863 writeln!(
864 code,
865 " #[serde(default, skip_serializing_if = \"Vec::is_empty\")]"
866 )
867 .unwrap();
868 } else if is_bool {
869 writeln!(
871 code,
872 " #[serde(with = \"ooxml_xml::ooxml_bool_required\")]"
873 )
874 .unwrap();
875 }
876 writeln!(code, " pub {}: {},", field.name, field_type).unwrap();
877 }
878
879 let has_attrs = fields.iter().any(|f| f.is_attribute);
881 if has_attrs {
882 writeln!(
883 code,
884 " /// Unknown attributes captured for roundtrip fidelity."
885 )
886 .unwrap();
887 writeln!(code, " #[cfg(feature = \"extra-attrs\")]").unwrap();
888 writeln!(code, " #[serde(skip)]").unwrap();
890 writeln!(code, " #[cfg(feature = \"extra-attrs\")]").unwrap();
891 writeln!(code, " #[serde(default)]").unwrap();
892 writeln!(code, " #[cfg(feature = \"extra-attrs\")]").unwrap();
893 writeln!(
894 code,
895 " pub extra_attrs: std::collections::HashMap<String, String>,"
896 )
897 .unwrap();
898 }
899
900 let has_parsing_content = fields.iter().any(|f| !f.is_attribute);
903 if has_parsing_content {
904 writeln!(
905 code,
906 " /// Unknown child elements captured for roundtrip fidelity."
907 )
908 .unwrap();
909 writeln!(code, " #[cfg(feature = \"extra-children\")]").unwrap();
910 writeln!(code, " #[serde(skip)]").unwrap();
911 writeln!(code, " #[cfg(feature = \"extra-children\")]").unwrap();
912 writeln!(
913 code,
914 " pub extra_children: Vec<ooxml_xml::PositionedNode>,"
915 )
916 .unwrap();
917 }
918
919 writeln!(code, "}}").unwrap();
920 }
921
922 Some(code)
923 }
924
925 fn extract_fields(&self, pattern: &Pattern) -> Vec<Field> {
926 let mut fields = Vec::new();
927 self.collect_fields(pattern, &mut fields, false);
928 let mut seen = std::collections::HashSet::new();
930 fields.retain(|f| seen.insert(f.name.clone()));
931 fields
932 }
933
934 fn collect_fields(&self, pattern: &Pattern, fields: &mut Vec<Field>, is_optional: bool) {
935 match pattern {
936 Pattern::Attribute { name, pattern } => {
937 fields.push(Field {
938 name: self.qname_to_field_name(name),
939 xml_name: name.local.clone(),
940 xml_prefix: name.prefix.clone(),
941 pattern: pattern.as_ref().clone(),
942 is_optional,
943 is_attribute: true,
944 is_vec: false,
945 is_text_content: false,
946 });
947 }
948 Pattern::Element { name, pattern } => {
949 if name.local == "_any" {
951 return;
952 }
953 fields.push(Field {
954 name: self.qname_to_field_name(name),
955 xml_name: name.local.clone(),
956 xml_prefix: name.prefix.clone(),
957 pattern: pattern.as_ref().clone(),
958 is_optional,
959 is_attribute: false,
960 is_vec: false,
961 is_text_content: false,
962 });
963 }
964 Pattern::Sequence(items) | Pattern::Interleave(items) => {
965 for item in items {
966 self.collect_fields(item, fields, is_optional);
967 }
968 }
969 Pattern::Optional(inner) => {
970 self.collect_fields(inner, fields, true);
971 }
972 Pattern::ZeroOrMore(inner) | Pattern::OneOrMore(inner) => {
973 match inner.as_ref() {
975 Pattern::Element { name, pattern } if name.local != "_any" => {
976 fields.push(Field {
977 name: self.qname_to_field_name(name),
978 xml_name: name.local.clone(),
979 xml_prefix: name.prefix.clone(),
980 pattern: pattern.as_ref().clone(),
981 is_optional: false,
982 is_attribute: false,
983 is_vec: true,
984 is_text_content: false,
985 });
986 }
987 Pattern::Ref(name) if name.contains("_EG_") => {
988 if let Some(def_pattern) = self.definitions.get(name.as_str()) {
989 if self.is_element_choice(def_pattern) {
990 fields.push(Field {
992 name: self.eg_ref_to_field_name(name),
993 xml_name: name.clone(),
994 xml_prefix: None,
995 pattern: Pattern::Ref(name.clone()),
996 is_optional: false,
997 is_attribute: false,
998 is_vec: true,
999 is_text_content: false,
1000 });
1001 } else {
1002 self.collect_fields(def_pattern, fields, true);
1004 }
1005 }
1006 }
1007 Pattern::Choice(alternatives) => {
1008 for alt in alternatives {
1011 self.collect_fields_as_vec(alt, fields);
1012 }
1013 }
1014 Pattern::Ref(_) => {
1015 self.collect_fields(inner, fields, false);
1017 }
1018 Pattern::Group(group_inner) => {
1019 if let Pattern::Choice(alternatives) = group_inner.as_ref() {
1022 for alt in alternatives {
1023 self.collect_fields_as_vec(alt, fields);
1024 }
1025 } else {
1026 self.collect_fields(group_inner, fields, false);
1027 }
1028 }
1029 _ => {}
1030 }
1031 }
1032 Pattern::Group(inner) => {
1033 self.collect_fields(inner, fields, is_optional);
1034 }
1035 Pattern::Ref(name) => {
1036 if let Some(def_pattern) = self.definitions.get(name.as_str()) {
1037 if self.is_string_type(def_pattern) {
1038 fields.push(Field {
1041 name: "text".to_string(),
1042 xml_name: "$text".to_string(),
1043 xml_prefix: None,
1044 pattern: Pattern::Datatype {
1045 library: "xsd".to_string(),
1046 name: "string".to_string(),
1047 params: vec![],
1048 },
1049 is_optional: true,
1050 is_attribute: false,
1051 is_vec: false,
1052 is_text_content: true,
1053 });
1054 } else if name.contains("_EG_") {
1055 if self.is_element_choice(def_pattern) {
1057 fields.push(Field {
1059 name: self.eg_ref_to_field_name(name),
1060 xml_name: name.clone(),
1061 xml_prefix: None,
1062 pattern: Pattern::Ref(name.clone()),
1063 is_optional,
1064 is_attribute: false,
1065 is_vec: false,
1066 is_text_content: false,
1067 });
1068 } else {
1069 self.collect_fields(def_pattern, fields, is_optional);
1071 }
1072 } else if name.contains("_AG_") {
1073 self.collect_fields(def_pattern, fields, is_optional);
1075 } else {
1076 self.collect_fields(def_pattern, fields, is_optional);
1078 }
1079 }
1080 }
1081 Pattern::Choice(alternatives) => {
1082 for alt in alternatives {
1086 self.collect_fields(alt, fields, true);
1087 }
1088 }
1089 _ => {}
1090 }
1091 }
1092
1093 fn collect_fields_as_vec(&self, pattern: &Pattern, fields: &mut Vec<Field>) {
1096 match pattern {
1097 Pattern::Element {
1098 name,
1099 pattern: inner_pattern,
1100 } if name.local != "_any" => {
1101 fields.push(Field {
1102 name: self.qname_to_field_name(name),
1103 xml_name: name.local.clone(),
1104 xml_prefix: name.prefix.clone(),
1105 pattern: inner_pattern.as_ref().clone(),
1106 is_optional: false,
1107 is_attribute: false,
1108 is_vec: true,
1109 is_text_content: false,
1110 });
1111 }
1112 Pattern::Optional(inner) => {
1113 self.collect_fields(inner, fields, true);
1116 }
1117 Pattern::Group(inner) => {
1118 self.collect_fields_as_vec(inner, fields);
1119 }
1120 Pattern::Ref(name) => {
1121 if let Some(def_pattern) = self.definitions.get(name.as_str()) {
1123 if name.contains("_EG_") && self.is_element_choice(def_pattern) {
1124 fields.push(Field {
1126 name: self.eg_ref_to_field_name(name),
1127 xml_name: name.clone(),
1128 xml_prefix: None,
1129 pattern: Pattern::Ref(name.clone()),
1130 is_optional: false,
1131 is_attribute: false,
1132 is_vec: true,
1133 is_text_content: false,
1134 });
1135 } else if !name.contains("_AG_") && !name.contains("_CT_") {
1136 self.collect_fields_as_vec(def_pattern, fields);
1138 }
1139 }
1140 }
1141 _ => {}
1142 }
1143 }
1144
1145 fn to_rust_type_name(&self, name: &str) -> String {
1146 let spec_name = strip_namespace_prefix(name);
1148
1149 if let Some(mappings) = &self.config.name_mappings
1151 && let Some(mapped) = mappings.resolve_type(&self.config.module_name, spec_name)
1152 {
1153 return mapped.to_string();
1154 }
1155
1156 if self.config.warn_unmapped && self.config.name_mappings.is_some() {
1158 eprintln!("warning: unmapped type '{}' (spec: {})", spec_name, name);
1159 }
1160
1161 to_pascal_case(spec_name)
1163 }
1164
1165 fn to_rust_variant_name(&self, name: &str) -> String {
1166 if name.is_empty() {
1168 return "Empty".to_string();
1169 }
1170
1171 if let Some(mappings) = &self.config.name_mappings
1173 && let Some(mapped) = mappings.resolve_variant(&self.config.module_name, name)
1174 {
1175 return mapped.to_string();
1176 }
1177
1178 let name = to_pascal_case(name);
1180 if name.chars().next().is_some_and(|c| c.is_ascii_digit()) {
1182 format!("_{}", name)
1183 } else {
1184 name
1185 }
1186 }
1187
1188 fn qname_to_field_name(&self, qname: &QName) -> String {
1189 if let Some(mappings) = &self.config.name_mappings
1191 && let Some(mapped) = mappings.resolve_field(&self.config.module_name, &qname.local)
1192 {
1193 return mapped.to_string();
1194 }
1195
1196 if self.config.warn_unmapped && self.config.name_mappings.is_some() {
1198 eprintln!("warning: unmapped field '{}'", qname.local);
1199 }
1200
1201 to_snake_case(&qname.local)
1203 }
1204
1205 fn get_element_name(&self, rust_type_name: &str) -> Option<String> {
1207 self.config
1208 .name_mappings
1209 .as_ref()
1210 .and_then(|m| m.resolve_element(&self.config.module_name, rust_type_name))
1211 .map(|s| s.to_string())
1212 }
1213
1214 fn get_field_feature(&self, struct_name: &str, xml_field_name: &str) -> Option<String> {
1218 self.config
1219 .feature_mappings
1220 .as_ref()
1221 .and_then(|fm| {
1222 fm.primary_feature(&self.config.module_name, struct_name, xml_field_name)
1223 })
1224 .map(|feature| format!("{}-{}", self.config.module_name, feature))
1225 }
1226
1227 fn xsd_to_rust(&self, library: &str, name: &str) -> &'static str {
1228 if library == "xsd" {
1229 match name {
1230 "string" => "String",
1231 "integer" => "i64",
1232 "int" => "i32",
1233 "long" => "i64",
1234 "short" => "i16",
1235 "byte" => "i8",
1236 "unsignedInt" => "u32",
1237 "unsignedLong" => "u64",
1238 "unsignedShort" => "u16",
1239 "unsignedByte" => "u8",
1240 "boolean" => "bool",
1241 "double" => "f64",
1242 "float" => "f32",
1243 "decimal" => "f64",
1244 "dateTime" => "String", "date" => "String",
1246 "time" => "String",
1247 "hexBinary" => "Vec<u8>",
1248 "base64Binary" => "Vec<u8>",
1249 "anyURI" => "String",
1250 "token" => "String",
1251 "NCName" => "String",
1252 "ID" => "String",
1253 "IDREF" => "String",
1254 _ => "String",
1255 }
1256 } else {
1257 "String"
1258 }
1259 }
1260
1261 fn resolve_cross_crate_type(&self, name: &str) -> Option<(String, bool)> {
1264 for (prefix, (crate_path, module_name)) in &self.config.cross_crate_type_prefix {
1265 if name.starts_with(prefix) {
1266 let spec_name = strip_namespace_prefix(name);
1268 let rust_type_name = if let Some(mappings) = &self.config.name_mappings
1269 && let Some(mapped) = mappings.resolve_type(module_name, spec_name)
1270 {
1271 mapped.to_string()
1272 } else {
1273 to_pascal_case(spec_name)
1274 };
1275 let full_path = format!("{}{}", crate_path, rust_type_name);
1276 let needs_box = name.contains("_CT_") || name.contains("_EG_");
1278 return Some((full_path, needs_box));
1279 }
1280 }
1281 None
1282 }
1283
1284 fn is_element_wrapper_type_alias(&self, name: &str) -> bool {
1287 if let Some(def_pattern) = self.definitions.get(name) {
1288 matches!(def_pattern, Pattern::Element { pattern, .. } if matches!(pattern.as_ref(), Pattern::Ref(_)))
1289 } else {
1290 false
1291 }
1292 }
1293
1294 fn pattern_to_rust_type(&self, pattern: &Pattern, is_optional: bool, is_vec: bool) -> String {
1298 let (inner, needs_box) = match pattern {
1299 Pattern::Ref(name) => {
1300 if self.definitions.contains_key(name.as_str()) {
1302 let type_name = self.to_rust_type_name(name);
1303 let is_complex = name.contains("_CT_") || name.contains("_EG_");
1306 let is_already_boxed = self.is_element_wrapper_type_alias(name);
1307 let needs_box = is_complex && !is_vec && !is_already_boxed;
1308 (type_name, needs_box)
1309 } else if let Some((cross_crate_type, cross_crate_needs_box)) =
1310 self.resolve_cross_crate_type(name)
1311 {
1312 let needs_box = !is_vec && cross_crate_needs_box;
1314 (cross_crate_type, needs_box)
1315 } else {
1316 ("String".to_string(), false)
1318 }
1319 }
1320 Pattern::Datatype { library, name, .. } => {
1321 (self.xsd_to_rust(library, name).to_string(), false)
1322 }
1323 Pattern::Empty => ("()".to_string(), false),
1324 Pattern::StringLiteral(_) => ("String".to_string(), false),
1325 Pattern::Choice(_) => ("String".to_string(), false),
1326 _ => ("String".to_string(), false),
1327 };
1328
1329 let inner = if needs_box {
1330 format!("Box<{}>", inner)
1331 } else {
1332 inner
1333 };
1334
1335 if is_optional {
1336 format!("Option<{}>", inner)
1337 } else {
1338 inner
1339 }
1340 }
1341}
1342
1343struct Field {
1344 name: String,
1345 xml_name: String,
1346 #[allow(dead_code)]
1347 xml_prefix: Option<String>,
1348 pattern: Pattern,
1349 is_optional: bool,
1350 is_attribute: bool,
1351 is_vec: bool,
1352 is_text_content: bool,
1353}
1354
1355fn strip_namespace_prefix(name: &str) -> &str {
1362 for kind in ["CT_", "ST_", "EG_"] {
1364 if let Some(pos) = name.find(kind)
1365 && pos > 0
1366 {
1367 return &name[pos..];
1369 }
1370 }
1371 name
1373}
1374
1375pub(crate) fn to_pascal_case(s: &str) -> String {
1376 let mut result = String::new();
1377 let mut capitalize_next = true;
1378
1379 for ch in s.chars() {
1380 if ch == '_' || ch == '-' {
1381 capitalize_next = true;
1382 } else if capitalize_next {
1383 result.extend(ch.to_uppercase());
1384 capitalize_next = false;
1385 } else {
1386 result.push(ch);
1387 }
1388 }
1389
1390 result
1391}
1392
1393fn to_snake_case(s: &str) -> String {
1394 let mut result = String::new();
1395
1396 for (i, ch) in s.chars().enumerate() {
1397 if ch.is_uppercase() && i > 0 {
1398 result.push('_');
1399 }
1400 result.extend(ch.to_lowercase());
1401 }
1402
1403 match result.as_str() {
1405 "type" => "r#type".to_string(),
1406 "ref" => "r#ref".to_string(),
1407 "match" => "r#match".to_string(),
1408 "in" => "r#in".to_string(),
1409 "for" => "r#for".to_string(),
1410 "if" => "r#if".to_string(),
1411 "else" => "r#else".to_string(),
1412 "loop" => "r#loop".to_string(),
1413 "break" => "r#break".to_string(),
1414 "continue" => "r#continue".to_string(),
1415 "return" => "r#return".to_string(),
1416 "self" => "r#self".to_string(),
1417 "super" => "r#super".to_string(),
1418 "crate" => "r#crate".to_string(),
1419 "mod" => "r#mod".to_string(),
1420 "pub" => "r#pub".to_string(),
1421 "use" => "r#use".to_string(),
1422 "as" => "r#as".to_string(),
1423 "static" => "r#static".to_string(),
1424 "const" => "r#const".to_string(),
1425 "extern" => "r#extern".to_string(),
1426 "fn" => "r#fn".to_string(),
1427 "struct" => "r#struct".to_string(),
1428 "enum" => "r#enum".to_string(),
1429 "trait" => "r#trait".to_string(),
1430 "impl" => "r#impl".to_string(),
1431 "move" => "r#move".to_string(),
1432 "mut" => "r#mut".to_string(),
1433 "where" => "r#where".to_string(),
1434 "async" => "r#async".to_string(),
1435 "await" => "r#await".to_string(),
1436 "dyn" => "r#dyn".to_string(),
1437 "box" => "r#box".to_string(),
1438 "true" => "r#true".to_string(),
1439 "false" => "r#false".to_string(),
1440 "macro" => "r#macro".to_string(),
1441 "try" => "r#try".to_string(),
1442 "abstract" => "r#abstract".to_string(),
1443 "become" => "r#become".to_string(),
1444 "final" => "r#final".to_string(),
1445 "override" => "r#override".to_string(),
1446 "priv" => "r#priv".to_string(),
1447 "typeof" => "r#typeof".to_string(),
1448 "unsized" => "r#unsized".to_string(),
1449 "virtual" => "r#virtual".to_string(),
1450 "yield" => "r#yield".to_string(),
1451 _ => result,
1452 }
1453}
1454
1455#[cfg(test)]
1456mod tests {
1457 use super::*;
1458
1459 #[test]
1460 fn test_to_pascal_case() {
1461 assert_eq!(to_pascal_case("foo_bar"), "FooBar");
1462 assert_eq!(to_pascal_case("fooBar"), "FooBar");
1463 assert_eq!(to_pascal_case("FOO"), "FOO");
1464 }
1465
1466 #[test]
1467 fn test_to_snake_case() {
1468 assert_eq!(to_snake_case("fooBar"), "foo_bar");
1469 assert_eq!(to_snake_case("FooBar"), "foo_bar");
1470 assert_eq!(to_snake_case("type"), "r#type");
1471 }
1472}