1use std::{array::from_fn, mem::ManuallyDrop};
6
7use bytemuck::Pod;
8use rand::{rngs::ThreadRng, Rng};
9use rostl_oram::{
10 circuit_oram::CircuitORAM,
11 linear_oram::{oblivious_read_index, oblivious_write_index},
12 prelude::PositionType,
13 recursive_oram::RecursivePositionMap,
14};
15use rostl_primitives::{indexable::Length, traits::Cmov};
16
17pub type Array<T, const N: usize> = FixedArray<T, N>;
20pub type DArray<T> = DynamicArray<T>;
23
24#[repr(C)]
27#[derive(Debug)]
28pub struct ShortArray<T, const N: usize>
29{
31 pub(crate) data: [T; N],
33}
34
35impl<T, const N: usize> ShortArray<T, N>
36where
37 T: Cmov + Pod + Default,
38{
39 pub fn new() -> Self {
41 Self { data: [T::default(); N] }
42 }
43
44 pub fn read(&self, index: usize, out: &mut T) {
46 oblivious_read_index(&self.data, index, out);
47 }
48
49 pub fn write(&mut self, index: usize, value: T) {
51 oblivious_write_index(&mut self.data, index, value);
52 }
53}
54
55impl<T, const N: usize> Length for ShortArray<T, N> {
56 fn len(&self) -> usize {
57 N
58 }
59}
60
61impl<T, const N: usize> Default for ShortArray<T, N>
62where
63 T: Cmov + Pod + Default,
64{
65 fn default() -> Self {
66 Self::new()
67 }
68}
69
70#[repr(C)]
73#[derive(Debug)]
74pub struct LongArray<T, const N: usize>
75where
76 T: Cmov + Pod,
77{
78 data: CircuitORAM<T>,
80 pos_map: RecursivePositionMap,
82 rng: ThreadRng,
84}
85impl<T, const N: usize> LongArray<T, N>
86where
87 T: Cmov + Pod + Default + std::fmt::Debug,
88{
89 pub fn new() -> Self {
91 Self { data: CircuitORAM::new(N), pos_map: RecursivePositionMap::new(N), rng: rand::rng() }
92 }
93
94 pub fn read(&mut self, index: usize, out: &mut T) {
96 let new_pos = self.rng.random_range(0..N as PositionType);
97 let old_pos = self.pos_map.access_position(index, new_pos);
98 self.data.read(old_pos, new_pos, index, out);
99 }
100
101 pub fn write(&mut self, index: usize, value: T) {
103 let new_pos = self.rng.random_range(0..N as PositionType);
104 let old_pos = self.pos_map.access_position(index, new_pos);
105 self.data.write_or_insert(old_pos, new_pos, index, value);
106 }
107}
108
109impl<T: Cmov + Pod, const N: usize> Length for LongArray<T, N> {
110 fn len(&self) -> usize {
111 N
112 }
113}
114
115impl<T: Cmov + Pod + Default + std::fmt::Debug, const N: usize> Default for LongArray<T, N> {
116 fn default() -> Self {
117 Self::new()
118 }
119}
120
121const SHORT_ARRAY_THRESHOLD: usize = 128;
123
124#[repr(C)]
131pub union FixedArray<T, const N: usize>
132where
133 T: Cmov + Pod,
134{
135 short: ManuallyDrop<ShortArray<T, N>>,
137 long: ManuallyDrop<LongArray<T, N>>,
139}
140
141impl<T, const N: usize> Drop for FixedArray<T, N>
142where
143 T: Cmov + Pod,
144{
145 fn drop(&mut self) {
146 if N <= SHORT_ARRAY_THRESHOLD {
147 unsafe {
148 ManuallyDrop::drop(&mut self.short);
149 }
150 } else {
151 unsafe {
152 ManuallyDrop::drop(&mut self.long);
153 }
154 }
155 }
156}
157
158impl<T, const N: usize> std::fmt::Debug for FixedArray<T, N>
159where
160 T: Cmov + Pod + std::fmt::Debug,
161{
162 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163 if N <= SHORT_ARRAY_THRESHOLD {
164 let short_array: &ManuallyDrop<ShortArray<T, N>>;
165 unsafe {
166 short_array = &self.short;
167 }
168 short_array.fmt(f)
169 } else {
170 let long_array: &ManuallyDrop<LongArray<T, N>>;
171 unsafe {
172 long_array = &self.long;
173 }
174 long_array.fmt(f)
175 }
176 }
177}
178
179impl<T, const N: usize> FixedArray<T, N>
180where
181 T: Cmov + Pod + Default + std::fmt::Debug,
182{
183 pub fn new() -> Self {
185 if N <= SHORT_ARRAY_THRESHOLD {
186 FixedArray { short: ManuallyDrop::new(ShortArray::new()) }
187 } else {
188 FixedArray { long: ManuallyDrop::new(LongArray::new()) }
189 }
190 }
191
192 pub fn read(&mut self, index: usize, out: &mut T) {
194 if N <= SHORT_ARRAY_THRESHOLD {
195 let short_array: &mut ManuallyDrop<ShortArray<T, N>>;
197 unsafe {
198 short_array = &mut self.short;
199 }
200 short_array.read(index, out);
201 } else {
202 let long_array: &mut ManuallyDrop<LongArray<T, N>>;
203 unsafe {
204 long_array = &mut self.long;
205 }
206 long_array.read(index, out);
207 }
208 }
209
210 pub fn write(&mut self, index: usize, value: T) {
212 if N <= SHORT_ARRAY_THRESHOLD {
213 let short_array: &mut ManuallyDrop<ShortArray<T, N>>;
215 unsafe {
216 short_array = &mut self.short;
217 }
218 short_array.write(index, value);
219 } else {
220 let long_array: &mut ManuallyDrop<LongArray<T, N>>;
221 unsafe {
222 long_array = &mut self.long;
223 }
224 long_array.write(index, value);
225 }
226 }
227}
228
229impl<T: Cmov + Pod, const N: usize> Length for FixedArray<T, N> {
230 fn len(&self) -> usize {
231 N
232 }
233}
234
235impl<T: Cmov + Pod + Default + std::fmt::Debug, const N: usize> Default for FixedArray<T, N> {
236 fn default() -> Self {
237 Self::new()
238 }
239}
240
241#[derive(Debug)]
264pub struct DynamicArray<T>
265where
266 T: Cmov + Pod,
267{
268 data: CircuitORAM<T>,
270 pos_map: RecursivePositionMap,
272 rng: ThreadRng,
274}
275
276impl<T> DynamicArray<T>
277where
278 T: Cmov + Pod + Default + std::fmt::Debug,
279{
280 pub fn new(n: usize) -> Self {
282 Self { data: CircuitORAM::new(n), pos_map: RecursivePositionMap::new(n), rng: rand::rng() }
283 }
284
285 pub fn resize(&mut self, n: usize) {
287 let mut new_array = Self::new(n);
288 for i in 0..self.len() {
289 let mut value = Default::default();
290 self.read(i, &mut value);
291 new_array.write(i, value);
292 }
293 *self = new_array;
295 }
296
297 pub fn read(&mut self, index: usize, out: &mut T) {
299 let new_pos = self.rng.random_range(0..self.len() as PositionType);
300 let old_pos = self.pos_map.access_position(index, new_pos);
301 self.data.read(old_pos, new_pos, index, out);
302 }
303
304 pub fn write(&mut self, index: usize, value: T) {
306 let new_pos = self.rng.random_range(0..self.len() as PositionType);
307 let old_pos = self.pos_map.access_position(index, new_pos);
308 self.data.write_or_insert(old_pos, new_pos, index, value);
309 }
310
311 pub fn update<R, F>(&mut self, index: usize, update_func: F) -> (bool, R)
313 where
314 F: FnOnce(&mut T) -> R,
315 {
316 let new_pos = self.rng.random_range(0..self.len() as PositionType);
317 let old_pos = self.pos_map.access_position(index, new_pos);
318 self.data.update(old_pos, new_pos, index, update_func)
319 }
320}
321
322impl<T: Cmov + Pod> Length for DynamicArray<T> {
323 #[inline(always)]
324 fn len(&self) -> usize {
325 self.pos_map.n
326 }
327}
328
329#[derive(Debug)]
332pub struct MultiWayArray<T, const W: usize>
333where
334 T: Cmov + Pod,
335{
336 data: CircuitORAM<T>,
338 pos_map: [RecursivePositionMap; W],
340 rng: ThreadRng,
342}
343
344impl<T, const W: usize> MultiWayArray<T, W>
345where
346 T: Cmov + Pod + Default + std::fmt::Debug,
347{
348 pub fn new(n: usize) -> Self {
350 assert!(W.is_power_of_two(), "W must be a power of two due to all the ilog2's here");
351 Self {
352 data: CircuitORAM::new(n),
353 pos_map: from_fn(|_| RecursivePositionMap::new(n)),
354 rng: rand::rng(),
355 }
356 }
357
358 fn get_real_index(&self, subarray: usize, index: usize) -> usize {
359 debug_assert!(subarray < W, "Subarray index out of bounds");
360 debug_assert!(index < self.len(), "Index out of bounds");
361 (index << W.ilog2()) | subarray
362 }
363
364 pub fn read(&mut self, subarray: usize, index: usize, out: &mut T) {
366 let new_pos = self.rng.random_range(0..self.len() as PositionType);
367 let old_pos = self.pos_map[subarray].access_position(index, new_pos);
368 let real_index = self.get_real_index(subarray, index);
369 self.data.read(old_pos, new_pos, real_index, out);
370 }
371
372 pub fn write(&mut self, subarray: usize, index: usize, value: T) {
374 let new_pos = self.rng.random_range(0..self.len() as PositionType);
375 let old_pos = self.pos_map[subarray].access_position(index, new_pos);
376 let real_index = self.get_real_index(subarray, index);
377 self.data.write_or_insert(old_pos, new_pos, real_index, value);
378 }
379
380 pub fn update<R, F>(&mut self, subarray: usize, index: usize, update_func: F) -> (bool, R)
382 where
383 F: FnOnce(&mut T) -> R,
384 {
385 let new_pos = self.rng.random_range(0..self.len() as PositionType);
386 let old_pos = self.pos_map[subarray].access_position(index, new_pos);
387 let real_index = self.get_real_index(subarray, index);
388 self.data.update(old_pos, new_pos, real_index, update_func)
389 }
390}
391
392impl<T: Cmov + Pod, const W: usize> Length for MultiWayArray<T, W> {
393 #[inline(always)]
394 fn len(&self) -> usize {
395 self.pos_map[0].n
396 }
397}
398
399#[cfg(test)]
408#[allow(clippy::reversed_empty_ranges)]
409mod tests {
410 use super::*;
411
412 macro_rules! m_test_fixed_array_exhaustive {
413 ($arraytp:ident, $valtp:ty, $size:expr) => {{
414 println!("Testing {} with size {}", stringify!($arraytp), $size);
415 let mut arr = $arraytp::<$valtp, $size>::new();
416 assert_eq!(arr.len(), $size);
417 for i in 0..$size {
418 let mut value = Default::default();
419 arr.read(i, &mut value);
420 assert_eq!(value, Default::default());
421 }
422 assert_eq!(arr.len(), $size);
423 for i in 0..$size {
424 let value = i as $valtp;
425 arr.write(i, value);
426 }
427 assert_eq!(arr.len(), $size);
428 for i in 0..$size {
429 let mut value = Default::default();
430 arr.read(i, &mut value);
431 let v = i as $valtp;
432 assert_eq!(value, v);
433 }
434 assert_eq!(arr.len(), $size);
435 }};
436 }
437
438 macro_rules! m_test_multiway_array_exhaustive {
439 ($arraytp:ident, $valtp:ty, $size:expr, $ways:expr) => {{
440 println!("Testing {} with size {}", stringify!($arraytp), $size);
441 let mut arr = $arraytp::<$valtp, $ways>::new($size);
442 assert_eq!(arr.len(), $size);
443 for w in 0..$ways {
444 for i in 0..$size {
445 let mut value = Default::default();
446 arr.read(w, i, &mut value);
447 assert_eq!(value, Default::default());
448 }
449 }
450 assert_eq!(arr.len(), $size);
451
452 for w in 0..$ways {
453 for i in 0..($size / $ways) {
454 let value = (i + w) as $valtp;
455 arr.write(w, i, value);
456 }
457 }
458 assert_eq!(arr.len(), $size);
459 for w in 0..$ways {
460 for i in 0..($size / $ways) {
461 let mut value = Default::default();
462 arr.read(w, i, &mut value);
463 let v = (i + w) as $valtp;
464 assert_eq!(value, v);
465 }
466 }
467 assert_eq!(arr.len(), $size);
468 }};
469 }
470
471 macro_rules! m_test_dynamic_array_exhaustive {
472 ($arraytp:ident, $valtp:ty, $size:expr) => {{
473 println!("Testing {} with size {}", stringify!($arraytp), $size);
474 let mut arr = $arraytp::<$valtp>::new($size);
475 assert_eq!(arr.len(), $size);
476 for i in 0..$size {
477 let mut value = Default::default();
478 arr.read(i, &mut value);
479 assert_eq!(value, Default::default());
480 }
481 assert_eq!(arr.len(), $size);
482 for i in 0..$size {
483 let value = i as $valtp;
484 arr.write(i, value);
485 }
486 assert_eq!(arr.len(), $size);
487 for i in 0..$size {
488 let mut value = Default::default();
489 arr.read(i, &mut value);
490 let v = i as $valtp;
491 assert_eq!(value, v);
492 }
493 assert_eq!(arr.len(), $size);
494 arr.resize($size + 1);
495 assert_eq!(arr.len(), $size + 1);
496 for i in 0..$size {
497 let mut value = Default::default();
498 arr.read(i, &mut value);
499 let v = i as $valtp;
500 assert_eq!(value, v);
501 }
502 assert_eq!(arr.len(), $size + 1);
503 for i in $size..($size + 1) {
504 let mut value = Default::default();
505 arr.read(i, &mut value);
506 assert_eq!(value, Default::default());
507 }
508 assert_eq!(arr.len(), $size + 1);
509 arr.resize(2 * $size);
510 assert_eq!(arr.len(), 2 * $size);
511 for i in 0..$size {
512 let mut value = Default::default();
513 arr.read(i, &mut value);
514 let v = i as $valtp;
515 assert_eq!(value, v);
516 }
517 assert_eq!(arr.len(), 2 * $size);
518 for i in $size..(2 * $size) {
519 let mut value = Default::default();
520 arr.read(i, &mut value);
521 assert_eq!(value, Default::default());
522 }
523 assert_eq!(arr.len(), 2 * $size);
524 }};
526 }
527
528 #[test]
529 fn test_fixed_arrays() {
530 m_test_fixed_array_exhaustive!(ShortArray, u32, 1);
531 m_test_fixed_array_exhaustive!(ShortArray, u32, 2);
532 m_test_fixed_array_exhaustive!(ShortArray, u32, 3);
533 m_test_fixed_array_exhaustive!(ShortArray, u64, 15);
534 m_test_fixed_array_exhaustive!(ShortArray, u8, 33);
535 m_test_fixed_array_exhaustive!(ShortArray, u64, 200);
536
537 m_test_fixed_array_exhaustive!(LongArray, u32, 2);
539 m_test_fixed_array_exhaustive!(LongArray, u32, 3);
540 m_test_fixed_array_exhaustive!(LongArray, u64, 15);
541 m_test_fixed_array_exhaustive!(LongArray, u8, 33);
542
543 m_test_fixed_array_exhaustive!(FixedArray, u32, 1);
544 m_test_fixed_array_exhaustive!(FixedArray, u32, 2);
545 m_test_fixed_array_exhaustive!(FixedArray, u32, 3);
546 m_test_fixed_array_exhaustive!(FixedArray, u64, 15);
547 m_test_fixed_array_exhaustive!(FixedArray, u8, 33);
548 m_test_fixed_array_exhaustive!(FixedArray, u64, 200);
549 }
550
551 #[test]
552 fn test_multiway_array() {
553 m_test_multiway_array_exhaustive!(MultiWayArray, u32, 2, 1);
555 m_test_multiway_array_exhaustive!(MultiWayArray, u32, 3, 1);
556 m_test_multiway_array_exhaustive!(MultiWayArray, u64, 15, 1);
557 m_test_multiway_array_exhaustive!(MultiWayArray, u8, 33, 1);
558 m_test_multiway_array_exhaustive!(MultiWayArray, u64, 200, 1);
559
560 m_test_multiway_array_exhaustive!(MultiWayArray, u32, 2, 2);
562 m_test_multiway_array_exhaustive!(MultiWayArray, u32, 3, 2);
563 m_test_multiway_array_exhaustive!(MultiWayArray, u64, 15, 2);
564 m_test_multiway_array_exhaustive!(MultiWayArray, u8, 33, 2);
565 m_test_multiway_array_exhaustive!(MultiWayArray, u64, 200, 2);
566
567 m_test_multiway_array_exhaustive!(MultiWayArray, u32, 2, 4);
569 m_test_multiway_array_exhaustive!(MultiWayArray, u32, 3, 4);
570 m_test_multiway_array_exhaustive!(MultiWayArray, u64, 15, 4);
571 m_test_multiway_array_exhaustive!(MultiWayArray, u8, 33, 4);
572 m_test_multiway_array_exhaustive!(MultiWayArray, u64, 200, 4);
573 }
574
575 #[test]
576 fn test_dynamic_array() {
577 m_test_dynamic_array_exhaustive!(DynamicArray, u32, 2);
579 m_test_dynamic_array_exhaustive!(DynamicArray, u32, 3);
580 m_test_dynamic_array_exhaustive!(DynamicArray, u64, 15);
581 m_test_dynamic_array_exhaustive!(DynamicArray, u8, 33);
582 m_test_dynamic_array_exhaustive!(DynamicArray, u64, 200);
583 }
584}