klvmr/
traverse_path.rs

1use crate::allocator::{Allocator, NodePtr, SExp};
2use crate::cost::Cost;
3use crate::error::EvalErr;
4use crate::reduction::{Reduction, Response};
5
6// lowered from measured 147 per bit. It doesn't seem to take this long in
7// practice
8const TRAVERSE_BASE_COST: Cost = 40;
9const TRAVERSE_COST_PER_ZERO_BYTE: Cost = 4;
10const TRAVERSE_COST_PER_BIT: Cost = 4;
11
12// `run_program` has two stacks: the operand stack (of `Node` objects) and the
13// operator stack (of Operation)
14
15// return a bitmask with a single bit set, for the most significant set bit in
16// the input byte
17pub(crate) fn msb_mask(byte: u8) -> u8 {
18    let mut byte = (byte | (byte >> 1)) as u32;
19    byte |= byte >> 2;
20    byte |= byte >> 4;
21    debug_assert!((byte + 1) >> 1 <= 0x80);
22    ((byte + 1) >> 1) as u8
23}
24
25// return the index of the first non-zero byte in buf. If all bytes are 0, the
26// length (one past end) will be returned.
27pub const fn first_non_zero(buf: &[u8]) -> usize {
28    let mut c: usize = 0;
29    while c < buf.len() && buf[c] == 0 {
30        c += 1;
31    }
32    c
33}
34
35pub fn traverse_path(allocator: &Allocator, node_index: &[u8], args: NodePtr) -> Response {
36    let mut arg_list: NodePtr = args;
37
38    // find first non-zero byte
39    let first_bit_byte_index = first_non_zero(node_index);
40
41    let mut cost: Cost = TRAVERSE_BASE_COST
42        + (first_bit_byte_index as Cost) * TRAVERSE_COST_PER_ZERO_BYTE
43        + TRAVERSE_COST_PER_BIT;
44
45    if first_bit_byte_index >= node_index.len() {
46        return Ok(Reduction(cost, allocator.nil()));
47    }
48
49    // find first non-zero bit (the most significant bit is a sentinel)
50    let last_bitmask = msb_mask(node_index[first_bit_byte_index]);
51
52    // follow through the bits, moving left and right
53    let mut byte_idx = node_index.len() - 1;
54    let mut bitmask = 0x01;
55    while byte_idx > first_bit_byte_index || bitmask < last_bitmask {
56        let is_bit_set: bool = (node_index[byte_idx] & bitmask) != 0;
57        match allocator.sexp(arg_list) {
58            SExp::Atom => {
59                return Err(EvalErr::PathIntoAtom);
60            }
61            SExp::Pair(left, right) => {
62                arg_list = if is_bit_set { right } else { left };
63            }
64        }
65        if bitmask == 0x80 {
66            bitmask = 0x01;
67            byte_idx -= 1;
68        } else {
69            bitmask <<= 1;
70        }
71        cost += TRAVERSE_COST_PER_BIT;
72    }
73    Ok(Reduction(cost, arg_list))
74}
75
76// The cost calculation for this version of traverse_path assumes the node_index has the canonical
77// integer representation (which is true for SmallAtom in the allocator). If there are any
78// redundant leading zeros, the slow path must be used
79pub fn traverse_path_fast(allocator: &Allocator, mut node_index: u32, args: NodePtr) -> Response {
80    if node_index == 0 {
81        return Ok(Reduction(
82            TRAVERSE_BASE_COST + TRAVERSE_COST_PER_BIT,
83            allocator.nil(),
84        ));
85    }
86
87    let mut arg_list: NodePtr = args;
88
89    let mut cost: Cost = TRAVERSE_BASE_COST + TRAVERSE_COST_PER_BIT;
90    let mut num_bits = 0;
91    while node_index != 1 {
92        let SExp::Pair(left, right) = allocator.sexp(arg_list) else {
93            return Err(EvalErr::PathIntoAtom);
94        };
95
96        let is_bit_set: bool = (node_index & 0x01) != 0;
97        arg_list = if is_bit_set { right } else { left };
98        node_index >>= 1;
99        num_bits += 1
100    }
101
102    cost += num_bits * TRAVERSE_COST_PER_BIT;
103    // since positive numbers sometimes need a leading zero, e.g. 0x80, 0x8000 etc. We also
104    // need to add the cost of that leading zero byte
105    if num_bits == 7 || num_bits == 15 || num_bits == 23 || num_bits == 31 {
106        cost += TRAVERSE_COST_PER_ZERO_BYTE;
107    }
108
109    Ok(Reduction(cost, arg_list))
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115
116    #[test]
117    fn test_msb_mask() {
118        assert_eq!(msb_mask(0x0), 0x0);
119        assert_eq!(msb_mask(0x01), 0x01);
120        assert_eq!(msb_mask(0x02), 0x02);
121        assert_eq!(msb_mask(0x04), 0x04);
122        assert_eq!(msb_mask(0x08), 0x08);
123        assert_eq!(msb_mask(0x10), 0x10);
124        assert_eq!(msb_mask(0x20), 0x20);
125        assert_eq!(msb_mask(0x40), 0x40);
126        assert_eq!(msb_mask(0x80), 0x80);
127
128        assert_eq!(msb_mask(0x44), 0x40);
129        assert_eq!(msb_mask(0x2a), 0x20);
130        assert_eq!(msb_mask(0xff), 0x80);
131        assert_eq!(msb_mask(0x0f), 0x08);
132    }
133
134    #[test]
135    fn test_first_non_zero() {
136        assert_eq!(first_non_zero(&[]), 0);
137        assert_eq!(first_non_zero(&[1]), 0);
138        assert_eq!(first_non_zero(&[0]), 1);
139        assert_eq!(first_non_zero(&[0, 0, 0, 1, 1, 1]), 3);
140        assert_eq!(first_non_zero(&[0, 0, 0, 0, 0, 0]), 6);
141        assert_eq!(first_non_zero(&[1, 0, 0, 0, 0, 0]), 0);
142    }
143
144    #[test]
145    fn test_traverse_path() {
146        use crate::allocator::Allocator;
147
148        let mut a = Allocator::new();
149        let nul = a.nil();
150        let n1 = a.new_atom(&[0, 1, 2]).unwrap();
151        let n2 = a.new_atom(&[4, 5, 6]).unwrap();
152
153        assert_eq!(traverse_path(&a, &[], n1).unwrap(), Reduction(44, nul));
154        assert_eq!(traverse_path(&a, &[0b1], n1).unwrap(), Reduction(44, n1));
155        assert_eq!(traverse_path(&a, &[0b1], n2).unwrap(), Reduction(44, n2));
156
157        // cost for leading zeros
158        assert_eq!(traverse_path(&a, &[0], n1).unwrap(), Reduction(48, nul));
159        assert_eq!(traverse_path(&a, &[0, 0], n1).unwrap(), Reduction(52, nul));
160        assert_eq!(
161            traverse_path(&a, &[0, 0, 0], n1).unwrap(),
162            Reduction(56, nul)
163        );
164        assert_eq!(
165            traverse_path(&a, &[0, 0, 0, 0], n1).unwrap(),
166            Reduction(60, nul)
167        );
168
169        let n3 = a.new_pair(n1, n2).unwrap();
170        assert_eq!(traverse_path(&a, &[0b1], n3).unwrap(), Reduction(44, n3));
171        assert_eq!(traverse_path(&a, &[0b10], n3).unwrap(), Reduction(48, n1));
172        assert_eq!(traverse_path(&a, &[0b11], n3).unwrap(), Reduction(48, n2));
173        assert_eq!(traverse_path(&a, &[0b11], n3).unwrap(), Reduction(48, n2));
174
175        let list = a.new_pair(n1, nul).unwrap();
176        let list = a.new_pair(n2, list).unwrap();
177
178        assert_eq!(traverse_path(&a, &[0b10], list).unwrap(), Reduction(48, n2));
179        assert_eq!(
180            traverse_path(&a, &[0b101], list).unwrap(),
181            Reduction(52, n1)
182        );
183        assert_eq!(
184            traverse_path(&a, &[0b111], list).unwrap(),
185            Reduction(52, nul)
186        );
187
188        // errors
189        assert_eq!(
190            traverse_path(&a, &[0b1011], list).unwrap_err(),
191            EvalErr::PathIntoAtom
192        );
193        assert_eq!(
194            traverse_path(&a, &[0b1101], list).unwrap_err(),
195            EvalErr::PathIntoAtom
196        );
197        assert_eq!(
198            traverse_path(&a, &[0b1001], list).unwrap_err(),
199            EvalErr::PathIntoAtom
200        );
201        assert_eq!(
202            traverse_path(&a, &[0b1010], list).unwrap_err(),
203            EvalErr::PathIntoAtom
204        );
205        assert_eq!(
206            traverse_path(&a, &[0b1110], list).unwrap_err(),
207            EvalErr::PathIntoAtom
208        );
209    }
210
211    #[test]
212    fn test_traverse_path_fast_fast() {
213        use crate::allocator::Allocator;
214
215        let mut a = Allocator::new();
216        let nul = a.nil();
217        let n1 = a.new_atom(&[0, 1, 2]).unwrap();
218        let n2 = a.new_atom(&[4, 5, 6]).unwrap();
219
220        assert_eq!(traverse_path_fast(&a, 0, n1).unwrap(), Reduction(44, nul));
221        assert_eq!(traverse_path_fast(&a, 0b1, n1).unwrap(), Reduction(44, n1));
222        assert_eq!(traverse_path_fast(&a, 0b1, n2).unwrap(), Reduction(44, n2));
223
224        let n3 = a.new_pair(n1, n2).unwrap();
225        assert_eq!(traverse_path_fast(&a, 0b1, n3).unwrap(), Reduction(44, n3));
226        assert_eq!(traverse_path_fast(&a, 0b10, n3).unwrap(), Reduction(48, n1));
227        assert_eq!(traverse_path_fast(&a, 0b11, n3).unwrap(), Reduction(48, n2));
228        assert_eq!(traverse_path_fast(&a, 0b11, n3).unwrap(), Reduction(48, n2));
229
230        let list = a.new_pair(n1, nul).unwrap();
231        let list = a.new_pair(n2, list).unwrap();
232
233        assert_eq!(
234            traverse_path_fast(&a, 0b10, list).unwrap(),
235            Reduction(48, n2)
236        );
237        assert_eq!(
238            traverse_path_fast(&a, 0b101, list).unwrap(),
239            Reduction(52, n1)
240        );
241        assert_eq!(
242            traverse_path_fast(&a, 0b111, list).unwrap(),
243            Reduction(52, nul)
244        );
245
246        // errors
247        assert_eq!(
248            traverse_path_fast(&a, 0b1011, list).unwrap_err(),
249            EvalErr::PathIntoAtom
250        );
251        assert_eq!(
252            traverse_path_fast(&a, 0b1101, list).unwrap_err(),
253            EvalErr::PathIntoAtom
254        );
255        assert_eq!(
256            traverse_path_fast(&a, 0b1001, list).unwrap_err(),
257            EvalErr::PathIntoAtom
258        );
259        assert_eq!(
260            traverse_path_fast(&a, 0b1010, list).unwrap_err(),
261            EvalErr::PathIntoAtom
262        );
263        assert_eq!(
264            traverse_path_fast(&a, 0b1110, list).unwrap_err(),
265            EvalErr::PathIntoAtom
266        );
267    }
268}