maniac_runtime/utils/
bits.rs

1use std::sync::atomic::{AtomicU64, Ordering};
2
3#[inline(always)]
4pub fn size(value: &AtomicU64) -> u64 {
5    value.load(Ordering::Relaxed).count_ones() as u64
6}
7
8#[inline(always)]
9pub fn is_empty(value: &AtomicU64) -> bool {
10    value.load(Ordering::Relaxed).count_ones() == 0
11}
12
13#[inline(always)]
14pub fn set(value: &AtomicU64, index: u64) -> (bool, bool) {
15    // let bit = 0x8000000000000000u64 >> index;
16    let bit = 1u64 << index;
17    let prev = value.fetch_or(bit, Ordering::AcqRel);
18    // was empty; was_set
19    (prev == 0, (prev & bit) == 0)
20}
21
22#[inline(always)]
23pub fn set_with_bit(value: &AtomicU64, bit: u64) -> u64 {
24    value.fetch_or(bit, Ordering::AcqRel)
25}
26
27#[inline(always)]
28pub fn acquire(value: &AtomicU64, index: u64) -> bool {
29    // let bit = 0x8000000000000000u64 >> index;
30    if !is_set(value, index) {
31        return false;
32    }
33    let bit = 1u64 << index;
34    let previous = value.fetch_and(!bit, Ordering::AcqRel);
35    (previous & bit) == bit
36}
37
38#[inline(always)]
39pub fn try_acquire(value: &AtomicU64, index: u64) -> (u64, u64, bool) {
40    if !is_set(value, index) {
41        return (0, 0, false);
42    }
43    let bit = 1u64 << index;
44    let previous = value.fetch_and(!bit, Ordering::AcqRel);
45    (bit, previous, (previous & bit) == bit)
46}
47
48#[inline(always)]
49pub fn is_set(value: &AtomicU64, index: u64) -> bool {
50    // let bit = 0x8000000000000000u64 >> index;
51    let bit = 1u64 << index;
52    (value.load(Ordering::Relaxed) & bit) != 0
53}
54
55pub fn find_nearest_set_bit(value: u64, start_index: u64) -> u64 {
56    if start_index >= 64 {
57        return 64;
58    }
59
60    // First, try to find a set bit at or after the start_index
61    let mask_forward = !((1u64 << start_index) - 1); // Clear bits before start_index
62    let forward_bits = value & mask_forward;
63
64    if forward_bits != 0 {
65        // Found a set bit at or after start_index
66        return forward_bits.trailing_zeros() as u64;
67    }
68
69    // If no bit found forward, search backwards from start_index
70    let mask_backward = (1u64 << start_index) - 1; // Keep only bits before start_index
71    let backward_bits = value & mask_backward;
72
73    if backward_bits != 0 {
74        // Found a set bit before start_index
75        return 63 - backward_bits.leading_zeros() as u64;
76    }
77
78    // No set bits found
79    64
80}
81
82pub fn find_nearest_by_distance0(value: u64, start_index: u64) -> u64 {
83    let out_of_bounds = start_index >= 64;
84    let idx = start_index & 63;
85
86    let forward_bits = value & !((1u64 << idx) - 1);
87    let backward_bits = value & ((1u64 << idx) - 1);
88
89    let f_idx = forward_bits.trailing_zeros() as u64;
90    let b_idx = 63 - backward_bits.leading_zeros() as u64;
91
92    let f_valid = forward_bits != 0;
93    let b_valid = backward_bits != 0;
94
95    let f_dist = f_idx - idx;
96    let b_dist = idx - b_idx;
97
98    // Branchless selection: prefer forward on tie, handle invalid cases
99    let use_forward = f_valid && (!b_valid || f_dist <= b_dist);
100    let use_backward = b_valid && !use_forward;
101
102    let result = if use_forward {
103        f_idx
104    } else if use_backward {
105        b_idx
106    } else {
107        64
108    };
109
110    if out_of_bounds { 64 } else { result }
111}
112
113pub fn find_nearest_by_distance_branchless(value: u64, start_index: u64) -> u64 {
114    // Handle out of bounds case
115    let valid = (start_index < 64) as u64;
116    let clamped_index = start_index & 63; // Equivalent to start_index % 64
117
118    // Search forward and backward
119    let mask_forward = !((1u64 << clamped_index) - 1);
120    let forward_bits = value & mask_forward;
121    let mask_backward = (1u64 << clamped_index) - 1;
122    let backward_bits = value & mask_backward;
123
124    // Calculate indices using bit manipulation to avoid branches
125    let forward_tz = forward_bits.trailing_zeros() as u64;
126    let forward_valid = (forward_bits != 0) as u64;
127    let forward_index = forward_tz | (64 * (1 - forward_valid));
128
129    let backward_lz = backward_bits.leading_zeros() as u64;
130    let backward_valid = (backward_bits != 0) as u64;
131    let backward_index = (63 - backward_lz) | (64 * (1 - backward_valid));
132
133    // Calculate distances
134    let forward_dist = forward_index - clamped_index;
135    let backward_dist = clamped_index - backward_index;
136
137    // Choose the closer one (forward wins ties)
138    let choose_forward = ((forward_dist <= backward_dist) & (forward_valid != 0)) as u64;
139    let choose_backward = ((backward_dist < forward_dist) & (backward_valid != 0)) as u64;
140
141    let result = forward_index * choose_forward
142        + backward_index * choose_backward
143        + 64 * (1 - choose_forward) * (1 - choose_backward);
144
145    // Return 64 if start_index was out of bounds, otherwise return result
146    result | (64 * (1 - valid))
147}
148
149pub fn find_nearest_by_distance(value: u64, start_index: u64) -> u64 {
150    if start_index >= 64 {
151        return 64;
152    }
153
154    // Search forward
155    let mask_forward = !((1u64 << start_index) - 1);
156    let forward_bits = value & mask_forward;
157    let mask_backward = (1u64 << start_index) - 1;
158    let backward_bits = value & mask_backward;
159
160    if forward_bits != 0 {
161        let forward_index = forward_bits.trailing_zeros() as u64;
162
163        if backward_bits == 0 {
164            return forward_index;
165        }
166
167        let forward_dist = forward_index - start_index;
168        let backward_index = 63 - backward_bits.leading_zeros() as u64;
169        let backward_dist = start_index - backward_index;
170
171        if forward_dist < backward_dist {
172            forward_index
173        } else {
174            backward_index
175        }
176    } else if backward_bits != 0 {
177        63 - backward_bits.leading_zeros() as u64
178    } else {
179        64
180    }
181}
182
183// // Alternative version that returns the distance as well
184// pub fn find_nearest_set_bit_with_distance(value: u64, start_index: u64) -> Option<(u64, u64)> {
185//     if start_index >= 64 {
186//         return None;
187//     }
188//
189//     // Search forward
190//     let mask_forward = !((1u64 << start_index) - 1);
191//     let forward_bits = value & mask_forward;
192//
193//     let forward_result = if forward_bits != 0 {
194//         let index = forward_bits.trailing_zeros() as u64;
195//         Some((index, index - start_index))
196//     } else {
197//         None
198//     };
199//
200//     // Search backward
201//     let mask_backward = (1u64 << start_index) - 1;
202//     let backward_bits = value & mask_backward;
203//
204//     let backward_result = if backward_bits != 0 {
205//         let index = 63 - backward_bits.leading_zeros() as u64;
206//         Some((index, start_index - index))
207//     } else {
208//         None
209//     };
210//
211//     // Return the closer one, preferring forward in case of tie
212//     match (forward_result, backward_result) {
213//         (Some((f_idx, f_dist)), Some((b_idx, b_dist))) => {
214//             if f_dist <= b_dist {
215//                 Some((f_idx, f_dist))
216//             } else {
217//                 Some((b_idx, b_dist))
218//             }
219//         }
220//         (Some(forward), None) => Some(forward),
221//         (None, Some(backward)) => Some(backward),
222//         (None, None) => None,
223//     }
224// }
225//
226// // Version that prioritizes forward search (like your original Java code)
227// pub fn find_nearest_set_bit_forward_priority(value: u64, start_index: u64) -> Option<u64> {
228//     if start_index >= 64 {
229//         return None;
230//     }
231//
232//     // Try forward first (including start_index)
233//     if start_index < 64 {
234//         let forward_mask = value >> start_index;
235//         if forward_mask != 0 {
236//             let offset = forward_mask.trailing_zeros() as u64;
237//             let found = start_index + offset;
238//             if found < 64 {
239//                 return Some(found);
240//             }
241//         }
242//     }
243//
244//     // If forward search failed, try backward
245//     if start_index > 0 {
246//         let backward_mask = value << (64 - start_index);
247//         if backward_mask != 0 {
248//             let leading_zeros = backward_mask.leading_zeros() as u64;
249//             return Some(start_index - 1 - leading_zeros);
250//         }
251//     }
252//
253//     None
254// }
255
256pub fn find_nearest(value: u64, signal_index: u64) -> u64 {
257    find_nearest_by_distance(value, signal_index)
258    // find_nearest_set_bit(value, signal_index)
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264
265    #[test]
266    fn nearest_test() {
267        let signal = AtomicU64::new(0);
268        let _ = set(&signal, 33u64);
269        println!("33: {}", is_set(&signal, 33));
270        println!("32: {}", is_set(&signal, 32));
271        set(&signal, 63);
272        println!(
273            "nearest by dist 35: {}",
274            find_nearest_by_distance(signal.load(Ordering::Relaxed), 35)
275        );
276        println!(
277            "nearest by dist branchless 35: {}",
278            find_nearest_by_distance_branchless(signal.load(Ordering::Relaxed), 35)
279        );
280        println!(
281            "nearest 35: {}",
282            find_nearest_set_bit(signal.load(Ordering::Relaxed), 35)
283        );
284        println!(
285            "nearest 31: {}",
286            find_nearest_set_bit(signal.load(Ordering::Relaxed), 31)
287        );
288        println!(
289            "nearest 1: {}",
290            find_nearest_set_bit(signal.load(Ordering::Relaxed), 1)
291        );
292    }
293}