codama_korok_visitors/
combine_modules_visitor.rs

1use crate::KorokVisitor;
2use codama_errors::CodamaResult;
3use codama_koroks::KorokTrait;
4use codama_nodes::{HasName, Node, ProgramNode, RootNode};
5
6#[derive(Default)]
7pub struct CombineModulesVisitor {
8    force: bool,
9}
10
11impl CombineModulesVisitor {
12    pub fn new() -> Self {
13        Self { force: false }
14    }
15
16    pub fn force() -> Self {
17        Self { force: true }
18    }
19}
20
21impl KorokVisitor for CombineModulesVisitor {
22    fn visit_root(&mut self, korok: &mut codama_koroks::RootKorok) -> CodamaResult<()> {
23        // Unless forced, if the root node is already set, do not combine modules.
24        if !self.force && korok.node.is_some() {
25            return Ok(());
26        }
27
28        self.visit_children(korok)?;
29        korok.node = combine_koroks(&korok.node, &korok.crates);
30        Ok(())
31    }
32
33    fn visit_crate(&mut self, korok: &mut codama_koroks::CrateKorok) -> CodamaResult<()> {
34        self.visit_children(korok)?;
35        korok.node = combine_koroks(&korok.node, &korok.items);
36        Ok(())
37    }
38
39    fn visit_file_module(
40        &mut self,
41        korok: &mut codama_koroks::FileModuleKorok,
42    ) -> CodamaResult<()> {
43        self.visit_children(korok)?;
44        korok.node = combine_koroks(&korok.node, &korok.items);
45        Ok(())
46    }
47
48    fn visit_module(&mut self, korok: &mut codama_koroks::ModuleKorok) -> CodamaResult<()> {
49        self.visit_children(korok)?;
50        korok.node = combine_koroks(&korok.node, &korok.items);
51        Ok(())
52    }
53}
54
55/// Create a single RootNode from an initial node and a list of nodes to merge.
56fn combine_koroks<T: KorokTrait>(initial_node: &Option<Node>, koroks: &[T]) -> Option<Node> {
57    // Create the new RootNode to bind all items together from the exisiting node, in any.
58    // - If there is already a RootNode or ProgramNode, use this as a starting point.
59    // - If there is no existing node, use None and let the merging create a new one if needed.
60    // - If there is any other node, return it as-is without combining the nodes.
61    let mut this_root_node = match initial_node {
62        Some(Node::Root(root)) => Some(root.clone()),
63        Some(Node::Program(program)) => Some(RootNode::new(program.clone())),
64        None => None,
65        _ => return initial_node.clone(),
66    };
67
68    // Get all nodes from the koroks to merge.
69    let nodes_to_merge = koroks
70        .iter()
71        .filter_map(|item| item.node().clone())
72        .collect::<Vec<_>>();
73
74    // Convert all nodes into RootNodes and merge them with the binding root node.
75    let from_parent = this_root_node.is_some();
76    for that_root_node in get_root_nodes_to_merge(nodes_to_merge) {
77        merge_root_nodes(&mut this_root_node, that_root_node, from_parent);
78    }
79
80    this_root_node.map(Into::into)
81}
82
83/// Convert all nodes to merge into RootNodes.
84fn get_root_nodes_to_merge(nodes: Vec<Node>) -> Vec<RootNode> {
85    // Split the nodes into:
86    // - Nodes that can be converted into RootNodes (Root, Program).
87    // - All other nodes that we will refer to as scraps.
88    let (roots_and_programs, scraps) = nodes
89        .into_iter()
90        .partition::<Vec<Node>, _>(|node| matches!(node, Node::Root(_) | Node::Program(_)));
91
92    // Convert all "rootable" nodes into RootNodes.
93    let mut roots = roots_and_programs
94        .into_iter()
95        .filter_map(|node| match node {
96            Node::Root(node) => Some(node),
97            Node::Program(node) => Some(RootNode::new(node)),
98            _ => None,
99        })
100        .collect::<Vec<_>>();
101
102    // Try to get a RootNode from all the scraps.
103    if let Some(root) = get_scraps_root_node(scraps) {
104        roots.push(root)
105    }
106
107    roots
108}
109
110/// Go through all "scraps" nodes and try to get a shared RootNode from them.
111fn get_scraps_root_node(nodes: Vec<Node>) -> Option<RootNode> {
112    let mut has_scraps = false;
113    let mut root = RootNode::default();
114
115    for node in nodes {
116        match node {
117            Node::Account(node) => {
118                add_or_replace_node_with_name(&mut root.program.accounts, node);
119                has_scraps = true
120            }
121            Node::Instruction(node) => {
122                add_or_replace_node_with_name(&mut root.program.instructions, node);
123                has_scraps = true
124            }
125            Node::DefinedType(node) => {
126                add_or_replace_node_with_name(&mut root.program.defined_types, node);
127                has_scraps = true
128            }
129            Node::Error(node) => {
130                add_or_replace_node_with_name(&mut root.program.errors, node);
131                has_scraps = true
132            }
133            Node::Pda(node) => {
134                add_or_replace_node_with_name(&mut root.program.pdas, node);
135                has_scraps = true
136            }
137            _ => (),
138        }
139    }
140
141    has_scraps.then_some(root)
142}
143
144/// Merge `that` RootNode into `this` RootNode.
145fn merge_root_nodes(this: &mut Option<RootNode>, that: RootNode, from_parent: bool) {
146    // If there is no root node yet, set it to the one provided.
147    let Some(this) = this else {
148        *this = Some(that);
149        return;
150    };
151
152    // Get an array of all programs to merge.
153    let mut those_programs = Vec::new();
154    those_programs.push(that.program);
155    those_programs.extend(that.additional_programs);
156
157    // For each program to merge.
158    for that_program in those_programs {
159        // Check if it can be merged with the main root program.
160        if should_merge_program_nodes(&this.program, &that_program, from_parent) {
161            merge_program_nodes(&mut this.program, that_program);
162            continue;
163        }
164
165        // Then, check if it can be merged with any additional program.
166        let found = this
167            .additional_programs
168            .iter_mut()
169            .find(|p| should_merge_program_nodes(p, &that_program, from_parent));
170
171        if let Some(additional_program) = found {
172            // If so, merge it with the additional program found.
173            merge_program_nodes(additional_program, that_program);
174        } else {
175            // Otherwise, add it as another additional program.
176            this.additional_programs.push(that_program);
177        }
178    }
179}
180
181/// Check if two ProgramNodes should be merged together.
182fn should_merge_program_nodes(this: &ProgramNode, that: &ProgramNode, from_parent: bool) -> bool {
183    this.public_key == that.public_key || (from_parent && that.public_key.is_empty())
184}
185
186/// Merge `that` ProgramNode into `this` ProgramNode.
187fn merge_program_nodes(this: &mut ProgramNode, that: ProgramNode) {
188    if this.name.is_empty() {
189        this.name = that.name;
190    }
191    if this.public_key.is_empty() {
192        this.public_key = that.public_key;
193    }
194    if this.version.is_empty() {
195        this.version = that.version;
196    }
197    if this.origin.is_none() {
198        this.origin = that.origin;
199    }
200    if this.docs.is_empty() {
201        this.docs = that.docs;
202    }
203    merge_nodes_with_name(&mut this.accounts, that.accounts);
204    merge_nodes_with_name(&mut this.instructions, that.instructions);
205    merge_nodes_with_name(&mut this.defined_types, that.defined_types);
206    merge_nodes_with_name(&mut this.errors, that.errors);
207    merge_nodes_with_name(&mut this.pdas, that.pdas);
208}
209
210fn merge_nodes_with_name<T>(nodes: &mut Vec<T>, new_nodes: Vec<T>)
211where
212    T: HasName,
213{
214    for that in new_nodes {
215        add_or_replace_node_with_name(nodes, that);
216    }
217}
218
219fn add_or_replace_node_with_name<T>(nodes: &mut Vec<T>, new_node: T)
220where
221    T: HasName,
222{
223    if let Some(existing) = nodes.iter_mut().find(|d| d.name() == new_node.name()) {
224        *existing = new_node;
225    } else {
226        nodes.push(new_node);
227    }
228}