1#![allow(clippy::needless_bitwise_bool)]
4
5use bytemuck::{Pod, Zeroable};
7use rostl_primitives::{
8 cmov_body, cxchg_body, impl_cmov_for_generic_pod,
9 traits::{Cmov, _Cmovbase},
10};
11
12use crate::heap_tree::HeapTree;
13use crate::prelude::{PositionType, DUMMY_POS, K};
14
15pub const Z: usize = 2;
17pub const S: usize = 20;
19const EVICTIONS_PER_OP: usize = 2; #[repr(C)]
31#[derive(Clone, Copy, Zeroable)]
32pub struct Block<V>
33where
34 V: Cmov + Pod,
35{
36 pub pos: PositionType,
38 pub key: K,
40 pub value: V,
42}
43unsafe impl<V: Cmov + Pod> Pod for Block<V> {}
44
45impl<T: Cmov + Pod> Default for Block<T> {
46 fn default() -> Self {
47 Self { pos: DUMMY_POS, key: 0, value: T::zeroed() }
48 }
49}
50
51impl<T: Cmov + Pod + std::fmt::Debug> std::fmt::Debug for Block<T> {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 if self.pos == DUMMY_POS {
54 write!(f, ".")
55 } else {
56 write!(f, "Block {{ pos: {}, key: {}, value: {:?} }}", self.pos, self.key, self.value)
57 }
58 }
59}
60
61impl_cmov_for_generic_pod!(Block<V>; where V: Cmov + Pod);
62
63impl<V: Cmov + Pod> Block<V> {
64 pub const fn is_empty(&self) -> bool {
66 self.pos == DUMMY_POS
67 }
68}
69
70#[repr(C)]
72#[derive(Debug, Default, Clone, Copy, Zeroable)]
73pub struct Bucket<V>([Block<V>; Z])
74where
75 V: Cmov + Pod;
76unsafe impl<V: Cmov + Pod> Pod for Bucket<V> {}
77impl_cmov_for_generic_pod!(Bucket<V>; where V: Cmov + Pod);
78
79impl<V: Cmov + Pod> HeapTree<Bucket<V>> {
80 #[inline]
83 pub fn read_path(&mut self, path: PositionType, out: &mut [Block<V>]) {
84 debug_assert!((path as usize) < (1 << self.height));
85 debug_assert!(out.len() == self.height * Z);
86 for i in 0..self.height {
87 let index = self.get_index(i, path);
88 let bucket = &self.tree[index];
89 out[i * Z..(i + 1) * Z].copy_from_slice(&bucket.0);
90 }
91 }
92
93 #[inline]
96 pub fn write_path(&mut self, path: PositionType, in_: &[Block<V>]) {
97 debug_assert!((path as usize) < (1 << self.height));
98 debug_assert!(in_.len() == self.height * Z);
99 for i in 0..self.height {
100 let index = self.get_index(i, path);
101 let bucket = &mut self.tree[index];
102 bucket.0.copy_from_slice(&in_[i * Z..(i + 1) * Z]);
103 }
104 }
105}
106
107#[derive(Debug)]
113pub struct CircuitORAM<V: Cmov + Pod> {
114 pub max_n: usize,
116 pub h: usize,
118 pub stash: Vec<Block<V>>,
120 pub tree: HeapTree<Bucket<V>>,
122 pub evict_counter: PositionType,
124}
125
126#[inline]
127fn read_and_remove_element<V: Cmov + Pod>(arr: &mut [Block<V>], k: K, ret: &mut V) -> bool {
128 let mut rv = false;
129
130 for item in arr {
131 let matched = (!item.is_empty()) & (item.key == k);
132 debug_assert!((!matched) | (!rv));
133
134 ret.cmov(&item.value, matched);
135 item.pos.cmov(&DUMMY_POS, matched);
136 rv.cmov(&true, matched);
137 }
138
139 rv
140}
141
142#[inline]
146pub fn remove_element<V: Cmov + Pod>(arr: &mut [Block<V>], k: K) -> bool {
147 let mut rv = false;
148
149 for item in arr {
150 let matched = (!item.is_empty()) & (item.key == k);
151 debug_assert!((!matched) | (!rv));
152
153 item.pos.cmov(&DUMMY_POS, matched);
154 rv.cmov(&true, matched);
155 }
156
157 rv
158}
159
160#[inline]
163pub fn write_block_to_empty_slot<V: Cmov + Pod>(arr: &mut [Block<V>], val: &Block<V>) -> bool {
164 let mut rv = false;
165
166 for item in arr {
167 let matched = (item.is_empty()) & (!rv);
168 debug_assert!((!matched) | (!rv));
169
170 item.cmov(val, matched);
171 rv.cmov(&true, matched);
172 }
173
174 rv
175}
176
177#[inline]
188pub fn reverse_bits(n: usize, bits: usize) -> usize {
189 let mut result = 0;
190 let mut value = n;
191
192 for _ in 0..bits {
193 result = (result << 1) | (value & 1);
194 value >>= 1;
195 }
196
197 result
198}
199
200#[inline]
201const fn common_suffix_length(a: PositionType, b: PositionType) -> u32 {
202 let w = a ^ b;
203 w.trailing_zeros()
204}
205
206impl<V: Cmov + Pod + Default + Clone + std::fmt::Debug> CircuitORAM<V> {
207 pub fn new(max_n: usize) -> Self {
218 debug_assert!(max_n > 1);
220 debug_assert!(max_n <= u32::MAX as usize);
221
222 let h = {
223 let h0 = (max_n).ilog2() as usize;
224 if (1 << h0) < max_n {
225 h0 + 2
226 } else {
227 h0 + 1
228 }
229 };
230 let tree = HeapTree::new(h);
231 let stash = vec![Block::<V>::default(); S + h * Z];
232
233 let max_n = 2usize.pow((h - 1) as u32);
234 Self { max_n, h, stash, tree, evict_counter: 0 }
235 }
236
237 pub fn new_with_positions_and_values(
249 max_n: usize,
250 keys: &[K],
251 values: &[V],
252 positions: &[PositionType],
253 ) -> Self {
254 let mut oram = Self::new(max_n);
255 debug_assert!(keys.len() == values.len());
256 debug_assert!(keys.len() == positions.len());
257 debug_assert!(keys.len() <= max_n);
258
259 for (i, ((key, value), pos)) in keys.iter().zip(values.iter()).zip(positions.iter()).enumerate()
260 {
261 oram.write_or_insert(i as PositionType, *pos, *key, *value);
262 }
263 oram
264 }
265
266 pub fn read_path_and_get_nodes(&mut self, pos: PositionType) {
268 debug_assert!((pos as usize) < self.max_n);
269 self.tree.read_path(pos, &mut self.stash[S..S + self.h * Z]);
270 }
271
272 pub fn write_back_path(&mut self, pos: PositionType) {
274 debug_assert!((pos as usize) < self.max_n);
275 self.tree.write_path(pos, &self.stash[S..S + self.h * Z]);
276 }
277
278 pub fn evict_once_fast(&mut self, pos: PositionType) {
280 let mut deepest: [i32; 64] = [-1; 64];
284 let mut deepest_idx: [i32; 64] = [0; 64];
285 let mut target: [i32; 64] = [-1; 64];
286 let mut has_empty: [bool; 64] = [false; 64];
287
288 let mut src = -1;
289 let mut dst: i32 = -1;
290
291 for idx in 0..S + Z {
296 let deepest_level = common_suffix_length(self.stash[idx].pos, pos) as i32;
297 let deeper_flag = (!self.stash[idx].is_empty()) & (deepest_level > dst);
298 dst.cmov(&deepest_level, deeper_flag);
299 deepest_idx[0].cmov(&(idx as i32), deeper_flag);
300 }
301 src.cmov(&0, dst != -1);
302
303 let mut idx = S + Z;
304 for i in 1..self.h {
307 deepest[i].cmov(&src, dst >= i as i32);
308 let mut bucket_deepest_level: i32 = -1;
309 for _ in 0..Z {
310 let deepest_level = common_suffix_length(self.stash[idx].pos, pos) as i32;
311 let is_empty = self.stash[idx].is_empty();
312 has_empty[i].cmov(&true, is_empty);
313
314 let deeper_flag = (!is_empty) & (deepest_level > bucket_deepest_level);
315 bucket_deepest_level.cmov(&deepest_level, deeper_flag);
316 deepest_idx[i].cmov(&(idx as i32), deeper_flag);
317
318 idx += 1;
319 }
320
321 let deepper_flag = bucket_deepest_level > dst;
322 src.cmov(&(i as i32), deepper_flag);
323 dst.cmov(&bucket_deepest_level, deepper_flag);
324 }
325
326 src = -1;
329 dst = -1;
330 for i in (1..self.h).rev() {
331 let is_src = (i as i32) == src;
332 target[i].cmov(&dst, is_src);
333 src.cmov(&-1, is_src);
334 dst.cmov(&-1, is_src);
335 let change_flag = (((dst == -1) & has_empty[i]) | (target[i] != -1)) & (deepest[i] != -1);
336 src.cmov(&deepest[i], change_flag);
337 dst.cmov(&(i as i32), change_flag);
338 }
339 target[0].cmov(&dst, src == 0);
340
341 let mut hold = Block::<V>::default();
345 for idx in 0..S + Z {
346 let is_deepest = deepest_idx[0] == idx as i32;
347 let read_and_remove_flag = is_deepest & (target[0] != -1);
348 hold.cmov(&self.stash[idx], read_and_remove_flag);
349 self.stash[idx].pos.cmov(&DUMMY_POS, read_and_remove_flag);
350 }
351 dst = target[0];
352
353 let mut idx = S + Z;
355 for i in 1..(self.h - 1) {
356 let has_target_flag = target[i] != -1;
357 let place_dummy_flag = (i as i32 == dst) & (!has_target_flag);
358 for _ in 0..Z {
359 let is_deepest = deepest_idx[i] == idx as i32;
377 let read_and_remove_flag = is_deepest & has_target_flag;
378 let write_flag = (self.stash[idx].is_empty()) & place_dummy_flag;
379 let swap_flag = read_and_remove_flag | write_flag;
380 hold.cxchg(&mut self.stash[idx], swap_flag);
381 idx += 1;
382 }
383
384 dst.cmov(&target[i], has_target_flag | place_dummy_flag);
385 }
386
387 let place_dummy_flag = ((self.h - 1) as i32) == dst;
389 let mut written = false;
390 for _ in 0..Z {
391 let write_flag = (self.stash[idx].is_empty()) & place_dummy_flag & (!written);
392 written |= write_flag;
393 self.stash[idx].cmov(&hold, write_flag);
394 idx += 1;
395 }
396 }
397
398 fn perform_eviction(&mut self, pos: PositionType) {
400 debug_assert!((pos as usize) < self.max_n);
401 self.read_path_and_get_nodes(pos);
402 self.evict_once_fast(pos);
403 self.write_back_path(pos);
404 }
405
406 fn perform_deterministic_evictions(&mut self) {
410 for _ in 0..EVICTIONS_PER_OP {
412 let evict_pos = self.evict_counter;
414 self.perform_eviction(evict_pos);
415 self.evict_counter = (self.evict_counter + 1) % (self.max_n as PositionType);
416 }
417 let mut ok = false;
421 for elem in &self.stash[..S] {
422 ok.cmov(&true, elem.is_empty());
423 }
424 debug_assert!(ok);
425 }
427
428 pub fn read(&mut self, pos: PositionType, new_pos: PositionType, key: K, ret: &mut V) -> bool {
441 debug_assert!((pos as usize) < self.max_n);
442 debug_assert!((new_pos as usize) < self.max_n || new_pos == DUMMY_POS);
443
444 self.read_path_and_get_nodes(pos);
445
446 let found = read_and_remove_element(&mut self.stash, key, ret);
447 let mut to_write = Block { pos: new_pos, key, value: *ret };
448 to_write.pos.cmov(&DUMMY_POS, !found);
449 write_block_to_empty_slot(&mut self.stash[..S], &to_write); self.evict_once_fast(pos);
452 self.write_back_path(pos);
453 self.perform_deterministic_evictions();
454
455 found
456 }
457
458 pub fn write(&mut self, pos: PositionType, new_pos: PositionType, key: K, val: V) -> bool {
472 debug_assert!((pos as usize) < self.max_n);
473 debug_assert!((new_pos as usize) < self.max_n);
475
476 self.read_path_and_get_nodes(pos);
477
478 let found = remove_element(&mut self.stash, key);
479
480 let mut target_pos = DUMMY_POS;
481 target_pos.cmov(&new_pos, found);
482
483 write_block_to_empty_slot(
484 &mut self.stash[..S],
485 &Block::<V> { pos: target_pos, key, value: val },
486 ); self.evict_once_fast(pos);
489 self.write_back_path(pos);
490 self.perform_deterministic_evictions();
491
492 found
493 }
494
495 pub fn write_or_insert(
509 &mut self,
510 pos: PositionType,
511 new_pos: PositionType,
512 key: K,
513 val: V,
514 ) -> bool {
515 debug_assert!((pos as usize) < self.max_n);
516 debug_assert!((new_pos as usize) < self.max_n || new_pos == DUMMY_POS);
517
518 self.read_path_and_get_nodes(pos);
519 let found = remove_element(&mut self.stash, key);
522 write_block_to_empty_slot(&mut self.stash[..S], &Block::<V> { pos: new_pos, key, value: val }); self.evict_once_fast(pos);
528 self.write_back_path(pos);
529 self.perform_deterministic_evictions();
530
531 found
532 }
533
534 pub fn update<T, F>(
545 &mut self,
546 pos: PositionType,
547 new_pos: PositionType,
548 key: K,
549 update_func: F,
550 ) -> (bool, T)
551 where
552 F: FnOnce(&mut V) -> T,
553 {
554 debug_assert!((pos as usize) < self.max_n);
555 debug_assert!((new_pos as usize) < self.max_n);
556
557 self.read_path_and_get_nodes(pos);
558
559 let mut val = V::default();
560 let found = read_and_remove_element(&mut self.stash, key, &mut val);
561 let rv = update_func(&mut val);
562
563 write_block_to_empty_slot(&mut self.stash[..S], &Block::<V> { pos: new_pos, key, value: val }); self.evict_once_fast(pos);
566 self.write_back_path(pos);
567 self.perform_deterministic_evictions();
568
569 (found, rv)
570 }
571
572 #[cfg(test)]
573 pub(crate) fn print_for_debug(&self) {
574 println!("self.h: {}", self.h);
575 println!("Stash: {:?}", self.stash);
576 for i in 0..self.h {
577 print!("Level {}: ", i);
578 for j in 0..(1 << i) {
579 let w_j = reverse_bits(j, i);
580 print!(
581 "{:?} ",
582 self.tree.get_path_at_depth(
583 i,
584 reverse_bits(w_j * (1 << (self.h - 1 - i)), self.h - 1) as PositionType
585 )
586 );
587 }
588 println!();
589 }
590 }
591}
592
593#[cfg(test)]
594mod tests {
595 use std::vec;
596
597 use super::*;
598 use rand::{rng, Rng};
599
600 fn assert_empty_stash(oram: &CircuitORAM<u64>) {
601 for elem in &oram.stash[..S] {
602 debug_assert!(elem.is_empty());
603 }
604 }
605
606 #[test]
607 fn test_print_for_debug() {
608 let mut oram = CircuitORAM::<u64>::new(4);
609 oram.perform_deterministic_evictions();
610 assert_empty_stash(&oram);
611 oram.print_for_debug();
612 oram.write_or_insert(0, 0, 0, 0);
613 oram.print_for_debug();
614 oram.write_or_insert(0, 1, 1, 1);
615 oram.print_for_debug();
616 oram.write_or_insert(0, 2, 2, 2);
617 oram.print_for_debug();
618 oram.write_or_insert(0, 3, 3, 3);
619 oram.print_for_debug();
620 oram.perform_deterministic_evictions();
621 oram.print_for_debug();
622 oram.write_or_insert(0, 0, 4, 0);
623 oram.print_for_debug();
624 oram.write_or_insert(0, 1, 5, 1);
625 oram.print_for_debug();
626 oram.write_or_insert(0, 2, 6, 2);
627 oram.print_for_debug();
628 oram.write_or_insert(0, 3, 7, 3);
629 oram.print_for_debug();
630 oram.perform_deterministic_evictions();
631 oram.print_for_debug();
632 oram.write_or_insert(0, 0, 10, 0);
633 oram.print_for_debug();
634 oram.write_or_insert(0, 1, 11, 1);
635 oram.print_for_debug();
636 oram.write_or_insert(0, 2, 12, 2);
637 oram.print_for_debug();
638 oram.write_or_insert(0, 3, 13, 3);
639 oram.print_for_debug();
640 oram.perform_deterministic_evictions();
641 oram.print_for_debug();
643 oram.write_or_insert(0, 0, 20, 0);
644 oram.print_for_debug();
645 oram.write_or_insert(0, 1, 21, 1);
646 oram.print_for_debug();
647 oram.write_or_insert(0, 2, 22, 2);
648 oram.print_for_debug();
649 oram.write_or_insert(0, 3, 23, 3);
650 oram.print_for_debug();
651 oram.perform_deterministic_evictions();
652 oram.perform_deterministic_evictions();
653 oram.perform_deterministic_evictions();
654 oram.perform_deterministic_evictions();
655 oram.perform_deterministic_evictions();
656 oram.print_for_debug();
657 }
658
659 #[test]
660 fn test_circuitoram_simple() {
661 let mut oram = CircuitORAM::<u64>::new(16);
662 oram.perform_deterministic_evictions();
663 assert_empty_stash(&oram);
664
665 oram.write_or_insert(0, 0, 1, 1);
666 assert_empty_stash(&oram);
667
668 let mut v = 0;
669 let found = oram.read(0, 0, 1, &mut v);
670 assert!(found);
671 assert_eq!(v, 1);
672 assert_empty_stash(&oram);
673 oram.print_for_debug();
674
675 oram.write_or_insert(0, 0, 2, 2);
676 assert_empty_stash(&oram);
677 let found = oram.read(0, 0, 2, &mut v);
678 assert!(found);
679 assert_eq!(v, 2);
680 assert_empty_stash(&oram);
681 oram.print_for_debug();
682
683 let found = oram.read(0, 0, 3, &mut v);
684 assert!(!found);
685 assert_empty_stash(&oram);
686 oram.print_for_debug();
687
688 oram.write_or_insert(0, 0, 1, 3);
689 let found = oram.read(0, 0, 1, &mut v);
690 assert!(found);
691 assert_eq!(v, 3);
692 }
693
694 #[test]
695 fn test_circuitoram_simple_2() {
696 const TOTAL_KEYS: usize = 8;
697 let mut oram = CircuitORAM::<u64>::new(TOTAL_KEYS);
698 let mut val = 0;
699 let found = oram.write_or_insert(0, 4, 0, 123);
700 oram.print_for_debug();
701 assert!(!found);
702 let found = oram.read(4, 7, 0, &mut val);
703 oram.print_for_debug();
704 assert!(found);
705 assert_eq!(val, 123);
706 }
707
708 fn test_circuitoram_repetitive_generic<const TOTAL_KEYS: PositionType>() {
709 let mut oram = CircuitORAM::<u64>::new(TOTAL_KEYS as usize);
710 let mut pmap = vec![0; TOTAL_KEYS as usize];
711 let mut vals = vec![0; TOTAL_KEYS as usize];
712 let mut used = vec![false; TOTAL_KEYS as usize];
713 let mut rng = rng();
714
715 for _ in 0..2_000 {
716 let new_pos = rng.random_range(0..TOTAL_KEYS);
717 let key = 0;
718 rng.random_range(0..TOTAL_KEYS);
719 let val = rng.random::<u64>();
720 let op = rng.random_range(0..3);
721 if op == 0 {
724 let mut v = 0;
725 let found = oram.read(pmap[key], new_pos, key as K, &mut v);
726 assert_eq!(found, used[key]);
727 if used[key] {
728 assert_eq!(v, vals[key]);
729 }
730 } else if op == 1 {
731 let found = oram.write(pmap[key], new_pos, key as K, val);
732 assert_eq!(found, used[key]);
733 vals[key] = val;
734 } else if op == 2 {
735 let found = oram.write_or_insert(pmap[key], new_pos, key as K, val);
736 assert_eq!(found, used[key]);
737 used[key] = true;
738 vals[key] = val;
739 } else if op == 3 {
740 let found = oram.update(pmap[key], new_pos, key as K, |v| {
741 *v = val;
742 *v
743 });
744 assert_eq!(found.0, used[key]);
745 if used[key] {
746 assert_eq!(found.1, vals[key]);
747 }
748 used[key] = true;
749 vals[key] = val;
750 }
751
752 pmap[key] = new_pos;
753 }
754 }
755
756 #[test]
770 fn test_circuitoram_repetitive() {
771 test_circuitoram_repetitive_generic::<8>();
772 test_circuitoram_repetitive_generic::<16>();
773 test_circuitoram_repetitive_generic::<1024>();
774 }
775
776 }