codama_korok_visitors/
combine_modules_visitor.rs1use crate::KorokVisitor;
2use codama_errors::CodamaResult;
3use codama_koroks::KorokTrait;
4use codama_nodes::{HasName, Node, ProgramNode, RootNode};
5
6#[derive(Default)]
7pub struct CombineModulesVisitor {
8 force: bool,
9}
10
11impl CombineModulesVisitor {
12 pub fn new() -> Self {
13 Self { force: false }
14 }
15
16 pub fn force() -> Self {
17 Self { force: true }
18 }
19}
20
21impl KorokVisitor for CombineModulesVisitor {
22 fn visit_root(&mut self, korok: &mut codama_koroks::RootKorok) -> CodamaResult<()> {
23 if !self.force && korok.node.is_some() {
25 return Ok(());
26 }
27
28 self.visit_children(korok)?;
29 korok.node = combine_koroks(&korok.node, &korok.crates);
30 Ok(())
31 }
32
33 fn visit_crate(&mut self, korok: &mut codama_koroks::CrateKorok) -> CodamaResult<()> {
34 self.visit_children(korok)?;
35 korok.node = combine_koroks(&korok.node, &korok.items);
36 Ok(())
37 }
38
39 fn visit_file_module(
40 &mut self,
41 korok: &mut codama_koroks::FileModuleKorok,
42 ) -> CodamaResult<()> {
43 self.visit_children(korok)?;
44 korok.node = combine_koroks(&korok.node, &korok.items);
45 Ok(())
46 }
47
48 fn visit_module(&mut self, korok: &mut codama_koroks::ModuleKorok) -> CodamaResult<()> {
49 self.visit_children(korok)?;
50 korok.node = combine_koroks(&korok.node, &korok.items);
51 Ok(())
52 }
53}
54
55fn combine_koroks<T: KorokTrait>(initial_node: &Option<Node>, koroks: &[T]) -> Option<Node> {
57 let mut this_root_node = match initial_node {
62 Some(Node::Root(root)) => Some(root.clone()),
63 Some(Node::Program(program)) => Some(RootNode::new(program.clone())),
64 None => None,
65 _ => return initial_node.clone(),
66 };
67
68 let nodes_to_merge = koroks
70 .iter()
71 .filter_map(|item| item.node().clone())
72 .collect::<Vec<_>>();
73
74 let from_parent = this_root_node.is_some();
76 for that_root_node in get_root_nodes_to_merge(nodes_to_merge) {
77 merge_root_nodes(&mut this_root_node, that_root_node, from_parent);
78 }
79
80 this_root_node.map(Into::into)
81}
82
83fn get_root_nodes_to_merge(nodes: Vec<Node>) -> Vec<RootNode> {
85 let (roots_and_programs, scraps) = nodes
89 .into_iter()
90 .partition::<Vec<Node>, _>(|node| matches!(node, Node::Root(_) | Node::Program(_)));
91
92 let mut roots = roots_and_programs
94 .into_iter()
95 .filter_map(|node| match node {
96 Node::Root(node) => Some(node),
97 Node::Program(node) => Some(RootNode::new(node)),
98 _ => None,
99 })
100 .collect::<Vec<_>>();
101
102 if let Some(root) = get_scraps_root_node(scraps) {
104 roots.push(root)
105 }
106
107 roots
108}
109
110fn get_scraps_root_node(nodes: Vec<Node>) -> Option<RootNode> {
112 let mut has_scraps = false;
113 let mut root = RootNode::default();
114
115 for node in nodes {
116 match node {
117 Node::Account(node) => {
118 add_or_replace_node_with_name(&mut root.program.accounts, node);
119 has_scraps = true
120 }
121 Node::Instruction(node) => {
122 add_or_replace_node_with_name(&mut root.program.instructions, node);
123 has_scraps = true
124 }
125 Node::DefinedType(node) => {
126 add_or_replace_node_with_name(&mut root.program.defined_types, node);
127 has_scraps = true
128 }
129 Node::Error(node) => {
130 add_or_replace_node_with_name(&mut root.program.errors, node);
131 has_scraps = true
132 }
133 Node::Pda(node) => {
134 add_or_replace_node_with_name(&mut root.program.pdas, node);
135 has_scraps = true
136 }
137 _ => (),
138 }
139 }
140
141 has_scraps.then_some(root)
142}
143
144fn merge_root_nodes(this: &mut Option<RootNode>, that: RootNode, from_parent: bool) {
146 let Some(this) = this else {
148 *this = Some(that);
149 return;
150 };
151
152 let mut those_programs = Vec::new();
154 those_programs.push(that.program);
155 those_programs.extend(that.additional_programs);
156
157 for that_program in those_programs {
159 if should_merge_program_nodes(&this.program, &that_program, from_parent) {
161 merge_program_nodes(&mut this.program, that_program);
162 continue;
163 }
164
165 let found = this
167 .additional_programs
168 .iter_mut()
169 .find(|p| should_merge_program_nodes(p, &that_program, from_parent));
170
171 if let Some(additional_program) = found {
172 merge_program_nodes(additional_program, that_program);
174 } else {
175 this.additional_programs.push(that_program);
177 }
178 }
179}
180
181fn should_merge_program_nodes(this: &ProgramNode, that: &ProgramNode, from_parent: bool) -> bool {
183 this.public_key == that.public_key || (from_parent && that.public_key.is_empty())
184}
185
186fn merge_program_nodes(this: &mut ProgramNode, that: ProgramNode) {
188 if this.name.is_empty() {
189 this.name = that.name;
190 }
191 if this.public_key.is_empty() {
192 this.public_key = that.public_key;
193 }
194 if this.version.is_empty() {
195 this.version = that.version;
196 }
197 if this.origin.is_none() {
198 this.origin = that.origin;
199 }
200 if this.docs.is_empty() {
201 this.docs = that.docs;
202 }
203 merge_nodes_with_name(&mut this.accounts, that.accounts);
204 merge_nodes_with_name(&mut this.instructions, that.instructions);
205 merge_nodes_with_name(&mut this.defined_types, that.defined_types);
206 merge_nodes_with_name(&mut this.errors, that.errors);
207 merge_nodes_with_name(&mut this.pdas, that.pdas);
208}
209
210fn merge_nodes_with_name<T>(nodes: &mut Vec<T>, new_nodes: Vec<T>)
211where
212 T: HasName,
213{
214 for that in new_nodes {
215 add_or_replace_node_with_name(nodes, that);
216 }
217}
218
219fn add_or_replace_node_with_name<T>(nodes: &mut Vec<T>, new_node: T)
220where
221 T: HasName,
222{
223 if let Some(existing) = nodes.iter_mut().find(|d| d.name() == new_node.name()) {
224 *existing = new_node;
225 } else {
226 nodes.push(new_node);
227 }
228}