trident/neural/inference/
execute.rs1use rayon::prelude::*;
8
9use crate::cost::scorer::profile_tasm;
10use crate::cost::stack_verifier::verify_equivalent;
11use crate::neural::model::vocab::Vocab;
12
13pub struct RankedResult {
15 pub tasm_lines: Vec<String>,
17 pub cost: u64,
19 pub valid_count: usize,
21 pub total_count: usize,
23}
24
25pub fn validate_and_rank(
35 candidates: &[Vec<u32>],
36 vocab: &Vocab,
37 baseline_tasm: &[String],
38 seed: u64,
39) -> Option<RankedResult> {
40 if candidates.is_empty() || baseline_tasm.is_empty() {
41 return None;
42 }
43
44 let results: Vec<Option<(Vec<String>, u64)>> = candidates
45 .par_iter()
46 .map(|token_ids| {
47 let tasm_lines = vocab.decode_sequence(token_ids);
49 if tasm_lines.is_empty() {
50 return None;
51 }
52
53 if !verify_equivalent(baseline_tasm, &tasm_lines, seed) {
55 return None;
56 }
57
58 let line_refs: Vec<&str> = tasm_lines.iter().map(|s| s.as_str()).collect();
60 let profile = profile_tasm(&line_refs);
61
62 Some((tasm_lines, profile.cost()))
63 })
64 .collect();
65
66 let valid_count = results.iter().filter(|r| r.is_some()).count();
67 let total_count = candidates.len();
68
69 let best = results
71 .into_iter()
72 .flatten()
73 .min_by_key(|(_, cost)| *cost)?;
74
75 Some(RankedResult {
76 tasm_lines: best.0,
77 cost: best.1,
78 valid_count,
79 total_count,
80 })
81}
82
83#[cfg(test)]
84mod tests {
85 use super::*;
86
87 #[test]
88 fn validate_empty_candidates() {
89 let vocab = Vocab::new();
90 let result = validate_and_rank(&[], &vocab, &["push 1".into()], 42);
91 assert!(result.is_none());
92 }
93
94 #[test]
95 fn validate_empty_baseline() {
96 let vocab = Vocab::new();
97 let result = validate_and_rank(&[vec![3, 0]], &vocab, &[], 42);
98 assert!(result.is_none());
99 }
100
101 #[test]
102 fn validate_equivalent_candidate() {
103 let vocab = Vocab::new();
104 let baseline: Vec<String> = vec!["push 1".into(), "push 2".into(), "add".into()];
106 let candidates = vec![vec![5]]; let result = validate_and_rank(&candidates, &vocab, &baseline, 42);
109 assert!(result.is_some());
110 let r = result.unwrap();
111 assert_eq!(r.valid_count, 1);
112 assert_eq!(r.tasm_lines, vec!["push 3"]);
113 }
114
115 #[test]
116 fn validate_picks_cheapest() {
117 let vocab = Vocab::new();
118 let baseline: Vec<String> = vec!["push 3".into()];
120 let candidates = vec![
124 vec![5, 96], vec![5], ];
127 let result = validate_and_rank(&candidates, &vocab, &baseline, 42);
128 assert!(result.is_some());
129 let r = result.unwrap();
130 assert_eq!(r.valid_count, 2);
131 assert_eq!(r.tasm_lines, vec!["push 3"]);
133 }
134
135 #[test]
136 fn validate_rejects_invalid() {
137 let vocab = Vocab::new();
138 let baseline: Vec<String> = vec!["push 1".into(), "push 2".into(), "add".into()];
139 let candidates = vec![vec![6]]; let result = validate_and_rank(&candidates, &vocab, &baseline, 42);
142 assert!(result.is_none());
143 }
144}