Skip to main content

vexil_codegen_rust/
backend.rs

1// crates/vexil-codegen-rust/src/backend.rs
2
3use std::collections::{BTreeMap, HashMap, HashSet};
4use std::path::PathBuf;
5
6use vexil_lang::codegen::{CodegenBackend, CodegenError};
7use vexil_lang::ir::{CompiledSchema, ResolvedType, TypeDef, TypeId};
8use vexil_lang::project::ProjectResult;
9
10/// Rust code-generation backend for Vexil schemas.
11///
12/// Generates idiomatic Rust structs, enums, and encode/decode implementations
13/// using the `vexil-runtime` crate.
14#[derive(Debug, Clone, Copy)]
15pub struct RustBackend;
16
17impl CodegenBackend for RustBackend {
18    fn name(&self) -> &str {
19        "rust"
20    }
21
22    fn file_extension(&self) -> &str {
23        "rs"
24    }
25
26    fn generate(&self, compiled: &CompiledSchema) -> Result<String, CodegenError> {
27        crate::generate(compiled).map_err(|e| CodegenError::BackendSpecific(Box::new(e)))
28    }
29
30    fn generate_project(
31        &self,
32        result: &ProjectResult,
33    ) -> Result<BTreeMap<PathBuf, String>, CodegenError> {
34        let mut files = BTreeMap::new();
35        let mut mod_tree: BTreeMap<String, Vec<String>> = BTreeMap::new();
36
37        // Step 1: Build a global type_name -> Rust path map from all schemas' declarations.
38        let mut global_type_map: HashMap<String, String> = HashMap::new();
39        for (ns, compiled) in &result.schemas {
40            let segments: Vec<&str> = ns.split('.').collect();
41            let rust_module = segments.join("::");
42            for &type_id in &compiled.declarations {
43                if let Some(typedef) = compiled.registry.get(type_id) {
44                    let name = crate::type_name_of(typedef);
45                    let rust_path = format!("crate::{rust_module}::{name}");
46                    global_type_map.insert(name.to_string(), rust_path);
47                }
48            }
49        }
50
51        for (ns, compiled) in &result.schemas {
52            let segments: Vec<&str> = ns.split('.').collect();
53            if segments.is_empty() {
54                continue;
55            }
56            let file_name = segments[segments.len() - 1];
57            let dir_segments = &segments[..segments.len() - 1];
58
59            // Track mod.rs entries
60            for i in 0..segments.len() - 1 {
61                let parent_key = segments[..i].join("/");
62                let child = segments[i].to_string();
63                let entry = mod_tree.entry(parent_key).or_default();
64                if !entry.contains(&child) {
65                    entry.push(child);
66                }
67            }
68            if segments.len() >= 2 {
69                let parent_key = dir_segments.join("/");
70                let child = file_name.to_string();
71                let entry = mod_tree.entry(parent_key).or_default();
72                if !entry.contains(&child) {
73                    entry.push(child);
74                }
75            } else {
76                let entry = mod_tree.entry(String::new()).or_default();
77                let child = file_name.to_string();
78                if !entry.contains(&child) {
79                    entry.push(child);
80                }
81            }
82
83            // Step 2: Build import_paths for this schema.
84            let declared_ids: HashSet<TypeId> = compiled.declarations.iter().copied().collect();
85
86            // Collect all Named TypeIds referenced by declared types.
87            let mut import_paths: HashMap<TypeId, String> = HashMap::new();
88            for &type_id in &compiled.declarations {
89                if let Some(typedef) = compiled.registry.get(type_id) {
90                    collect_named_ids_from_typedef(typedef, &declared_ids, |imported_id| {
91                        if let Some(imported_def) = compiled.registry.get(imported_id) {
92                            let name = crate::type_name_of(imported_def);
93                            if let Some(rust_path) = global_type_map.get(name) {
94                                import_paths.insert(imported_id, rust_path.clone());
95                            }
96                        }
97                    });
98                }
99            }
100
101            // Generate code with cross-file imports.
102            let imports = if import_paths.is_empty() {
103                None
104            } else {
105                Some(&import_paths)
106            };
107            let code = crate::generate_with_imports(compiled, imports)
108                .map_err(|e| CodegenError::BackendSpecific(Box::new(e)))?;
109
110            let mut file_path = PathBuf::new();
111            for seg in dir_segments {
112                file_path.push(seg);
113            }
114            file_path.push(format!("{file_name}.rs"));
115            files.insert(file_path, code);
116        }
117
118        // Generate mod.rs files
119        for (dir_key, children) in &mod_tree {
120            let mut mod_path = PathBuf::new();
121            if !dir_key.is_empty() {
122                for seg in dir_key.split('/') {
123                    mod_path.push(seg);
124                }
125            }
126            mod_path.push("mod.rs");
127
128            let child_refs: Vec<&str> = children.iter().map(|s| s.as_str()).collect();
129            let mod_content = crate::generate_mod_file(&child_refs);
130            files.insert(mod_path, mod_content);
131        }
132
133        Ok(files)
134    }
135}
136
137/// Collect all `ResolvedType::Named(id)` from a TypeDef where `id` is NOT in
138/// the declared set (i.e., it's an imported type). Calls `on_import` for each.
139fn collect_named_ids_from_typedef(
140    typedef: &TypeDef,
141    declared: &HashSet<TypeId>,
142    mut on_import: impl FnMut(TypeId),
143) {
144    match typedef {
145        TypeDef::Message(msg) => {
146            for f in &msg.fields {
147                collect_named_ids_from_resolved(&f.resolved_type, declared, &mut on_import);
148            }
149        }
150        TypeDef::Union(un) => {
151            for v in &un.variants {
152                for f in &v.fields {
153                    collect_named_ids_from_resolved(&f.resolved_type, declared, &mut on_import);
154                }
155            }
156        }
157        TypeDef::Newtype(nt) => {
158            collect_named_ids_from_resolved(&nt.inner_type, declared, &mut on_import);
159        }
160        TypeDef::Config(cfg) => {
161            for f in &cfg.fields {
162                collect_named_ids_from_resolved(&f.resolved_type, declared, &mut on_import);
163            }
164        }
165        _ => {}
166    }
167}
168
169fn collect_named_ids_from_resolved(
170    ty: &ResolvedType,
171    declared: &HashSet<TypeId>,
172    on_import: &mut impl FnMut(TypeId),
173) {
174    match ty {
175        ResolvedType::Named(id) => {
176            if !declared.contains(id) {
177                on_import(*id);
178            }
179        }
180        ResolvedType::Optional(inner) | ResolvedType::Array(inner) => {
181            collect_named_ids_from_resolved(inner, declared, on_import);
182        }
183        ResolvedType::Map(k, v) | ResolvedType::Result(k, v) => {
184            collect_named_ids_from_resolved(k, declared, on_import);
185            collect_named_ids_from_resolved(v, declared, on_import);
186        }
187        _ => {}
188    }
189}