Skip to main content

chipi_core/
tree.rs

1//! Decision tree construction for optimal instruction dispatch.
2//!
3//! Builds a compact decision tree that efficiently dispatches on bit patterns
4//! to select the correct instruction. The tree minimizes the maximum partition
5//! size and balances splits to create an efficient decoder.
6
7use std::collections::BTreeMap;
8
9use crate::types::*;
10
11/// A node in the dispatch/decision tree.
12#[derive(Debug, Clone)]
13pub enum DecodeNode {
14    /// Branch on a range of bits, with arms for each value
15    Branch {
16        range: BitRange,
17        arms: BTreeMap<u64, DecodeNode>,
18        default: Box<DecodeNode>,
19    },
20    /// Leaf: matched this instruction
21    Leaf { instruction_index: usize },
22    /// Multiple candidates to try in priority order (most specific to least specific)
23    /// Used when patterns overlap and can't be distinguished by bit splits alone
24    PriorityLeaves { candidates: Vec<usize> },
25    /// No instruction matches
26    Fail,
27}
28
29/// Build an optimal dispatch tree from validated instructions.
30///
31/// The tree is constructed recursively, choosing bit ranges that best
32/// partition the instruction candidates at each level.
33pub fn build_tree(def: &ValidatedDef) -> DecodeNode {
34    let candidates: Vec<usize> = (0..def.instructions.len()).collect();
35    build_node(&def.instructions, &candidates, def.config.width)
36}
37
38fn build_node(
39    instructions: &[ValidatedInstruction],
40    candidates: &[usize],
41    width: u32,
42) -> DecodeNode {
43    match candidates.len() {
44        0 => DecodeNode::Fail,
45        1 => DecodeNode::Leaf {
46            instruction_index: candidates[0],
47        },
48        _ => {
49            // Find the best bit range to split on
50            let groups = find_useful_bit_groups(instructions, candidates, width);
51
52            if groups.is_empty() {
53                // No bits can distinguish candidates.
54                // Separate specific patterns from wildcards and preserve both.
55                let (specifics, wildcards) =
56                    separate_specific_and_wildcards(instructions, candidates, width);
57
58                if !specifics.is_empty() && !wildcards.is_empty() {
59                    // Both specific and wildcard patterns exist
60                    // Return a PriorityLeaves node with specifics first, then wildcards
61                    let mut priority_order = specifics;
62                    priority_order.extend(wildcards);
63                    return DecodeNode::PriorityLeaves {
64                        candidates: priority_order,
65                    };
66                } else if !specifics.is_empty() {
67                    // Only specific patterns
68                    if specifics.len() == 1 {
69                        return DecodeNode::Leaf {
70                            instruction_index: specifics[0],
71                        };
72                    } else {
73                        return DecodeNode::PriorityLeaves {
74                            candidates: specifics,
75                        };
76                    }
77                } else {
78                    // Only wildcards
79                    if wildcards.len() == 1 {
80                        return DecodeNode::Leaf {
81                            instruction_index: wildcards[0],
82                        };
83                    } else {
84                        return DecodeNode::PriorityLeaves {
85                            candidates: wildcards,
86                        };
87                    }
88                }
89            }
90
91            // Try all candidate ranges and pick the best
92            let mut best_range: Option<BitRange> = None;
93            let mut best_score: Option<(usize, usize, u32)> = None;
94
95            for group in &groups {
96                let ranges_to_try = generate_sub_ranges(group);
97                for range in ranges_to_try {
98                    let partitions = partition_by_range(instructions, candidates, range);
99
100                    // Count effective partitions (values that actually split)
101                    let num_values = partitions.len();
102                    if num_values <= 1 {
103                        continue;
104                    }
105
106                    let max_part = partitions.values().map(|v| v.len()).max().unwrap_or(0);
107
108                    // If this split doesn't actually reduce any partition, skip
109                    if max_part >= candidates.len() {
110                        continue;
111                    }
112
113                    // Score: (max_partition, inv_num_partitions, inv_width) - lower is better
114                    let score = (max_part, usize::MAX - num_values, u32::MAX - range.width());
115
116                    let is_better = match &best_score {
117                        None => true,
118                        Some(prev) => score < *prev,
119                    };
120
121                    if is_better {
122                        best_score = Some(score);
123                        best_range = Some(range);
124                    }
125                }
126            }
127
128            let range = match best_range {
129                Some(r) => r,
130                None => {
131                    // Can't split further, separate specifics from wildcards
132                    let (specifics, wildcards) =
133                        separate_specific_and_wildcards(instructions, candidates, width);
134
135                    if !specifics.is_empty() && !wildcards.is_empty() {
136                        let mut priority_order = specifics;
137                        priority_order.extend(wildcards);
138                        return DecodeNode::PriorityLeaves {
139                            candidates: priority_order,
140                        };
141                    } else if !specifics.is_empty() {
142                        if specifics.len() == 1 {
143                            return DecodeNode::Leaf {
144                                instruction_index: specifics[0],
145                            };
146                        } else {
147                            return DecodeNode::PriorityLeaves {
148                                candidates: specifics,
149                            };
150                        }
151                    } else {
152                        if wildcards.len() == 1 {
153                            return DecodeNode::Leaf {
154                                instruction_index: wildcards[0],
155                            };
156                        } else {
157                            return DecodeNode::PriorityLeaves {
158                                candidates: wildcards,
159                            };
160                        }
161                    }
162                }
163            };
164
165            let partitions = partition_by_range(instructions, candidates, range);
166
167            // Collect wildcards: candidates that don't have all fixed bits in range
168            let wildcards: Vec<usize> = candidates
169                .iter()
170                .copied()
171                .filter(|&idx| !has_all_fixed_at(&instructions[idx], range))
172                .collect();
173
174            let mut arms = BTreeMap::new();
175            for (value, sub_candidates) in partitions {
176                // Guard: if partition didn't reduce candidate count, avoid infinite recursion
177                if sub_candidates.len() >= candidates.len() {
178                    // No progress, separate specifics from wildcards
179                    let (specifics, wildcards) =
180                        separate_specific_and_wildcards(instructions, &sub_candidates, width);
181
182                    if !specifics.is_empty() && !wildcards.is_empty() {
183                        let mut priority_order = specifics;
184                        priority_order.extend(wildcards);
185                        arms.insert(
186                            value,
187                            DecodeNode::PriorityLeaves {
188                                candidates: priority_order,
189                            },
190                        );
191                    } else if !specifics.is_empty() {
192                        if specifics.len() == 1 {
193                            arms.insert(
194                                value,
195                                DecodeNode::Leaf {
196                                    instruction_index: specifics[0],
197                                },
198                            );
199                        } else {
200                            arms.insert(
201                                value,
202                                DecodeNode::PriorityLeaves {
203                                    candidates: specifics,
204                                },
205                            );
206                        }
207                    } else {
208                        if wildcards.len() == 1 {
209                            arms.insert(
210                                value,
211                                DecodeNode::Leaf {
212                                    instruction_index: wildcards[0],
213                                },
214                            );
215                        } else {
216                            arms.insert(
217                                value,
218                                DecodeNode::PriorityLeaves {
219                                    candidates: wildcards,
220                                },
221                            );
222                        }
223                    }
224                } else {
225                    let child = build_node(instructions, &sub_candidates, width);
226                    arms.insert(value, child);
227                }
228            }
229
230            // Default arm for values not explicitly matched
231            let default = if wildcards.is_empty() {
232                Box::new(DecodeNode::Fail)
233            } else {
234                Box::new(build_node(instructions, &wildcards, width))
235            };
236
237            DecodeNode::Branch {
238                range,
239                arms,
240                default,
241            }
242        }
243    }
244}
245
246/// Check if an instruction has fixed bits at ALL positions in a range.
247fn has_all_fixed_at(instr: &ValidatedInstruction, range: BitRange) -> bool {
248    range.bits().all(|bit| instr.fixed_bit_at(bit).is_some())
249}
250
251/// Find contiguous bit groups useful for splitting.
252/// A bit position is useful if at least 2 candidates have variation at that position.
253fn find_useful_bit_groups(
254    instructions: &[ValidatedInstruction],
255    candidates: &[usize],
256    width: u32,
257) -> Vec<BitRange> {
258    let mut useful = vec![false; width as usize];
259
260    for bit in 0..width {
261        // Collect fixed values at this bit (skip candidates without fixed bits here)
262        let fixed_values: Vec<Bit> = candidates
263            .iter()
264            .filter_map(|&idx| instructions[idx].fixed_bit_at(bit))
265            .collect();
266
267        // Need at least 2 fixed values with variation
268        if fixed_values.len() >= 2 {
269            let has_variation = fixed_values.iter().any(|&v| v != fixed_values[0]);
270            useful[bit as usize] = has_variation;
271        }
272    }
273
274    // Group contiguous useful bits into ranges
275    let mut groups = Vec::new();
276    let mut i = 0u32;
277    while i < width {
278        if useful[i as usize] {
279            let start = i;
280            while i < width && useful[i as usize] {
281                i += 1;
282            }
283            let end = i - 1;
284            // BitRange: start is MSB, end is LSB. Iteration goes from LSB to MSB.
285            groups.push(BitRange::new(end, start));
286        } else {
287            i += 1;
288        }
289    }
290
291    groups
292}
293
294/// Generate sub-ranges from a full range for finer-grained splitting.
295fn generate_sub_ranges(range: &BitRange) -> Vec<BitRange> {
296    let mut ranges = vec![*range];
297
298    let w = range.width();
299    if w > 1 {
300        if w <= 10 {
301            for sub_width in (1..w).rev() {
302                for start_offset in 0..=(w - sub_width) {
303                    let sub_start = range.start.saturating_sub(start_offset);
304                    let sub_end_needed = sub_start + 1 - sub_width;
305                    if sub_end_needed >= range.end && sub_start >= sub_end_needed {
306                        let sub = BitRange::new(sub_start, sub_end_needed);
307                        if sub != *range && !ranges.contains(&sub) {
308                            ranges.push(sub);
309                        }
310                    }
311                }
312            }
313        } else {
314            let mid = range.end + w / 2;
315            ranges.push(BitRange::new(range.start, mid));
316            if mid > range.end {
317                ranges.push(BitRange::new(mid - 1, range.end));
318            }
319
320            for chunk_size in [4u32, 6, 8] {
321                if chunk_size < w {
322                    if range.start >= chunk_size - 1 {
323                        let sub = BitRange::new(range.start, range.start - chunk_size + 1);
324                        if !ranges.contains(&sub) {
325                            ranges.push(sub);
326                        }
327                    }
328                    let sub = BitRange::new(range.end + chunk_size - 1, range.end);
329                    if !ranges.contains(&sub) {
330                        ranges.push(sub);
331                    }
332                }
333            }
334        }
335    }
336
337    ranges
338}
339
340/// Partition candidates by fixed bit values at a given range.
341/// Instructions with all fixed bits at the range go in their specific bucket.
342/// Instructions with any non-fixed bits (wildcards) go in every bucket.
343fn partition_by_range(
344    instructions: &[ValidatedInstruction],
345    candidates: &[usize],
346    range: BitRange,
347) -> BTreeMap<u64, Vec<usize>> {
348    let mut fixed_map: BTreeMap<u64, Vec<usize>> = BTreeMap::new();
349    let mut wildcards: Vec<usize> = Vec::new();
350
351    for &idx in candidates {
352        if has_all_fixed_at(&instructions[idx], range) {
353            let value = extract_fixed_value(&instructions[idx], range);
354            fixed_map.entry(value).or_default().push(idx);
355        } else {
356            wildcards.push(idx);
357        }
358    }
359
360    // Add wildcards to every partition
361    if !wildcards.is_empty() {
362        for bucket in fixed_map.values_mut() {
363            bucket.extend_from_slice(&wildcards);
364        }
365    }
366
367    fixed_map
368}
369
370/// Extract the fixed bit value of an instruction at a given range.
371fn extract_fixed_value(instr: &ValidatedInstruction, range: BitRange) -> u64 {
372    let mut value: u64 = 0;
373    for bit_pos in range.bits() {
374        value <<= 1;
375        if let Some(Bit::One) = instr.fixed_bit_at(bit_pos) {
376            value |= 1;
377        }
378    }
379    value
380}
381
382/// Separate candidates into specific (all bits fixed in unit 0) vs wildcards (some don't-care bits).
383/// Returns (specific_candidates, wildcard_candidates).
384/// Specific candidates are ordered by number of fixed bits (most specific first).
385fn separate_specific_and_wildcards(
386    instructions: &[ValidatedInstruction],
387    candidates: &[usize],
388    width: u32,
389) -> (Vec<usize>, Vec<usize>) {
390    let mut specifics = Vec::new();
391    let mut wildcards = Vec::new();
392
393    for &idx in candidates {
394        // Check if this instruction has a fixed bit at EVERY position in unit 0
395        let all_fixed = (0..width).all(|bit| instructions[idx].fixed_bit_at(bit).is_some());
396
397        if all_fixed {
398            specifics.push(idx);
399        } else {
400            wildcards.push(idx);
401        }
402    }
403
404    // Sort specifics by number of fixed bits across ALL units (most to least)
405    specifics.sort_by_key(|&idx| std::cmp::Reverse(instructions[idx].fixed_bits().len()));
406
407    // Sort wildcards by number of fixed bits too (for consistent prioritization)
408    wildcards.sort_by_key(|&idx| std::cmp::Reverse(instructions[idx].fixed_bits().len()));
409
410    (specifics, wildcards)
411}