Skip to main content

trident/neural/inference/
execute.rs

1//! Parallel validation and ranking of beam search candidates.
2//!
3//! Takes K candidate token sequences from beam search, decodes them
4//! to TASM, validates equivalence with baseline using rayon parallel
5//! iteration, and returns the cheapest valid candidate.
6
7use rayon::prelude::*;
8
9use crate::cost::scorer::profile_tasm;
10use crate::cost::stack_verifier::verify_equivalent;
11use crate::neural::model::vocab::Vocab;
12
13/// Result of validating and ranking beam candidates.
14pub struct RankedResult {
15    /// Best valid TASM sequence (if any).
16    pub tasm_lines: Vec<String>,
17    /// Clock cycles (table cost) of the best candidate.
18    pub cost: u64,
19    /// How many candidates were valid out of total.
20    pub valid_count: usize,
21    /// Total candidates evaluated.
22    pub total_count: usize,
23}
24
25/// Validate beam search candidates against a baseline and return the best.
26///
27/// Each candidate is decoded from token IDs to TASM strings, then verified
28/// for equivalence with the baseline TASM using the stack verifier.
29/// Valid candidates are profiled for cost, and the cheapest is returned.
30///
31/// Uses rayon for parallel validation across all K candidates.
32///
33/// Returns `None` if no valid candidate is found (fallback to compiler).
34pub fn validate_and_rank(
35    candidates: &[Vec<u32>],
36    vocab: &Vocab,
37    baseline_tasm: &[String],
38    seed: u64,
39) -> Option<RankedResult> {
40    if candidates.is_empty() || baseline_tasm.is_empty() {
41        return None;
42    }
43
44    let results: Vec<Option<(Vec<String>, u64)>> = candidates
45        .par_iter()
46        .map(|token_ids| {
47            // Decode tokens to TASM lines
48            let tasm_lines = vocab.decode_sequence(token_ids);
49            if tasm_lines.is_empty() {
50                return None;
51            }
52
53            // Verify equivalence with baseline
54            if !verify_equivalent(baseline_tasm, &tasm_lines, seed) {
55                return None;
56            }
57
58            // Profile for cost
59            let line_refs: Vec<&str> = tasm_lines.iter().map(|s| s.as_str()).collect();
60            let profile = profile_tasm(&line_refs);
61
62            Some((tasm_lines, profile.cost()))
63        })
64        .collect();
65
66    let valid_count = results.iter().filter(|r| r.is_some()).count();
67    let total_count = candidates.len();
68
69    // Find cheapest valid candidate
70    let best = results
71        .into_iter()
72        .flatten()
73        .min_by_key(|(_, cost)| *cost)?;
74
75    Some(RankedResult {
76        tasm_lines: best.0,
77        cost: best.1,
78        valid_count,
79        total_count,
80    })
81}
82
83#[cfg(test)]
84mod tests {
85    use super::*;
86
87    #[test]
88    fn validate_empty_candidates() {
89        let vocab = Vocab::new();
90        let result = validate_and_rank(&[], &vocab, &["push 1".into()], 42);
91        assert!(result.is_none());
92    }
93
94    #[test]
95    fn validate_empty_baseline() {
96        let vocab = Vocab::new();
97        let result = validate_and_rank(&[vec![3, 0]], &vocab, &[], 42);
98        assert!(result.is_none());
99    }
100
101    #[test]
102    fn validate_equivalent_candidate() {
103        let vocab = Vocab::new();
104        // Baseline: push 1, push 2, add → result 3
105        let baseline: Vec<String> = vec!["push 1".into(), "push 2".into(), "add".into()];
106        // Candidate: push 3 (token 5) → same result
107        let candidates = vec![vec![5]]; // push 3
108        let result = validate_and_rank(&candidates, &vocab, &baseline, 42);
109        assert!(result.is_some());
110        let r = result.unwrap();
111        assert_eq!(r.valid_count, 1);
112        assert_eq!(r.tasm_lines, vec!["push 3"]);
113    }
114
115    #[test]
116    fn validate_picks_cheapest() {
117        let vocab = Vocab::new();
118        // Baseline: push 3
119        let baseline: Vec<String> = vec!["push 3".into()];
120        // Two equivalent candidates:
121        //   push 3 (1 instruction) — token 5
122        //   push 3, nop (2 instructions) — tokens 5, 96
123        let candidates = vec![
124            vec![5, 96], // push 3, nop
125            vec![5],     // push 3
126        ];
127        let result = validate_and_rank(&candidates, &vocab, &baseline, 42);
128        assert!(result.is_some());
129        let r = result.unwrap();
130        assert_eq!(r.valid_count, 2);
131        // Cheapest should be the 1-instruction version
132        assert_eq!(r.tasm_lines, vec!["push 3"]);
133    }
134
135    #[test]
136    fn validate_rejects_invalid() {
137        let vocab = Vocab::new();
138        let baseline: Vec<String> = vec!["push 1".into(), "push 2".into(), "add".into()];
139        // Candidate: push 4 (wrong result)
140        let candidates = vec![vec![6]]; // push 4
141        let result = validate_and_rank(&candidates, &vocab, &baseline, 42);
142        assert!(result.is_none());
143    }
144}