Skip to main content

rh_codegen/generators/file_generator/
trait_file.rs

1use std::fs;
2use std::path::Path;
3
4use crate::generators::import_manager::ImportManager;
5use crate::rust_types::RustTrait;
6use crate::{CodegenError, CodegenResult};
7
8use super::FileGenerator;
9
10impl<'a> FileGenerator<'a> {
11    pub fn generate_trait_to_file<P: AsRef<Path>>(
12        &self,
13        _structure_def: &crate::fhir_types::StructureDefinition,
14        output_path: P,
15        rust_trait: &RustTrait,
16    ) -> CodegenResult<()> {
17        let mut all_tokens = proc_macro2::TokenStream::new();
18
19        let mut imports = std::collections::HashSet::new();
20        ImportManager::collect_custom_types_from_trait(rust_trait, &mut imports);
21
22        for import_path in imports {
23            let import_stmt = format!("use {import_path};");
24            let import_tokens: proc_macro2::TokenStream =
25                import_stmt.parse().map_err(|e| CodegenError::Generation {
26                    message: format!("Failed to parse import statement '{import_stmt}': {e}"),
27                })?;
28            all_tokens.extend(import_tokens);
29        }
30
31        let trait_tokens = self.token_generator.generate_trait(rust_trait);
32        all_tokens.extend(trait_tokens);
33
34        if std::env::var("DEBUG_TOKENS").is_ok() {
35            eprintln!(
36                "DEBUG: Generated tokens for trait '{}': {}",
37                rust_trait.name, all_tokens
38            );
39        }
40
41        let syntax_tree = syn::parse2(all_tokens).map_err(|e| CodegenError::Generation {
42            message: format!(
43                "Failed to parse generated trait tokens for '{}': {e}",
44                rust_trait.name
45            ),
46        })?;
47
48        let formatted_code = prettyplease::unparse(&syntax_tree);
49
50        if output_path.as_ref().exists() {
51            eprintln!(
52                "Warning: Trait file '{}' already exists and will be overwritten.",
53                output_path.as_ref().display()
54            );
55        }
56
57        fs::write(output_path.as_ref(), formatted_code)?;
58
59        Ok(())
60    }
61
62    pub fn generate_traits_to_file<P: AsRef<Path>>(
63        &self,
64        _structure_def: &crate::fhir_types::StructureDefinition,
65        output_path: P,
66        rust_traits: &[&RustTrait],
67    ) -> CodegenResult<()> {
68        let mut all_tokens = proc_macro2::TokenStream::new();
69
70        let mut imports = std::collections::HashSet::new();
71        for rust_trait in rust_traits {
72            ImportManager::collect_custom_types_from_trait(rust_trait, &mut imports);
73        }
74
75        for import_path in imports {
76            let import_stmt = format!("use {import_path};");
77            let import_tokens: proc_macro2::TokenStream =
78                import_stmt.parse().map_err(|e| CodegenError::Generation {
79                    message: format!("Failed to parse import statement '{import_stmt}': {e}"),
80                })?;
81            all_tokens.extend(import_tokens);
82        }
83
84        for rust_trait in rust_traits {
85            let trait_tokens = self.token_generator.generate_trait(rust_trait);
86            all_tokens.extend(trait_tokens);
87        }
88
89        if std::env::var("DEBUG_TOKENS").is_ok() {
90            let trait_names: Vec<&str> = rust_traits.iter().map(|t| t.name.as_str()).collect();
91            eprintln!(
92                "DEBUG: Generated tokens for traits [{}]: {}",
93                trait_names.join(", "),
94                all_tokens
95            );
96        }
97
98        let syntax_tree = syn::parse2(all_tokens).map_err(|e| CodegenError::Generation {
99            message: format!("Failed to parse generated trait tokens: {e}"),
100        })?;
101
102        let formatted_code = prettyplease::unparse(&syntax_tree);
103
104        if output_path.as_ref().exists() {
105            eprintln!(
106                "Warning: Trait file '{}' already exists and will be overwritten.",
107                output_path.as_ref().display()
108            );
109        }
110
111        fs::write(output_path.as_ref(), formatted_code)?;
112
113        Ok(())
114    }
115
116    pub fn generate_trait_file_from_trait<P: AsRef<Path>>(
117        &self,
118        rust_trait: &RustTrait,
119        output_path: P,
120    ) -> CodegenResult<()> {
121        let mut all_tokens = proc_macro2::TokenStream::new();
122
123        let mut imports = std::collections::HashSet::new();
124        ImportManager::collect_custom_types_from_trait(rust_trait, &mut imports);
125
126        for import_path in imports {
127            let import_stmt = format!("use {import_path};");
128            let import_tokens: proc_macro2::TokenStream =
129                import_stmt.parse().map_err(|e| CodegenError::Generation {
130                    message: format!("Failed to parse import statement '{import_stmt}': {e}"),
131                })?;
132            all_tokens.extend(import_tokens);
133        }
134
135        let trait_tokens = self.token_generator.generate_trait(rust_trait);
136        all_tokens.extend(trait_tokens);
137
138        let syntax_tree = syn::parse2(all_tokens).map_err(|e| CodegenError::Generation {
139            message: format!("Failed to parse generated trait tokens: {e}"),
140        })?;
141
142        let formatted_code = prettyplease::unparse(&syntax_tree);
143
144        if output_path.as_ref().exists() {
145            eprintln!(
146                "Warning: Trait file '{}' already exists and will be overwritten.",
147                output_path.as_ref().display()
148            );
149        }
150
151        fs::write(output_path.as_ref(), formatted_code)?;
152
153        Ok(())
154    }
155
156    pub(crate) fn generate_trait_implementations(
157        &self,
158        structure_def: &crate::fhir_types::StructureDefinition,
159    ) -> String {
160        let trait_impl_generator =
161            crate::generators::trait_impl_generator::TraitImplGenerator::new();
162        let trait_impls = match trait_impl_generator.generate_trait_impls(structure_def) {
163            Ok(impls) => impls,
164            Err(e) => {
165                eprintln!(
166                    "Warning: Failed to generate trait implementations for {}: {}",
167                    structure_def.name, e
168                );
169                return String::new();
170            }
171        };
172
173        let mut implementations = Vec::new();
174
175        for trait_impl in trait_impls {
176            let impl_tokens = self.token_generator.generate_trait_impl(&trait_impl);
177
178            match syn::parse2(impl_tokens.clone()) {
179                Ok(syntax_tree) => {
180                    let formatted_impl = prettyplease::unparse(&syntax_tree);
181                    implementations.push(formatted_impl);
182                }
183                Err(e) => {
184                    eprintln!(
185                        "Warning: Failed to parse trait implementation for {}: {}",
186                        trait_impl.struct_name, e
187                    );
188                    eprintln!("Generated tokens:\n{impl_tokens}");
189                }
190            }
191        }
192
193        if implementations.is_empty() {
194            String::new()
195        } else {
196            format!("// Trait implementations\n{}", implementations.join("\n\n"))
197        }
198    }
199
200    pub(crate) fn generate_trait_reexports(
201        &self,
202        structure_def: &crate::fhir_types::StructureDefinition,
203    ) -> String {
204        let is_profile = crate::generators::type_registry::TypeRegistry::is_profile(structure_def);
205
206        let (trait_module_name, trait_prefix) = if is_profile {
207            let struct_name = crate::naming::Naming::struct_name(structure_def);
208            let snake_module = crate::naming::Naming::to_rust_identifier(
209                &crate::naming::Naming::to_snake_case(&struct_name),
210            );
211            (snake_module, struct_name)
212        } else {
213            let resource_name = crate::naming::Naming::to_rust_identifier(&structure_def.name);
214            let snake_name = crate::naming::Naming::to_rust_identifier(
215                &crate::naming::Naming::to_snake_case(&resource_name),
216            );
217            (snake_name, resource_name)
218        };
219
220        format!(
221            r#"// Re-export traits for convenient importing
222// This allows users to just import the resource module and get all associated traits
223pub use crate::traits::{trait_module_name}::{{
224    {trait_prefix}Mutators,
225    {trait_prefix}Accessors,
226    {trait_prefix}Existence,
227}};"#
228        )
229    }
230
231    pub(crate) fn generate_default_implementation(
232        &self,
233        structure_def: &crate::fhir_types::StructureDefinition,
234        rust_struct: &crate::rust_types::RustStruct,
235    ) -> String {
236        let is_profile = crate::generators::type_registry::TypeRegistry::is_profile(structure_def);
237        if is_profile {
238            return String::new();
239        }
240
241        let struct_name = &rust_struct.name;
242
243        if rust_struct.derives.iter().any(|d| d == "Default") {
244            return String::new();
245        }
246
247        let elements = if let Some(differential) = &structure_def.differential {
248            &differential.element
249        } else if let Some(snapshot) = &structure_def.snapshot {
250            &snapshot.element
251        } else {
252            &Vec::new()
253        };
254
255        let mut required_fields = Vec::new();
256        for element in elements {
257            let path_parts: Vec<&str> = element.path.split('.').collect();
258            if path_parts.len() == 2 && path_parts[0] == structure_def.name {
259                let field_name = path_parts[1];
260                if let Some(min) = element.min {
261                    if min >= 1 && !field_name.ends_with("[x]") {
262                        required_fields.push((field_name, element.clone()));
263                    }
264                }
265            }
266        }
267
268        let mut field_inits = Vec::new();
269
270        if let Some(base_def) = &rust_struct.base_definition {
271            let base_type = base_def.split('/').next_back().unwrap_or(base_def);
272            let base_type = crate::naming::Naming::to_rust_identifier(base_type);
273            let proper_base_type = if base_type
274                .chars()
275                .next()
276                .map(|c| c.is_lowercase())
277                .unwrap_or(false)
278            {
279                crate::naming::Naming::capitalize_first(&base_type)
280            } else {
281                base_type
282            };
283            field_inits.push(format!("base: {proper_base_type}::default()"));
284        }
285
286        for field in &rust_struct.fields {
287            let field_name = &field.name;
288
289            let is_required = required_fields.iter().any(|(name, _)| {
290                let snake_name = crate::naming::Naming::to_snake_case(name);
291                snake_name == *field_name
292            });
293
294            if is_required {
295                let default_value = match field.field_type.to_string().as_str() {
296                    s if s.contains("::") && !s.contains("Option") && !s.contains("Vec") => {
297                        format!("{s}::default()")
298                    }
299                    "String" => "String::new()".to_string(),
300                    "i32" | "i64" | "u32" | "u64" => "0".to_string(),
301                    "f32" | "f64" => "0.0".to_string(),
302                    "bool" => "false".to_string(),
303                    s if s.starts_with("Vec<") => "Vec::new()".to_string(),
304                    _ => format!("{}::default()", field.field_type.to_string()),
305                };
306                field_inits.push(format!("{field_name}: {default_value}"));
307            } else {
308                field_inits.push(format!("{field_name}: Default::default()"));
309            }
310        }
311
312        let impl_block = format!(
313            r#"impl Default for {} {{
314    fn default() -> Self {{
315        Self {{
316            {}
317        }}
318    }}
319}}"#,
320            struct_name,
321            field_inits.join(",\n            ")
322        );
323
324        impl_block
325    }
326
327    pub(crate) fn generate_nested_struct_default_implementation(
328        &self,
329        parent_structure_def: &crate::fhir_types::StructureDefinition,
330        nested_struct: &crate::rust_types::RustStruct,
331    ) -> String {
332        let struct_name = &nested_struct.name;
333
334        if nested_struct.derives.iter().any(|d| d == "Default") {
335            return String::new();
336        }
337
338        let parent_name = &parent_structure_def.name;
339        let nested_field_name = if struct_name.starts_with(parent_name) {
340            let suffix = &struct_name[parent_name.len()..];
341            crate::naming::Naming::to_snake_case(suffix)
342        } else {
343            return String::new();
344        };
345
346        let base_path = format!("{parent_name}.{nested_field_name}");
347
348        let elements = if let Some(differential) = &parent_structure_def.differential {
349            &differential.element
350        } else if let Some(snapshot) = &parent_structure_def.snapshot {
351            &snapshot.element
352        } else {
353            &Vec::new()
354        };
355
356        let mut required_fields = Vec::new();
357        for element in elements {
358            if element.path.starts_with(&format!("{base_path}.")) {
359                let field_path = element
360                    .path
361                    .strip_prefix(&format!("{base_path}."))
362                    .unwrap_or_else(|| {
363                        panic!(
364                            "codegen bug: element path '{}' does not start with '{base_path}.'",
365                            element.path
366                        )
367                    });
368                if !field_path.contains('.') && !field_path.ends_with("[x]") {
369                    if let Some(min) = element.min {
370                        if min >= 1 {
371                            required_fields.push((field_path, element.clone()));
372                        }
373                    }
374                }
375            }
376        }
377
378        let mut field_inits = Vec::new();
379
380        if let Some(base_def) = &nested_struct.base_definition {
381            let base_type = base_def.split('/').next_back().unwrap_or(base_def);
382            let base_type = crate::naming::Naming::to_rust_identifier(base_type);
383            let proper_base_type = if base_type
384                .chars()
385                .next()
386                .map(|c| c.is_lowercase())
387                .unwrap_or(false)
388            {
389                crate::naming::Naming::capitalize_first(&base_type)
390            } else {
391                base_type
392            };
393            field_inits.push(format!("base: {proper_base_type}::default()"));
394        }
395
396        for field in &nested_struct.fields {
397            let field_name = &field.name;
398
399            let is_required = required_fields.iter().any(|(name, _)| {
400                let snake_name = crate::naming::Naming::to_snake_case(name);
401                snake_name == *field_name
402            });
403
404            if is_required {
405                let default_value = match field.field_type.to_string().as_str() {
406                    s if s.contains("::") && !s.contains("Option") && !s.contains("Vec") => {
407                        format!("{s}::default()")
408                    }
409                    "String" => "String::new()".to_string(),
410                    "i32" | "i64" | "u32" | "u64" => "0".to_string(),
411                    "f32" | "f64" => "0.0".to_string(),
412                    "bool" => "false".to_string(),
413                    s if s.starts_with("Vec<") => "Vec::new()".to_string(),
414                    _ => format!("{}::default()", field.field_type.to_string()),
415                };
416                field_inits.push(format!("{field_name}: {default_value}"));
417            } else {
418                field_inits.push(format!("{field_name}: Default::default()"));
419            }
420        }
421
422        let impl_block = format!(
423            r#"impl Default for {} {{
424    fn default() -> Self {{
425        Self {{
426            {}
427        }}
428    }}
429}}"#,
430            struct_name,
431            field_inits.join(",\n            ")
432        );
433
434        impl_block
435    }
436}