Skip to main content

trident/cost/
stack_verifier.rs

1//! Block-level TASM stack verifier for neural training.
2//!
3//! Executes straight-line TASM blocks on concrete u64 values using
4//! Goldilocks field arithmetic. Used to verify neural-generated TASM
5//! produces the same stack transformation as classical TASM.
6//!
7//! Not a full Triton VM — only handles the ~25 instructions that appear
8//! in straight-line blocks. Crypto/IO/memory ops modeled by stack effects
9//! only (correct push/pop counts, dummy values). Full verification uses
10//! trisha (Triton VM execution).
11
12use crate::field::goldilocks::{Goldilocks, MODULUS};
13use crate::field::PrimeField;
14
15/// Stack state after executing a TASM sequence.
16/// Tracks side-channel logs alongside the stack so verification can
17/// detect removal/substitution of I/O, assertion, and divine ops.
18#[derive(Clone, Debug)]
19pub struct StackState {
20    pub stack: Vec<u64>,
21    pub error: bool,
22    pub halted: bool,
23    pub io_output: Vec<u64>,
24    pub divine_log: Vec<usize>,
25    pub assert_log: Vec<u64>,
26    pub assert_vector_log: Vec<Vec<u64>>,
27}
28
29impl StackState {
30    pub fn new(initial: Vec<u64>) -> Self {
31        Self {
32            stack: initial,
33            error: false,
34            halted: false,
35            io_output: Vec::new(),
36            divine_log: Vec::new(),
37            assert_log: Vec::new(),
38            assert_vector_log: Vec::new(),
39        }
40    }
41
42    /// Execute a sequence of TASM lines. Stops on error or halt.
43    pub fn execute(&mut self, lines: &[String]) {
44        for line in lines {
45            if self.error || self.halted {
46                return;
47            }
48            self.execute_line(line);
49        }
50    }
51
52    /// Execute a single TASM instruction line.
53    pub fn execute_line(&mut self, line: &str) {
54        let t = line.trim();
55        if t.is_empty() || t.starts_with("//") || t.ends_with(':') {
56            return;
57        }
58        let parts: Vec<&str> = t.split_whitespace().collect();
59        if parts.is_empty() {
60            return;
61        }
62        let op = parts[0];
63        let arg = parts.get(1).and_then(|s| s.parse::<i64>().ok());
64        let arg_u = parts.get(1).and_then(|s| s.parse::<u64>().ok());
65
66        match op {
67            // --- Literals ---
68            "push" => {
69                let val = if let Some(v) = arg {
70                    if v < 0 {
71                        Goldilocks::from_u64(0)
72                            .sub(Goldilocks::from_u64((-v) as u64))
73                            .to_u64()
74                    } else {
75                        Goldilocks::from_u64(v as u64).to_u64()
76                    }
77                } else if let Some(v) = arg_u {
78                    // Large positive literal (exceeds i64 range)
79                    Goldilocks::from_u64(v).to_u64()
80                } else {
81                    0
82                };
83                self.stack.push(val);
84            }
85
86            // --- Stack manipulation ---
87            "pop" => {
88                let n = arg_u.unwrap_or(1) as usize;
89                if self.stack.len() < n {
90                    self.error = true;
91                    return;
92                }
93                self.stack.truncate(self.stack.len() - n);
94            }
95            "dup" => {
96                let depth = arg_u.unwrap_or(0) as usize;
97                if self.stack.len() <= depth {
98                    self.error = true;
99                    return;
100                }
101                let idx = self.stack.len() - 1 - depth;
102                let val = self.stack[idx];
103                self.stack.push(val);
104            }
105            "swap" => {
106                let depth = arg_u.unwrap_or(1) as usize;
107                if depth == 0 || self.stack.len() <= depth {
108                    self.error = true;
109                    return;
110                }
111                let top = self.stack.len() - 1;
112                self.stack.swap(top, top - depth);
113            }
114            "pick" => {
115                let depth = arg_u.unwrap_or(0) as usize;
116                if self.stack.len() <= depth {
117                    self.error = true;
118                    return;
119                }
120                let idx = self.stack.len() - 1 - depth;
121                let val = self.stack.remove(idx);
122                self.stack.push(val);
123            }
124            "place" => {
125                let depth = arg_u.unwrap_or(0) as usize;
126                if self.stack.is_empty() || self.stack.len() <= depth {
127                    self.error = true;
128                    return;
129                }
130                let val = self.stack.pop().unwrap();
131                let idx = self.stack.len() - depth;
132                self.stack.insert(idx, val);
133            }
134
135            // --- Arithmetic (Goldilocks field) ---
136            "add" => {
137                if self.stack.len() < 2 {
138                    self.error = true;
139                    return;
140                }
141                let b = Goldilocks(self.stack.pop().unwrap());
142                let a = Goldilocks(self.stack.pop().unwrap());
143                self.stack.push(a.add(b).to_u64());
144            }
145            "mul" => {
146                if self.stack.len() < 2 {
147                    self.error = true;
148                    return;
149                }
150                let b = Goldilocks(self.stack.pop().unwrap());
151                let a = Goldilocks(self.stack.pop().unwrap());
152                self.stack.push(a.mul(b).to_u64());
153            }
154            "invert" => {
155                // BUG: this implements negation, but Triton VM invert is
156                // multiplicative inverse (1/x mod p). Kept as-is for
157                // baseline simulation; excluded from ALLOWED candidate list.
158                if self.stack.is_empty() {
159                    self.error = true;
160                    return;
161                }
162                let a = Goldilocks(self.stack.pop().unwrap());
163                self.stack.push(a.neg().to_u64());
164            }
165
166            // --- Comparison ---
167            "eq" => {
168                if self.stack.len() < 2 {
169                    self.error = true;
170                    return;
171                }
172                let b = self.stack.pop().unwrap();
173                let a = self.stack.pop().unwrap();
174                self.stack.push(if a == b { 1 } else { 0 });
175            }
176            "lt" => {
177                if self.stack.len() < 2 {
178                    self.error = true;
179                    return;
180                }
181                let b = self.stack.pop().unwrap();
182                let a = self.stack.pop().unwrap();
183                self.stack.push(if a < b { 1 } else { 0 });
184            }
185
186            // --- Bitwise ---
187            "and" => {
188                if self.stack.len() < 2 {
189                    self.error = true;
190                    return;
191                }
192                let b = self.stack.pop().unwrap();
193                let a = self.stack.pop().unwrap();
194                self.stack.push(a & b);
195            }
196            "xor" => {
197                if self.stack.len() < 2 {
198                    self.error = true;
199                    return;
200                }
201                let b = self.stack.pop().unwrap();
202                let a = self.stack.pop().unwrap();
203                self.stack.push(a ^ b);
204            }
205            "split" => {
206                // x → (hi, lo) where hi = x >> 32, lo = x & 0xFFFFFFFF
207                if self.stack.is_empty() {
208                    self.error = true;
209                    return;
210                }
211                let x = self.stack.pop().unwrap();
212                let lo = x & 0xFFFF_FFFF;
213                let hi = x >> 32;
214                self.stack.push(hi);
215                self.stack.push(lo);
216            }
217            "div_mod" => {
218                // (n, d) → (q, r) where q = n/d, r = n%d
219                if self.stack.len() < 2 {
220                    self.error = true;
221                    return;
222                }
223                let d = self.stack.pop().unwrap();
224                let n = self.stack.pop().unwrap();
225                if d == 0 {
226                    self.error = true;
227                    return;
228                }
229                self.stack.push(n / d);
230                self.stack.push(n % d);
231            }
232            "pow" => {
233                // (base, exp) → base^exp mod p
234                if self.stack.len() < 2 {
235                    self.error = true;
236                    return;
237                }
238                let exp = self.stack.pop().unwrap();
239                let base = Goldilocks(self.stack.pop().unwrap());
240                let mut result = Goldilocks::ONE;
241                let mut b = base;
242                let mut e = exp;
243                while e > 0 {
244                    if e & 1 == 1 {
245                        result = result.mul(b);
246                    }
247                    b = b.mul(b);
248                    e >>= 1;
249                }
250                self.stack.push(result.to_u64());
251            }
252            "log_2_floor" => {
253                if self.stack.is_empty() {
254                    self.error = true;
255                    return;
256                }
257                let x = self.stack.pop().unwrap();
258                if x == 0 {
259                    self.error = true;
260                    return;
261                }
262                self.stack.push(63 - x.leading_zeros() as u64);
263            }
264            "pop_count" => {
265                if self.stack.is_empty() {
266                    self.error = true;
267                    return;
268                }
269                let x = self.stack.pop().unwrap();
270                self.stack.push(x.count_ones() as u64);
271            }
272
273            // --- Control (straight-line only) ---
274            "nop" => {}
275            "halt" => {
276                self.halted = true;
277                return;
278            }
279            "assert" => {
280                if self.stack.is_empty() {
281                    self.error = true;
282                    return;
283                }
284                let v = self.stack.pop().unwrap();
285                self.assert_log.push(v);
286                if v != 1 {
287                    self.error = true;
288                }
289            }
290            "assert_vector" => {
291                // Assert top 5 elements equal next 5
292                if self.stack.len() < 10 {
293                    self.error = true;
294                    return;
295                }
296                let len = self.stack.len();
297                let asserted: Vec<u64> = (0..5).map(|i| self.stack[len - 1 - i]).collect();
298                self.assert_vector_log.push(asserted);
299                for i in 0..5 {
300                    if self.stack[len - 1 - i] != self.stack[len - 6 - i] {
301                        self.error = true;
302                        return;
303                    }
304                }
305                // Pop top 5
306                self.stack.truncate(len - 5);
307            }
308
309            // --- I/O (modeled stack effects, dummy values) ---
310            "read_io" => {
311                let n = arg_u.unwrap_or(1) as usize;
312                for _ in 0..n {
313                    self.stack.push(0);
314                }
315            }
316            "write_io" => {
317                let n = arg_u.unwrap_or(1) as usize;
318                if self.stack.len() < n {
319                    self.error = true;
320                    return;
321                }
322                // Log values being written (TOS first = reverse of slice order)
323                let start = self.stack.len() - n;
324                for i in (start..self.stack.len()).rev() {
325                    self.io_output.push(self.stack[i]);
326                }
327                self.stack.truncate(start);
328            }
329            "divine" => {
330                let n = arg_u.unwrap_or(1) as usize;
331                self.divine_log.push(n);
332                for _ in 0..n {
333                    self.stack.push(0);
334                }
335            }
336
337            // --- Memory (modeled stack effects) ---
338            "read_mem" => {
339                // pop address, push N values + adjusted address
340                let n = arg_u.unwrap_or(1) as usize;
341                if self.stack.is_empty() {
342                    self.error = true;
343                    return;
344                }
345                let _addr = self.stack.pop().unwrap();
346                for _ in 0..n {
347                    self.stack.push(0); // dummy values
348                }
349                self.stack.push(0); // adjusted address
350            }
351            "write_mem" => {
352                // pop N values + address, push adjusted address
353                let n = arg_u.unwrap_or(1) as usize;
354                if self.stack.len() < n + 1 {
355                    self.error = true;
356                    return;
357                }
358                self.stack.truncate(self.stack.len() - n - 1);
359                self.stack.push(0); // adjusted address
360            }
361
362            // --- Crypto (modeled stack effects only) ---
363            "hash" => {
364                // pop 10, push 5
365                if self.stack.len() < 10 {
366                    self.error = true;
367                    return;
368                }
369                self.stack.truncate(self.stack.len() - 10);
370                for _ in 0..5 {
371                    self.stack.push(0);
372                }
373            }
374            "sponge_init" => {}
375            "sponge_absorb" => {
376                if self.stack.len() < 10 {
377                    self.error = true;
378                    return;
379                }
380                self.stack.truncate(self.stack.len() - 10);
381            }
382            "sponge_squeeze" => {
383                for _ in 0..10 {
384                    self.stack.push(0);
385                }
386            }
387            "sponge_absorb_mem" => {
388                // Absorb from memory: pop address, push adjusted address
389                if self.stack.is_empty() {
390                    self.error = true;
391                    return;
392                }
393                let _addr = self.stack.pop().unwrap();
394                self.stack.push(0);
395            }
396            "merkle_step" | "merkle_step_mem" => {
397                // Complex stack effects — skip in block verifier
398            }
399
400            // --- Extension field (modeled as nops for stack) ---
401            "xb_mul" | "x_invert" | "xx_dot_step" | "xb_dot_step" => {}
402
403            // --- Control flow ---
404            // return/recurse are stack-transparent for isolated function verification.
405            // call/skiz/recurse_or_return require branch simulation — unsimulable.
406            "return" | "recurse" => {}
407            "call" | "recurse_or_return" | "skiz" => {
408                self.error = true;
409            }
410
411            // Unknown instruction — ignore (conservative)
412            _ => {}
413        }
414    }
415
416    /// Check if execution completed without errors.
417    pub fn is_valid(&self) -> bool {
418        !self.error
419    }
420}
421
422/// Generate a deterministic test stack for a given seed.
423pub fn generate_test_stack(seed: u64, size: usize) -> Vec<u64> {
424    let mut stack = Vec::with_capacity(size);
425    let mut state = seed
426        .wrapping_mul(6364136223846793005)
427        .wrapping_add(1442695040888963407);
428    for _ in 0..size {
429        state = state.wrapping_mul(6364136223846793005).wrapping_add(1);
430        // Keep values in valid Goldilocks range
431        let val = state % MODULUS;
432        stack.push(val);
433    }
434    stack
435}
436
437/// Verify that candidate TASM produces the same stack as baseline TASM.
438/// Tests with 40 stacks (32 random + 8 structured) — all must pass.
439/// A single concrete test case is trivially gamed by the neural optimizer;
440/// diverse stacks catch wrong operand order, off-by-one dup/swap depth,
441/// missing operations, and false positives on complex blocks.
442/// Conservative: rejects candidates when baseline can't be simulated.
443pub fn verify_equivalent(baseline_tasm: &[String], candidate_tasm: &[String], seed: u64) -> bool {
444    // Instructions the verifier can simulate with exact semantics.
445    // Side-effect ops (write_io, assert, divine, halt, split) are now
446    // supported via side-channel logging — the verifier compares I/O
447    // output, assertion values, divine call patterns, and halt state
448    // alongside the stack.
449    const ALLOWED: &[&str] = &[
450        "push",
451        "pop",
452        "dup",
453        "swap",
454        "pick",
455        "place",
456        "add",
457        "mul",
458        // invert is NOT allowed — verifier implements negation,
459        // Triton VM does multiplicative inverse (1/x mod p).
460        "eq",
461        "lt",
462        "and",
463        "xor",
464        "split",
465        "div_mod",
466        "pow",
467        "log_2_floor",
468        "pop_count",
469        "nop",
470        "halt",
471        "write_io",
472        "read_io",
473        "divine",
474        "assert",
475        "assert_vector",
476        // Memory ops — simulated with dummy values (correct stack effects)
477        "read_mem",
478        "write_mem",
479        // Crypto ops — simulated with dummy values (correct stack effects)
480        "hash",
481        "sponge_init",
482        "sponge_absorb",
483        "sponge_squeeze",
484        "sponge_absorb_mem",
485        "merkle_step",
486        "merkle_step_mem",
487        // Extension field ops — simulated as nops (correct stack effects)
488        "xb_mul",
489        "x_invert",
490        "xx_dot_step",
491        "xb_dot_step",
492        // Control flow — return/recurse are no-ops for stack in isolated functions
493        "return",
494        "recurse",
495    ];
496    for line in candidate_tasm {
497        let op = line.trim().split_whitespace().next().unwrap_or("");
498        if op.is_empty() || op.starts_with("//") || op.ends_with(':') {
499            continue;
500        }
501        if !ALLOWED.contains(&op) {
502            return false;
503        }
504    }
505
506    // Test stacks must include structured values, not just random ones.
507    // Random Goldilocks values make eq/lt comparisons near-deterministic:
508    //   P(random a == random b) ≈ 2^-64 → "push 0" fakes "dup 1 | dup 1 | eq"
509    // Structured stacks expose these exploits by including zeros, duplicates,
510    // and ordered values where comparisons actually produce different results.
511    let test_stacks: Vec<Vec<u64>> = {
512        let mut stacks = Vec::with_capacity(40);
513        // 32 random seeds — 4 was too few; complex blocks (keccak rounds,
514        // field arithmetic) can produce false positives when random stacks
515        // happen to collide. 32 brings P(false positive) to ~2^(-64*32).
516        for i in 0..32u64 {
517            let s = seed.wrapping_mul(6364136223846793005).wrapping_add(i);
518            stacks.push(generate_test_stack(s, 16));
519        }
520        // All zeros (eq always returns 1, catches "push 0" faking eq)
521        stacks.push(vec![0; 16]);
522        // All ones
523        stacks.push(vec![1; 16]);
524        // Adjacent pairs equal: [5,5,3,3,7,7,...] (catches dup+eq exploits)
525        stacks.push(vec![5, 5, 3, 3, 7, 7, 2, 2, 9, 9, 1, 1, 4, 4, 8, 8]);
526        // Same value everywhere (catches any eq/dup combination)
527        stacks.push(vec![42; 16]);
528        // Ascending small values (catches lt/comparison order exploits)
529        stacks.push((0..16).collect());
530        // Descending (opposite lt behavior)
531        stacks.push((0..16).rev().collect());
532        // Mixed: zeros and large values (catches split/div_mod edge cases)
533        stacks.push(vec![
534            0,
535            MODULUS - 1,
536            0,
537            1,
538            0,
539            MODULUS - 1,
540            0,
541            1,
542            0,
543            MODULUS - 1,
544            0,
545            1,
546            0,
547            MODULUS - 1,
548            0,
549            1,
550        ]);
551        // Powers of 2 (catches pop_count, log_2_floor, split edge cases)
552        stacks.push(vec![
553            1,
554            2,
555            4,
556            8,
557            16,
558            32,
559            64,
560            128,
561            256,
562            512,
563            1024,
564            2048,
565            1u64 << 32,
566            1u64 << 33,
567            1u64 << 48,
568            1u64 << 63,
569        ]);
570        // Split-specific: values with non-trivial hi parts (all < MODULUS)
571        stacks.push(vec![
572            0x0000_0001_0000_0001, // hi=1, lo=1
573            0x0000_0002_FFFF_FFFF, // hi=2, lo=max32
574            0x0000_FFFF_0000_0000, // hi=65535, lo=0
575            0x0000_0000_FFFF_FFFF, // hi=0, lo=max32
576            0x0000_0001_0000_0000, // hi=1, lo=0
577            0x0000_0000_0000_0001, // hi=0, lo=1
578            0x0000_ABCD_1234_5678, // mixed
579            0x0000_0000_8000_0000, // hi=0, lo=2^31
580            0x0000_0001_8000_0000, // hi=1, lo=2^31
581            0x0000_7FFF_7FFF_FFFF, // hi=32767, lo=max31+
582            0x0000_0010_0000_0010, // hi=16, lo=16
583            0x0000_0100_0001_0000, // hi=256, lo=65536
584            0,
585            1,
586            MODULUS - 1,
587            MODULUS - 2,
588        ]);
589        stacks
590    };
591
592    for test_stack in &test_stacks {
593        let mut baseline_state = StackState::new(test_stack.clone());
594        baseline_state.execute(baseline_tasm);
595
596        // If baseline can't be simulated, we can't verify — reject candidate.
597        if baseline_state.error {
598            return false;
599        }
600
601        let mut candidate_state = StackState::new(test_stack.clone());
602        candidate_state.execute(candidate_tasm);
603
604        if candidate_state.error {
605            return false;
606        }
607
608        if baseline_state.stack != candidate_state.stack {
609            return false;
610        }
611        if baseline_state.halted != candidate_state.halted {
612            return false;
613        }
614        if baseline_state.io_output != candidate_state.io_output {
615            return false;
616        }
617        if baseline_state.divine_log != candidate_state.divine_log {
618            return false;
619        }
620        if baseline_state.assert_log != candidate_state.assert_log {
621            return false;
622        }
623        if baseline_state.assert_vector_log != candidate_state.assert_vector_log {
624            return false;
625        }
626    }
627    true
628}
629
630/// Score how close a candidate is to matching the baseline.
631/// Runs both on one test stack and returns a shaped fitness score:
632///   0   = candidate crashes
633///   100 = doesn't crash
634///   200 = stack depth matches
635///   400 = 50%+ of stack values match
636///   600 = 90%+ of stack values match
637///   800 = all stack values match
638///   900 = stack + all side-channels match on this stack
639pub fn score_candidate(baseline_tasm: &[String], candidate_tasm: &[String], seed: u64) -> i64 {
640    let test_stack = generate_test_stack(seed.wrapping_mul(6364136223846793005), 16);
641
642    let mut bl = StackState::new(test_stack.clone());
643    bl.execute(baseline_tasm);
644    if bl.error {
645        return 0;
646    }
647
648    let mut cd = StackState::new(test_stack);
649    cd.execute(candidate_tasm);
650
651    if cd.error {
652        return 0;
653    }
654    let mut score: i64 = 100;
655
656    if bl.stack.len() == cd.stack.len() {
657        score = 200;
658        let matches = bl
659            .stack
660            .iter()
661            .zip(&cd.stack)
662            .filter(|(a, b)| a == b)
663            .count();
664        let total = bl.stack.len().max(1);
665        let ratio = matches as f64 / total as f64;
666        if ratio >= 0.5 {
667            score = 400;
668        }
669        if ratio >= 0.9 {
670            score = 600;
671        }
672        if matches == total {
673            score = 800;
674        }
675    }
676
677    if score >= 800
678        && bl.halted == cd.halted
679        && bl.io_output == cd.io_output
680        && bl.divine_log == cd.divine_log
681        && bl.assert_log == cd.assert_log
682        && bl.assert_vector_log == cd.assert_vector_log
683    {
684        score = 900;
685    }
686
687    score
688}
689
690/// Diagnose why verification failed for a candidate vs baseline.
691/// Runs both on the first test stack and reports what differs.
692/// Returns a short human-readable reason string.
693pub fn diagnose_failure(baseline_tasm: &[String], candidate_tasm: &[String], seed: u64) -> String {
694    let test_stack = generate_test_stack(seed.wrapping_mul(6364136223846793005), 16);
695
696    let mut bl = StackState::new(test_stack.clone());
697    bl.execute(baseline_tasm);
698    if bl.error {
699        return "baseline errors on test stack".into();
700    }
701
702    let mut cd = StackState::new(test_stack);
703    cd.execute(candidate_tasm);
704    if cd.error {
705        return "candidate errors on test stack".into();
706    }
707
708    if bl.stack != cd.stack {
709        let bl_len = bl.stack.len();
710        let cd_len = cd.stack.len();
711        if bl_len != cd_len {
712            return format!("stack depth: baseline={} candidate={}", bl_len, cd_len);
713        }
714        for i in 0..bl_len {
715            if bl.stack[i] != cd.stack[i] {
716                return format!(
717                    "stack[{}]: baseline={} candidate={} (depth {})",
718                    i, bl.stack[i], cd.stack[i], bl_len
719                );
720            }
721        }
722    }
723    if bl.halted != cd.halted {
724        return format!("halted: baseline={} candidate={}", bl.halted, cd.halted);
725    }
726    if bl.io_output != cd.io_output {
727        return format!(
728            "io_output: baseline={:?} candidate={:?}",
729            &bl.io_output[..bl.io_output.len().min(5)],
730            &cd.io_output[..cd.io_output.len().min(5)]
731        );
732    }
733    if bl.divine_log != cd.divine_log {
734        return format!(
735            "divine_log: baseline={:?} candidate={:?}",
736            bl.divine_log, cd.divine_log
737        );
738    }
739    if bl.assert_log != cd.assert_log {
740        return format!(
741            "assert_log: baseline={:?} candidate={:?}",
742            &bl.assert_log[..bl.assert_log.len().min(5)],
743            &cd.assert_log[..cd.assert_log.len().min(5)]
744        );
745    }
746    if bl.assert_vector_log != cd.assert_vector_log {
747        return "assert_vector_log differs".into();
748    }
749    // Passed on this test stack but fails on others
750    "passes first stack, fails on structured stacks".into()
751}
752
753/// Score a neural model's raw output against a baseline block.
754/// Decodes the output, verifies equivalence, and returns the lower cost
755/// (or baseline cost if candidate is invalid/worse).
756pub fn score_neural_output(
757    raw_codes: &[u32],
758    block_baseline: u64,
759    baseline_tasm: &[String],
760    block_seed: u64,
761) -> u64 {
762    use crate::ir::tir::lower::decode_output;
763
764    let codes: Vec<u64> = raw_codes
765        .iter()
766        .take_while(|&&c| c != 0)
767        .map(|&c| c as u64)
768        .collect();
769    if codes.is_empty() {
770        return block_baseline;
771    }
772    let candidate_lines = decode_output(&codes);
773    if candidate_lines.is_empty() {
774        return block_baseline;
775    }
776    // No baseline = nothing to verify against = reject.
777    if baseline_tasm.is_empty() || !verify_equivalent(baseline_tasm, &candidate_lines, block_seed) {
778        return block_baseline;
779    }
780    let profile = crate::cost::scorer::profile_tasm(
781        &candidate_lines
782            .iter()
783            .map(|s| s.as_str())
784            .collect::<Vec<_>>(),
785    );
786    profile.cost().min(block_baseline)
787}
788
789/// Score improvement of a neural candidate over baseline.
790/// Returns 0 for failures or equal/worse cost, positive value for genuine wins.
791/// Used by training to reward only actual improvement (not negated cost).
792pub fn score_neural_improvement(
793    raw_codes: &[u32],
794    block_baseline: u64,
795    baseline_tasm: &[String],
796    block_seed: u64,
797) -> u64 {
798    let cost = score_neural_output(raw_codes, block_baseline, baseline_tasm, block_seed);
799    block_baseline.saturating_sub(cost)
800}
801
802#[cfg(test)]
803mod tests {
804    use super::*;
805
806    fn lines(s: &[&str]) -> Vec<String> {
807        s.iter().map(|l| l.to_string()).collect()
808    }
809
810    #[test]
811    fn push_add() {
812        let mut s = StackState::new(vec![]);
813        s.execute(&lines(&["push 1", "push 2", "add"]));
814        assert!(s.is_valid());
815        assert_eq!(s.stack, vec![3]);
816    }
817
818    #[test]
819    fn dup_swap() {
820        let mut s = StackState::new(vec![10, 20]);
821        s.execute(&lines(&["dup 1", "swap 1"]));
822        assert!(s.is_valid());
823        // [10, 20] → dup 1 → [10, 20, 10] → swap 1 → [10, 10, 20]
824        assert_eq!(s.stack, vec![10, 10, 20]);
825    }
826
827    #[test]
828    fn underflow_is_error() {
829        let mut s = StackState::new(vec![]);
830        s.execute(&lines(&["add"]));
831        assert!(!s.is_valid());
832    }
833
834    #[test]
835    fn goldilocks_arithmetic() {
836        let mut s = StackState::new(vec![]);
837        // push p-1, push 2, add → should wrap to 0 (since (p-1)+2 = p+1 ≡ 1 mod p... wait)
838        // Actually (p-1) + 1 = p ≡ 0 mod p
839        s.execute(&lines(&["push 18446744069414584320", "push 1", "add"]));
840        assert!(s.is_valid());
841        assert_eq!(s.stack, vec![0]); // (MODULUS - 1) + 1 = 0 mod p
842    }
843
844    #[test]
845    fn mul_field() {
846        let mut s = StackState::new(vec![]);
847        s.execute(&lines(&["push 3", "push 5", "mul"]));
848        assert!(s.is_valid());
849        assert_eq!(s.stack, vec![15]);
850    }
851
852    #[test]
853    fn split_instruction() {
854        let mut s = StackState::new(vec![]);
855        // 0x0000_0003_0000_0005 = 3 * 2^32 + 5
856        let val = 3u64 * (1u64 << 32) + 5;
857        s.stack.push(val);
858        s.execute(&lines(&["split"]));
859        assert!(s.is_valid());
860        assert_eq!(s.stack, vec![3, 5]); // hi=3, lo=5
861    }
862
863    #[test]
864    fn eq_comparison() {
865        let mut s = StackState::new(vec![42, 42]);
866        s.execute(&lines(&["eq"]));
867        assert!(s.is_valid());
868        assert_eq!(s.stack, vec![1]);
869
870        let mut s2 = StackState::new(vec![42, 43]);
871        s2.execute(&lines(&["eq"]));
872        assert!(s2.is_valid());
873        assert_eq!(s2.stack, vec![0]);
874    }
875
876    #[test]
877    fn negative_push() {
878        let mut s = StackState::new(vec![]);
879        s.execute(&lines(&["push 5", "push -1", "add"]));
880        assert!(s.is_valid());
881        assert_eq!(s.stack, vec![4]);
882    }
883
884    #[test]
885    fn control_flow_is_error() {
886        let mut s = StackState::new(vec![1]);
887        s.execute(&lines(&["skiz"]));
888        assert!(!s.is_valid());
889    }
890
891    #[test]
892    fn comments_and_labels_ignored() {
893        let mut s = StackState::new(vec![]);
894        s.execute(&lines(&["// comment", "__label:", "push 1", ""]));
895        assert!(s.is_valid());
896        assert_eq!(s.stack, vec![1]);
897    }
898
899    #[test]
900    fn verify_equivalent_same() {
901        let baseline = lines(&["push 1", "push 2", "add"]);
902        let candidate = lines(&["push 3"]); // same result, different path
903        assert!(verify_equivalent(&baseline, &candidate, 42));
904    }
905
906    #[test]
907    fn verify_equivalent_different() {
908        let baseline = lines(&["push 1", "push 2", "add"]);
909        let candidate = lines(&["push 4"]); // different result
910        assert!(!verify_equivalent(&baseline, &candidate, 42));
911    }
912
913    #[test]
914    fn verify_with_stack_input() {
915        // Both should add TOS to second element
916        let baseline = lines(&["dup 0", "dup 2", "add"]);
917        let candidate = lines(&["dup 0", "dup 2", "add"]);
918        assert!(verify_equivalent(&baseline, &candidate, 123));
919    }
920
921    #[test]
922    fn pow_instruction() {
923        let mut s = StackState::new(vec![]);
924        s.execute(&lines(&["push 2", "push 10", "pow"]));
925        assert!(s.is_valid());
926        assert_eq!(s.stack, vec![1024]); // 2^10
927    }
928
929    #[test]
930    fn pop_count_instruction() {
931        let mut s = StackState::new(vec![0b1010_1010]);
932        s.execute(&lines(&["pop_count"]));
933        assert!(s.is_valid());
934        assert_eq!(s.stack, vec![4]);
935    }
936
937    #[test]
938    fn sbox_pattern() {
939        // x^5 via dup/mul chain (from poseidon baseline)
940        let x = 7u64;
941        let mut s = StackState::new(vec![x]);
942        s.execute(&lines(&[
943            "dup 0", "dup 0", "mul", // x, x^2
944            "dup 0", "mul", // x, x^4
945            "mul", // x^5
946        ]));
947        assert!(s.is_valid());
948        // 7^5 = 16807
949        assert_eq!(s.stack, vec![16807]);
950    }
951
952    #[test]
953    fn verify_rejects_when_baseline_errors() {
954        // Baseline has control flow (can't simulate) — candidate must be rejected
955        let baseline = lines(&["push 1", "call some_fn", "add"]);
956        let candidate = lines(&["push 42"]);
957        assert!(!verify_equivalent(&baseline, &candidate, 42));
958    }
959
960    #[test]
961    fn verify_rejects_when_candidate_errors() {
962        let baseline = lines(&["push 1", "push 2", "add"]);
963        let _candidate = lines(&["add"]); // underflow on empty-ish stack (after 16 elements, add pops 2 and pushes 1 — actually this succeeds)
964                                          // Use a candidate that definitely errors
965        let bad_candidate = lines(&["pop 100"]); // underflow
966        assert!(!verify_equivalent(&baseline, &bad_candidate, 42));
967    }
968
969    #[test]
970    fn verify_rejects_both_error() {
971        // Both error — conservative: reject (no free passes)
972        let baseline = lines(&["call foo"]); // errors
973        let candidate = lines(&["call bar"]); // also errors
974        assert!(!verify_equivalent(&baseline, &candidate, 42));
975    }
976
977    #[test]
978    fn generate_test_stack_deterministic() {
979        let a = generate_test_stack(42, 8);
980        let b = generate_test_stack(42, 8);
981        assert_eq!(a, b);
982        // Different seed → different stack
983        let c = generate_test_stack(43, 8);
984        assert_ne!(a, c);
985    }
986
987    #[test]
988    fn generate_test_stack_in_range() {
989        let stack = generate_test_stack(99, 100);
990        for val in &stack {
991            assert!(*val < MODULUS, "value {} >= MODULUS", val);
992        }
993    }
994
995    // --- Side-channel verification tests ---
996
997    #[test]
998    fn write_io_removal_caught() {
999        // Baseline writes TOS; candidate just pops — same stack, different I/O
1000        let baseline = lines(&["write_io 1"]);
1001        let candidate = lines(&["pop 1"]);
1002        assert!(!verify_equivalent(&baseline, &candidate, 77));
1003    }
1004
1005    #[test]
1006    fn write_io_equivalent_accepted() {
1007        // Both write the same value
1008        let baseline = lines(&["write_io 1"]);
1009        let candidate = lines(&["write_io 1"]);
1010        assert!(verify_equivalent(&baseline, &candidate, 77));
1011    }
1012
1013    #[test]
1014    fn assert_removal_caught() {
1015        // Baseline asserts TOS==1; candidate just pops — different assert_log
1016        let baseline = lines(&["push 1", "assert"]);
1017        let candidate = lines(&["push 1", "pop 1"]);
1018        assert!(!verify_equivalent(&baseline, &candidate, 88));
1019    }
1020
1021    #[test]
1022    fn assert_equivalent_accepted() {
1023        let baseline = lines(&["push 1", "assert"]);
1024        let candidate = lines(&["push 1", "assert"]);
1025        assert!(verify_equivalent(&baseline, &candidate, 88));
1026    }
1027
1028    #[test]
1029    fn divine_replacement_caught() {
1030        // Baseline uses divine 1; candidate uses push 0 — same stack (both push 0)
1031        // but divine_log differs
1032        let baseline = lines(&["divine 1"]);
1033        let candidate = lines(&["push 0"]);
1034        assert!(!verify_equivalent(&baseline, &candidate, 99));
1035    }
1036
1037    #[test]
1038    fn divine_equivalent_accepted() {
1039        let baseline = lines(&["divine 1"]);
1040        let candidate = lines(&["divine 1"]);
1041        assert!(verify_equivalent(&baseline, &candidate, 99));
1042    }
1043
1044    #[test]
1045    fn halt_removal_caught() {
1046        // Baseline halts before push 99; candidate executes push 99
1047        let baseline = lines(&["halt", "push 99"]);
1048        let candidate = lines(&["push 99"]);
1049        assert!(!verify_equivalent(&baseline, &candidate, 55));
1050    }
1051
1052    #[test]
1053    fn halt_equivalent_accepted() {
1054        let baseline = lines(&["push 1", "halt"]);
1055        let candidate = lines(&["push 1", "halt"]);
1056        assert!(verify_equivalent(&baseline, &candidate, 55));
1057    }
1058
1059    #[test]
1060    fn split_now_verifiable() {
1061        // split should work now — same input, same deterministic output
1062        let baseline = lines(&["split"]);
1063        let candidate = lines(&["split"]);
1064        assert!(verify_equivalent(&baseline, &candidate, 42));
1065    }
1066
1067    #[test]
1068    fn split_wrong_replacement_caught() {
1069        // Candidate fakes split with wrong stack effect
1070        let baseline = lines(&["split"]);
1071        let candidate = lines(&["dup 0"]); // wrong: pushes copy instead of hi/lo
1072        assert!(!verify_equivalent(&baseline, &candidate, 42));
1073    }
1074
1075    #[test]
1076    fn assert_vector_removal_caught() {
1077        // Baseline: assert_vector (compare top 5 with next 5, pop 5)
1078        // Candidate: pop 5 (same stack effect, no assertion)
1079        let baseline = lines(&[
1080            "push 1",
1081            "push 2",
1082            "push 3",
1083            "push 4",
1084            "push 5",
1085            "push 1",
1086            "push 2",
1087            "push 3",
1088            "push 4",
1089            "push 5",
1090            "assert_vector",
1091        ]);
1092        let candidate = lines(&[
1093            "push 1", "push 2", "push 3", "push 4", "push 5", "push 1", "push 2", "push 3",
1094            "push 4", "push 5", "pop 5",
1095        ]);
1096        assert!(!verify_equivalent(&baseline, &candidate, 66));
1097    }
1098
1099    #[test]
1100    fn score_candidate_crash_returns_zero() {
1101        let baseline = lines(&["push 1", "push 2", "add"]);
1102        // Pop everything (16 initial + 0 pushed), then pop again → underflow
1103        let candidate = lines(&["pop 5", "pop 5", "pop 5", "pop 5", "pop 1"]);
1104        assert_eq!(score_candidate(&baseline, &candidate, 42), 0);
1105    }
1106
1107    #[test]
1108    fn score_candidate_no_crash_wrong_depth() {
1109        let baseline = lines(&["push 1", "push 2", "add"]); // depth +1
1110        let candidate = lines(&["push 1", "push 2"]); // depth +2
1111        let score = score_candidate(&baseline, &candidate, 42);
1112        assert_eq!(score, 100); // doesn't crash, but wrong depth
1113    }
1114
1115    #[test]
1116    fn score_candidate_right_depth_wrong_values() {
1117        let baseline = lines(&["push 1"]); // pushes 1
1118        let candidate = lines(&["push 2"]); // pushes 2, same depth
1119        let score = score_candidate(&baseline, &candidate, 42);
1120        assert!(
1121            score >= 200,
1122            "right depth should score >= 200, got {}",
1123            score
1124        );
1125    }
1126
1127    #[test]
1128    fn score_candidate_identical_scores_900() {
1129        let baseline = lines(&["push 1", "push 2", "add"]);
1130        let candidate = lines(&["push 1", "push 2", "add"]);
1131        let score = score_candidate(&baseline, &candidate, 42);
1132        assert_eq!(score, 900);
1133    }
1134
1135    #[test]
1136    fn score_candidate_nop_equivalent_scores_900() {
1137        let baseline = lines(&["push 5"]);
1138        let candidate = lines(&["push 5", "nop"]);
1139        let score = score_candidate(&baseline, &candidate, 42);
1140        assert_eq!(score, 900);
1141    }
1142}