1use std::{array::from_fn, mem::ManuallyDrop};
6
7use bytemuck::Pod;
8use rand::{rng, Rng};
9use rostl_oram::{
10 circuit_oram::CircuitORAM,
11 linear_oram::{oblivious_read_index, oblivious_write_index},
12 prelude::PositionType,
13 recursive_oram::RecursivePositionMap,
14};
15use rostl_primitives::{indexable::Length, traits::Cmov};
16
17pub type Array<T, const N: usize> = FixedArray<T, N>;
20pub type DArray<T> = DynamicArray<T>;
23
24#[repr(C)]
27#[derive(Debug)]
28pub struct ShortArray<T, const N: usize>
29{
31 pub(crate) data: [T; N],
33}
34
35impl<T, const N: usize> ShortArray<T, N>
36where
37 T: Cmov + Pod + Default,
38{
39 pub fn new() -> Self {
41 Self { data: [T::default(); N] }
42 }
43
44 pub fn read(&self, index: usize, out: &mut T) {
46 oblivious_read_index(&self.data, index, out);
47 }
48
49 pub fn write(&mut self, index: usize, value: T) {
51 oblivious_write_index(&mut self.data, index, value);
52 }
53}
54
55impl<T, const N: usize> Length for ShortArray<T, N> {
56 fn len(&self) -> usize {
57 N
58 }
59}
60
61impl<T, const N: usize> Default for ShortArray<T, N>
62where
63 T: Cmov + Pod + Default,
64{
65 fn default() -> Self {
66 Self::new()
67 }
68}
69
70#[repr(C)]
73#[derive(Debug)]
74pub struct LongArray<T, const N: usize>
75where
76 T: Cmov + Pod,
77{
78 data: CircuitORAM<T>,
80 pos_map: RecursivePositionMap,
82}
83impl<T, const N: usize> LongArray<T, N>
84where
85 T: Cmov + Pod + Default + std::fmt::Debug,
86{
87 pub fn new() -> Self {
89 Self { data: CircuitORAM::new(N), pos_map: RecursivePositionMap::new(N) }
90 }
91
92 pub fn read(&mut self, index: usize, out: &mut T) {
94 let new_pos = rng().random_range(0..N as PositionType);
95 let old_pos = self.pos_map.access_position(index, new_pos);
96 self.data.read(old_pos, new_pos, index, out);
97 }
98
99 pub fn write(&mut self, index: usize, value: T) {
101 let new_pos = rng().random_range(0..N as PositionType);
102 let old_pos = self.pos_map.access_position(index, new_pos);
103 self.data.write_or_insert(old_pos, new_pos, index, value);
104 }
105}
106
107impl<T: Cmov + Pod, const N: usize> Length for LongArray<T, N> {
108 fn len(&self) -> usize {
109 N
110 }
111}
112
113impl<T: Cmov + Pod + Default + std::fmt::Debug, const N: usize> Default for LongArray<T, N> {
114 fn default() -> Self {
115 Self::new()
116 }
117}
118
119const SHORT_ARRAY_THRESHOLD: usize = 128;
121
122#[repr(C)]
129pub union FixedArray<T, const N: usize>
130where
131 T: Cmov + Pod,
132{
133 short: ManuallyDrop<ShortArray<T, N>>,
135 long: ManuallyDrop<LongArray<T, N>>,
137}
138
139impl<T, const N: usize> Drop for FixedArray<T, N>
140where
141 T: Cmov + Pod,
142{
143 fn drop(&mut self) {
144 if N <= SHORT_ARRAY_THRESHOLD {
145 unsafe {
146 ManuallyDrop::drop(&mut self.short);
147 }
148 } else {
149 unsafe {
150 ManuallyDrop::drop(&mut self.long);
151 }
152 }
153 }
154}
155
156impl<T, const N: usize> std::fmt::Debug for FixedArray<T, N>
157where
158 T: Cmov + Pod + std::fmt::Debug,
159{
160 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161 if N <= SHORT_ARRAY_THRESHOLD {
162 let short_array: &ManuallyDrop<ShortArray<T, N>>;
163 unsafe {
164 short_array = &self.short;
165 }
166 short_array.fmt(f)
167 } else {
168 let long_array: &ManuallyDrop<LongArray<T, N>>;
169 unsafe {
170 long_array = &self.long;
171 }
172 long_array.fmt(f)
173 }
174 }
175}
176
177impl<T, const N: usize> FixedArray<T, N>
178where
179 T: Cmov + Pod + Default + std::fmt::Debug,
180{
181 pub fn new() -> Self {
183 if N <= SHORT_ARRAY_THRESHOLD {
184 FixedArray { short: ManuallyDrop::new(ShortArray::new()) }
185 } else {
186 FixedArray { long: ManuallyDrop::new(LongArray::new()) }
187 }
188 }
189
190 pub fn read(&mut self, index: usize, out: &mut T) {
192 if N <= SHORT_ARRAY_THRESHOLD {
193 let short_array: &mut ManuallyDrop<ShortArray<T, N>>;
195 unsafe {
196 short_array = &mut self.short;
197 }
198 short_array.read(index, out);
199 } else {
200 let long_array: &mut ManuallyDrop<LongArray<T, N>>;
201 unsafe {
202 long_array = &mut self.long;
203 }
204 long_array.read(index, out);
205 }
206 }
207
208 pub fn write(&mut self, index: usize, value: T) {
210 if N <= SHORT_ARRAY_THRESHOLD {
211 let short_array: &mut ManuallyDrop<ShortArray<T, N>>;
213 unsafe {
214 short_array = &mut self.short;
215 }
216 short_array.write(index, value);
217 } else {
218 let long_array: &mut ManuallyDrop<LongArray<T, N>>;
219 unsafe {
220 long_array = &mut self.long;
221 }
222 long_array.write(index, value);
223 }
224 }
225}
226
227impl<T: Cmov + Pod, const N: usize> Length for FixedArray<T, N> {
228 fn len(&self) -> usize {
229 N
230 }
231}
232
233impl<T: Cmov + Pod + Default + std::fmt::Debug, const N: usize> Default for FixedArray<T, N> {
234 fn default() -> Self {
235 Self::new()
236 }
237}
238
239#[derive(Debug)]
262pub struct DynamicArray<T>
263where
264 T: Cmov + Pod,
265{
266 data: CircuitORAM<T>,
268 pos_map: RecursivePositionMap,
270}
271
272impl<T> DynamicArray<T>
273where
274 T: Cmov + Pod + Default + std::fmt::Debug,
275{
276 pub fn new(n: usize) -> Self {
278 Self { data: CircuitORAM::new(n), pos_map: RecursivePositionMap::new(n) }
279 }
280
281 pub fn resize(&mut self, n: usize) {
283 let mut new_array = Self::new(n);
284 for i in 0..self.len() {
285 let mut value = Default::default();
286 self.read(i, &mut value);
287 new_array.write(i, value);
288 }
289 *self = new_array;
291 }
292
293 pub fn read(&mut self, index: usize, out: &mut T) {
295 let new_pos = rng().random_range(0..self.len() as PositionType);
296 let old_pos = self.pos_map.access_position(index, new_pos);
297 self.data.read(old_pos, new_pos, index, out);
298 }
299
300 pub fn write(&mut self, index: usize, value: T) {
302 let new_pos = rng().random_range(0..self.len() as PositionType);
303 let old_pos = self.pos_map.access_position(index, new_pos);
304 self.data.write_or_insert(old_pos, new_pos, index, value);
305 }
306
307 pub fn update<R, F>(&mut self, index: usize, update_func: F) -> (bool, R)
309 where
310 F: FnOnce(&mut T) -> R,
311 {
312 let new_pos = rng().random_range(0..self.len() as PositionType);
313 let old_pos = self.pos_map.access_position(index, new_pos);
314 self.data.update(old_pos, new_pos, index, update_func)
315 }
316}
317
318impl<T: Cmov + Pod> Length for DynamicArray<T> {
319 #[inline(always)]
320 fn len(&self) -> usize {
321 self.pos_map.n
322 }
323}
324
325#[derive(Debug)]
328pub struct MultiWayArray<T, const W: usize>
329where
330 T: Cmov + Pod,
331{
332 data: CircuitORAM<T>,
334 pos_map: [RecursivePositionMap; W],
336}
337
338impl<T, const W: usize> MultiWayArray<T, W>
339where
340 T: Cmov + Pod + Default + std::fmt::Debug,
341{
342 pub fn new(n: usize) -> Self {
344 assert!(W.is_power_of_two(), "W must be a power of two due to all the ilog2's here");
345 Self { data: CircuitORAM::new(n), pos_map: from_fn(|_| RecursivePositionMap::new(n)) }
346 }
347
348 fn get_real_index(&self, subarray: usize, index: usize) -> usize {
349 debug_assert!(subarray < W, "Subarray index out of bounds");
350 debug_assert!(index < self.len(), "Index out of bounds");
351 (index << W.ilog2()) | subarray
352 }
353
354 pub fn read(&mut self, subarray: usize, index: usize, out: &mut T) {
356 let new_pos = rng().random_range(0..self.len() as PositionType);
357 let old_pos = self.pos_map[subarray].access_position(index, new_pos);
358 let real_index = self.get_real_index(subarray, index);
359 self.data.read(old_pos, new_pos, real_index, out);
360 }
361
362 pub fn write(&mut self, subarray: usize, index: usize, value: T) {
364 let new_pos = rng().random_range(0..self.len() as PositionType);
365 let old_pos = self.pos_map[subarray].access_position(index, new_pos);
366 let real_index = self.get_real_index(subarray, index);
367 self.data.write_or_insert(old_pos, new_pos, real_index, value);
368 }
369
370 pub fn update<R, F>(&mut self, subarray: usize, index: usize, update_func: F) -> (bool, R)
372 where
373 F: FnOnce(&mut T) -> R,
374 {
375 let new_pos = rng().random_range(0..self.len() as PositionType);
376 let old_pos = self.pos_map[subarray].access_position(index, new_pos);
377 let real_index = self.get_real_index(subarray, index);
378 self.data.update(old_pos, new_pos, real_index, update_func)
379 }
380}
381
382impl<T: Cmov + Pod, const W: usize> Length for MultiWayArray<T, W> {
383 #[inline(always)]
384 fn len(&self) -> usize {
385 self.pos_map[0].n
386 }
387}
388
389#[cfg(test)]
398#[allow(clippy::reversed_empty_ranges)]
399mod tests {
400 use super::*;
401
402 macro_rules! m_test_fixed_array_exhaustive {
403 ($arraytp:ident, $valtp:ty, $size:expr) => {{
404 println!("Testing {} with size {}", stringify!($arraytp), $size);
405 let mut arr = $arraytp::<$valtp, $size>::new();
406 assert_eq!(arr.len(), $size);
407 for i in 0..$size {
408 let mut value = Default::default();
409 arr.read(i, &mut value);
410 assert_eq!(value, Default::default());
411 }
412 assert_eq!(arr.len(), $size);
413 for i in 0..$size {
414 let value = i as $valtp;
415 arr.write(i, value);
416 }
417 assert_eq!(arr.len(), $size);
418 for i in 0..$size {
419 let mut value = Default::default();
420 arr.read(i, &mut value);
421 let v = i as $valtp;
422 assert_eq!(value, v);
423 }
424 assert_eq!(arr.len(), $size);
425 }};
426 }
427
428 macro_rules! m_test_multiway_array_exhaustive {
429 ($arraytp:ident, $valtp:ty, $size:expr, $ways:expr) => {{
430 println!("Testing {} with size {}", stringify!($arraytp), $size);
431 let mut arr = $arraytp::<$valtp, $ways>::new($size);
432 assert_eq!(arr.len(), $size);
433 for w in 0..$ways {
434 for i in 0..$size {
435 let mut value = Default::default();
436 arr.read(w, i, &mut value);
437 assert_eq!(value, Default::default());
438 }
439 }
440 assert_eq!(arr.len(), $size);
441
442 for w in 0..$ways {
443 for i in 0..($size / $ways) {
444 let value = (i + w) as $valtp;
445 arr.write(w, i, value);
446 }
447 }
448 assert_eq!(arr.len(), $size);
449 for w in 0..$ways {
450 for i in 0..($size / $ways) {
451 let mut value = Default::default();
452 arr.read(w, i, &mut value);
453 let v = (i + w) as $valtp;
454 assert_eq!(value, v);
455 }
456 }
457 assert_eq!(arr.len(), $size);
458 }};
459 }
460
461 macro_rules! m_test_dynamic_array_exhaustive {
462 ($arraytp:ident, $valtp:ty, $size:expr) => {{
463 println!("Testing {} with size {}", stringify!($arraytp), $size);
464 let mut arr = $arraytp::<$valtp>::new($size);
465 assert_eq!(arr.len(), $size);
466 for i in 0..$size {
467 let mut value = Default::default();
468 arr.read(i, &mut value);
469 assert_eq!(value, Default::default());
470 }
471 assert_eq!(arr.len(), $size);
472 for i in 0..$size {
473 let value = i as $valtp;
474 arr.write(i, value);
475 }
476 assert_eq!(arr.len(), $size);
477 for i in 0..$size {
478 let mut value = Default::default();
479 arr.read(i, &mut value);
480 let v = i as $valtp;
481 assert_eq!(value, v);
482 }
483 assert_eq!(arr.len(), $size);
484 arr.resize($size + 1);
485 assert_eq!(arr.len(), $size + 1);
486 for i in 0..$size {
487 let mut value = Default::default();
488 arr.read(i, &mut value);
489 let v = i as $valtp;
490 assert_eq!(value, v);
491 }
492 assert_eq!(arr.len(), $size + 1);
493 for i in $size..($size + 1) {
494 let mut value = Default::default();
495 arr.read(i, &mut value);
496 assert_eq!(value, Default::default());
497 }
498 assert_eq!(arr.len(), $size + 1);
499 arr.resize(2 * $size);
500 assert_eq!(arr.len(), 2 * $size);
501 for i in 0..$size {
502 let mut value = Default::default();
503 arr.read(i, &mut value);
504 let v = i as $valtp;
505 assert_eq!(value, v);
506 }
507 assert_eq!(arr.len(), 2 * $size);
508 for i in $size..(2 * $size) {
509 let mut value = Default::default();
510 arr.read(i, &mut value);
511 assert_eq!(value, Default::default());
512 }
513 assert_eq!(arr.len(), 2 * $size);
514 }};
516 }
517
518 #[test]
519 fn test_fixed_arrays() {
520 m_test_fixed_array_exhaustive!(ShortArray, u32, 1);
521 m_test_fixed_array_exhaustive!(ShortArray, u32, 2);
522 m_test_fixed_array_exhaustive!(ShortArray, u32, 3);
523 m_test_fixed_array_exhaustive!(ShortArray, u64, 15);
524 m_test_fixed_array_exhaustive!(ShortArray, u8, 33);
525 m_test_fixed_array_exhaustive!(ShortArray, u64, 200);
526
527 m_test_fixed_array_exhaustive!(LongArray, u32, 2);
529 m_test_fixed_array_exhaustive!(LongArray, u32, 3);
530 m_test_fixed_array_exhaustive!(LongArray, u64, 15);
531 m_test_fixed_array_exhaustive!(LongArray, u8, 33);
532
533 m_test_fixed_array_exhaustive!(FixedArray, u32, 1);
534 m_test_fixed_array_exhaustive!(FixedArray, u32, 2);
535 m_test_fixed_array_exhaustive!(FixedArray, u32, 3);
536 m_test_fixed_array_exhaustive!(FixedArray, u64, 15);
537 m_test_fixed_array_exhaustive!(FixedArray, u8, 33);
538 m_test_fixed_array_exhaustive!(FixedArray, u64, 200);
539 }
540
541 #[test]
542 fn test_multiway_array() {
543 m_test_multiway_array_exhaustive!(MultiWayArray, u32, 2, 1);
545 m_test_multiway_array_exhaustive!(MultiWayArray, u32, 3, 1);
546 m_test_multiway_array_exhaustive!(MultiWayArray, u64, 15, 1);
547 m_test_multiway_array_exhaustive!(MultiWayArray, u8, 33, 1);
548 m_test_multiway_array_exhaustive!(MultiWayArray, u64, 200, 1);
549
550 m_test_multiway_array_exhaustive!(MultiWayArray, u32, 2, 2);
552 m_test_multiway_array_exhaustive!(MultiWayArray, u32, 3, 2);
553 m_test_multiway_array_exhaustive!(MultiWayArray, u64, 15, 2);
554 m_test_multiway_array_exhaustive!(MultiWayArray, u8, 33, 2);
555 m_test_multiway_array_exhaustive!(MultiWayArray, u64, 200, 2);
556
557 m_test_multiway_array_exhaustive!(MultiWayArray, u32, 2, 4);
559 m_test_multiway_array_exhaustive!(MultiWayArray, u32, 3, 4);
560 m_test_multiway_array_exhaustive!(MultiWayArray, u64, 15, 4);
561 m_test_multiway_array_exhaustive!(MultiWayArray, u8, 33, 4);
562 m_test_multiway_array_exhaustive!(MultiWayArray, u64, 200, 4);
563 }
564
565 #[test]
566 fn test_dynamic_array() {
567 m_test_dynamic_array_exhaustive!(DynamicArray, u32, 2);
569 m_test_dynamic_array_exhaustive!(DynamicArray, u32, 3);
570 m_test_dynamic_array_exhaustive!(DynamicArray, u64, 15);
571 m_test_dynamic_array_exhaustive!(DynamicArray, u8, 33);
572 m_test_dynamic_array_exhaustive!(DynamicArray, u64, 200);
573 }
574}