1use std::mem::ManuallyDrop;
6
7use bytemuck::Pod;
8use rand::{rngs::ThreadRng, Rng};
9use rostl_oram::{
10 circuit_oram::CircuitORAM,
11 linear_oram::{oblivious_read_index, oblivious_write_index},
12 prelude::PositionType,
13 recursive_oram::RecursivePositionMap,
14};
15use rostl_primitives::{indexable::Length, traits::Cmov};
16
17pub type Array<T, const N: usize> = FixedArray<T, N>;
20pub type DArray<T> = DynamicArray<T>;
23
24#[repr(C)]
27#[derive(Debug)]
28pub struct ShortArray<T, const N: usize>
29{
31 pub(crate) data: [T; N],
33}
34
35impl<T, const N: usize> ShortArray<T, N>
36where
37 T: Cmov + Pod + Default,
38{
39 pub fn new() -> Self {
41 Self { data: [T::default(); N] }
42 }
43
44 pub fn read(&self, index: usize, out: &mut T) {
46 oblivious_read_index(&self.data, index, out);
47 }
48
49 pub fn write(&mut self, index: usize, value: T) {
51 oblivious_write_index(&mut self.data, index, value);
52 }
53}
54
55impl<T, const N: usize> Length for ShortArray<T, N> {
56 fn len(&self) -> usize {
57 N
58 }
59}
60
61impl<T, const N: usize> Default for ShortArray<T, N>
62where
63 T: Cmov + Pod + Default,
64{
65 fn default() -> Self {
66 Self::new()
67 }
68}
69
70#[repr(C)]
73#[derive(Debug)]
74pub struct LongArray<T, const N: usize>
75where
76 T: Cmov + Pod,
77{
78 data: CircuitORAM<T>,
80 pos_map: RecursivePositionMap,
82 rng: ThreadRng,
84}
85impl<T, const N: usize> LongArray<T, N>
86where
87 T: Cmov + Pod + Default + std::fmt::Debug,
88{
89 pub fn new() -> Self {
91 Self { data: CircuitORAM::new(N), pos_map: RecursivePositionMap::new(N), rng: rand::rng() }
92 }
93
94 pub fn read(&mut self, index: usize, out: &mut T) {
96 let new_pos = self.rng.random_range(0..N as PositionType);
97 let old_pos = self.pos_map.access_position(index, new_pos);
98 self.data.read(old_pos, new_pos, index, out);
99 }
100
101 pub fn write(&mut self, index: usize, value: T) {
103 let new_pos = self.rng.random_range(0..N as PositionType);
104 let old_pos = self.pos_map.access_position(index, new_pos);
105 self.data.write_or_insert(old_pos, new_pos, index, value);
106 }
107}
108
109impl<T: Cmov + Pod, const N: usize> Length for LongArray<T, N> {
110 fn len(&self) -> usize {
111 N
112 }
113}
114
115impl<T: Cmov + Pod + Default + std::fmt::Debug, const N: usize> Default for LongArray<T, N> {
116 fn default() -> Self {
117 Self::new()
118 }
119}
120
121const SHORT_ARRAY_THRESHOLD: usize = 128;
123
124#[repr(C)]
131pub union FixedArray<T, const N: usize>
132where
133 T: Cmov + Pod,
134{
135 short: ManuallyDrop<ShortArray<T, N>>,
137 long: ManuallyDrop<LongArray<T, N>>,
139}
140
141impl<T, const N: usize> Drop for FixedArray<T, N>
142where
143 T: Cmov + Pod,
144{
145 fn drop(&mut self) {
146 if N <= SHORT_ARRAY_THRESHOLD {
147 unsafe {
148 ManuallyDrop::drop(&mut self.short);
149 }
150 } else {
151 unsafe {
152 ManuallyDrop::drop(&mut self.long);
153 }
154 }
155 }
156}
157
158impl<T, const N: usize> std::fmt::Debug for FixedArray<T, N>
159where
160 T: Cmov + Pod + std::fmt::Debug,
161{
162 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163 if N <= SHORT_ARRAY_THRESHOLD {
164 let short_array: &ManuallyDrop<ShortArray<T, N>>;
165 unsafe {
166 short_array = &self.short;
167 }
168 short_array.fmt(f)
169 } else {
170 let long_array: &ManuallyDrop<LongArray<T, N>>;
171 unsafe {
172 long_array = &self.long;
173 }
174 long_array.fmt(f)
175 }
176 }
177}
178
179impl<T, const N: usize> FixedArray<T, N>
180where
181 T: Cmov + Pod + Default + std::fmt::Debug,
182{
183 pub fn new() -> Self {
185 if N <= SHORT_ARRAY_THRESHOLD {
186 FixedArray { short: ManuallyDrop::new(ShortArray::new()) }
187 } else {
188 FixedArray { long: ManuallyDrop::new(LongArray::new()) }
189 }
190 }
191
192 pub fn read(&mut self, index: usize, out: &mut T) {
194 if N <= SHORT_ARRAY_THRESHOLD {
195 let short_array: &mut ManuallyDrop<ShortArray<T, N>>;
197 unsafe {
198 short_array = &mut self.short;
199 }
200 short_array.read(index, out);
201 } else {
202 let long_array: &mut ManuallyDrop<LongArray<T, N>>;
203 unsafe {
204 long_array = &mut self.long;
205 }
206 long_array.read(index, out);
207 }
208 }
209
210 pub fn write(&mut self, index: usize, value: T) {
212 if N <= SHORT_ARRAY_THRESHOLD {
213 let short_array: &mut ManuallyDrop<ShortArray<T, N>>;
215 unsafe {
216 short_array = &mut self.short;
217 }
218 short_array.write(index, value);
219 } else {
220 let long_array: &mut ManuallyDrop<LongArray<T, N>>;
221 unsafe {
222 long_array = &mut self.long;
223 }
224 long_array.write(index, value);
225 }
226 }
227}
228
229impl<T: Cmov + Pod, const N: usize> Length for FixedArray<T, N> {
230 fn len(&self) -> usize {
231 N
232 }
233}
234
235impl<T: Cmov + Pod + Default + std::fmt::Debug, const N: usize> Default for FixedArray<T, N> {
236 fn default() -> Self {
237 Self::new()
238 }
239}
240
241#[derive(Debug)]
264pub struct DynamicArray<T>
265where
266 T: Cmov + Pod,
267{
268 data: CircuitORAM<T>,
270 pos_map: RecursivePositionMap,
272 rng: ThreadRng,
274}
275
276impl<T> DynamicArray<T>
277where
278 T: Cmov + Pod + Default + std::fmt::Debug,
279{
280 pub fn new(n: usize) -> Self {
282 Self { data: CircuitORAM::new(n), pos_map: RecursivePositionMap::new(n), rng: rand::rng() }
283 }
284
285 pub fn resize(&mut self, n: usize) {
287 let mut new_array = Self::new(n);
288 for i in 0..self.len() {
289 let mut value = Default::default();
290 self.read(i, &mut value);
291 new_array.write(i, value);
292 }
293 *self = new_array;
295 }
296
297 pub fn read(&mut self, index: usize, out: &mut T) {
299 let new_pos = self.rng.random_range(0..self.len() as PositionType);
300 let old_pos = self.pos_map.access_position(index, new_pos);
301 self.data.read(old_pos, new_pos, index, out);
302 }
303
304 pub fn write(&mut self, index: usize, value: T) {
306 let new_pos = self.rng.random_range(0..self.len() as PositionType);
307 let old_pos = self.pos_map.access_position(index, new_pos);
308 self.data.write_or_insert(old_pos, new_pos, index, value);
309 }
310
311 pub fn update<R, F>(&mut self, index: usize, update_func: F) -> (bool, R)
313 where
314 F: FnOnce(&mut T) -> R,
315 {
316 let new_pos = self.rng.random_range(0..self.len() as PositionType);
317 let old_pos = self.pos_map.access_position(index, new_pos);
318 self.data.update(old_pos, new_pos, index, update_func)
319 }
320}
321
322impl<T: Cmov + Pod> Length for DynamicArray<T> {
323 #[inline(always)]
324 fn len(&self) -> usize {
325 self.pos_map.n
326 }
327}
328
329#[cfg(test)]
338mod tests {
339 use super::*;
340
341 macro_rules! m_test_fixed_array_exaustive {
342 ($arraytp:ident, $valtp:ty, $size:expr) => {{
343 println!("Testing {} with size {}", stringify!($arraytp), $size);
344 let mut arr = $arraytp::<$valtp, $size>::new();
345 assert_eq!(arr.len(), $size);
346 for i in 0..$size {
347 let mut value = Default::default();
348 arr.read(i, &mut value);
349 assert_eq!(value, Default::default());
350 }
351 assert_eq!(arr.len(), $size);
352 for i in 0..$size {
353 let value = i as $valtp;
354 arr.write(i, value);
355 }
356 assert_eq!(arr.len(), $size);
357 for i in 0..$size {
358 let mut value = Default::default();
359 arr.read(i, &mut value);
360 let v = i as $valtp;
361 assert_eq!(value, v);
362 }
363 assert_eq!(arr.len(), $size);
364 }};
365 }
366
367 macro_rules! m_test_dynamic_array_exaustive {
368 ($arraytp:ident, $valtp:ty, $size:expr) => {{
369 println!("Testing {} with size {}", stringify!($arraytp), $size);
370 let mut arr = $arraytp::<$valtp>::new($size);
371 assert_eq!(arr.len(), $size);
372 for i in 0..$size {
373 let mut value = Default::default();
374 arr.read(i, &mut value);
375 assert_eq!(value, Default::default());
376 }
377 assert_eq!(arr.len(), $size);
378 for i in 0..$size {
379 let value = i as $valtp;
380 arr.write(i, value);
381 }
382 assert_eq!(arr.len(), $size);
383 for i in 0..$size {
384 let mut value = Default::default();
385 arr.read(i, &mut value);
386 let v = i as $valtp;
387 assert_eq!(value, v);
388 }
389 assert_eq!(arr.len(), $size);
390 arr.resize($size + 1);
391 assert_eq!(arr.len(), $size + 1);
392 for i in 0..$size {
393 let mut value = Default::default();
394 arr.read(i, &mut value);
395 let v = i as $valtp;
396 assert_eq!(value, v);
397 }
398 assert_eq!(arr.len(), $size + 1);
399 for i in $size..($size + 1) {
400 let mut value = Default::default();
401 arr.read(i, &mut value);
402 assert_eq!(value, Default::default());
403 }
404 assert_eq!(arr.len(), $size + 1);
405 arr.resize(2 * $size);
406 assert_eq!(arr.len(), 2 * $size);
407 for i in 0..$size {
408 let mut value = Default::default();
409 arr.read(i, &mut value);
410 let v = i as $valtp;
411 assert_eq!(value, v);
412 }
413 assert_eq!(arr.len(), 2 * $size);
414 for i in $size..(2 * $size) {
415 let mut value = Default::default();
416 arr.read(i, &mut value);
417 assert_eq!(value, Default::default());
418 }
419 assert_eq!(arr.len(), 2 * $size);
420 }};
422 }
423
424 #[test]
425 fn test_fixed_arrays() {
426 m_test_fixed_array_exaustive!(ShortArray, u32, 1);
427 m_test_fixed_array_exaustive!(ShortArray, u32, 2);
428 m_test_fixed_array_exaustive!(ShortArray, u32, 3);
429 m_test_fixed_array_exaustive!(ShortArray, u64, 15);
430 m_test_fixed_array_exaustive!(ShortArray, u8, 33);
431 m_test_fixed_array_exaustive!(ShortArray, u64, 200);
432
433 m_test_fixed_array_exaustive!(LongArray, u32, 2);
435 m_test_fixed_array_exaustive!(LongArray, u32, 3);
436 m_test_fixed_array_exaustive!(LongArray, u64, 15);
437 m_test_fixed_array_exaustive!(LongArray, u8, 33);
438
439 m_test_fixed_array_exaustive!(FixedArray, u32, 1);
440 m_test_fixed_array_exaustive!(FixedArray, u32, 2);
441 m_test_fixed_array_exaustive!(FixedArray, u32, 3);
442 m_test_fixed_array_exaustive!(FixedArray, u64, 15);
443 m_test_fixed_array_exaustive!(FixedArray, u8, 33);
444 m_test_fixed_array_exaustive!(FixedArray, u64, 200);
445 }
446
447 #[test]
448 fn test_dynamic_array() {
449 m_test_dynamic_array_exaustive!(DynamicArray, u32, 2);
451 m_test_dynamic_array_exaustive!(DynamicArray, u32, 3);
452 m_test_dynamic_array_exaustive!(DynamicArray, u64, 15);
453 m_test_dynamic_array_exaustive!(DynamicArray, u8, 33);
454 m_test_dynamic_array_exaustive!(DynamicArray, u64, 200);
455 }
456}