use crate::{ArcisCiphertext, ArcisType};
use arcis_compiler::{
utils::{
field::BaseField,
packing::{DataSize, PackLocation},
},
ArcisField,
EvalValue,
};
use ff::Field;
use std::marker::PhantomData;
fn pack<T: ArcisType>(a: &T) -> Vec<BaseField> {
let mut data_sizes = Vec::new();
T::data_size(&mut data_sizes);
let (locations, v) = DataSize::pack_arcis(data_sizes);
let location_ptr = &mut &locations[..];
let mut res = vec![BaseField::ZERO; v.len()];
a.pack(location_ptr, &mut res);
res
}
fn unpack<T: ArcisType>(input: &[BaseField]) -> T {
let mut data_sizes = Vec::new();
T::data_size(&mut data_sizes);
let (locations, v) = DataSize::pack_arcis(data_sizes);
assert_eq!(input.len(), v.len(), "Unpack unexpected length");
let location_ptr = &mut &locations[..];
T::unpack(location_ptr, input)
}
fn n_packed_values<T: ArcisType>() -> usize {
let mut data_sizes = Vec::new();
T::data_size(&mut data_sizes);
DataSize::pack_arcis(data_sizes).1.len()
}
#[derive(Debug, PartialEq)]
pub struct Pack<T: ArcisType> {
data: Vec<ArcisCiphertext>,
phantom: PhantomData<T>,
}
impl<T: ArcisType> Pack<T> {
pub fn new(data: T) -> Self {
let data = pack(&data);
let phantom = PhantomData;
Self { data, phantom }
}
pub fn unpack(&self) -> T {
unpack(&self.data)
}
}
impl<T: ArcisType> ArcisType for Pack<T> {
fn n_values() -> usize {
n_packed_values::<T>()
}
fn gen_input(values: &mut Vec<EvalValue>) -> Self {
let mut gen_values = Vec::new();
let data = T::gen_input(&mut gen_values);
let data = pack(&data);
let phantom = PhantomData;
values.extend(data.iter().cloned().map(EvalValue::Base));
Self { data, phantom }
}
fn from_values(values: &[EvalValue]) -> Self {
let len = Self::n_values();
let data = values[..len]
.iter()
.map(|x| ArcisField::from(x.to_signed_number()))
.collect();
Pack {
data,
phantom: PhantomData,
}
}
fn handle_outputs(&self, outputs: &mut Vec<EvalValue>) {
outputs.extend(self.data.iter().cloned().map(EvalValue::Base));
}
fn is_similar(&self, other: &Self) -> bool {
self.unpack().is_similar(&other.unpack())
}
fn n_bools() -> usize {
panic!("Cannot generate random packs.")
}
fn from_bools(_bools: &[bool]) -> Self {
panic!("Cannot generate random packs.")
}
fn data_size(acc: &mut Vec<DataSize>) {
acc.extend(std::iter::repeat_n(DataSize::Full, Self::n_values()));
}
fn pack(&self, locations: &mut &[PackLocation], containers: &mut [BaseField]) {
for item in self.data.iter() {
let location = locations[0];
*locations = &locations[1..];
assert_eq!(location.bit_offset, 0, "BaseField item has a bit offset.");
containers[location.index] = *item;
}
}
fn unpack(locations: &mut &[PackLocation], containers: &[BaseField]) -> Self {
let data = (0..Self::n_values())
.map(|_| {
let location = locations[0];
*locations = &locations[1..];
assert_eq!(location.bit_offset, 0, "BaseField item has a bit offset.");
containers[location.index]
})
.collect::<Vec<_>>();
Pack {
data,
phantom: PhantomData,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use arcis_compiler::utils::used_field::UsedField;
#[test]
fn test_packing() {
use crate::*;
const A_SIZE: usize = 256;
const BASE_FIELD_FULL_BYTES: usize = 26;
#[derive(Debug, PartialEq, ArcisType)]
struct A {
arr: [i8; A_SIZE],
}
let a = A {
arr: [-127; A_SIZE],
};
let mut v = pack(&a);
let b = unpack(&v);
assert_eq!(a, b);
assert_eq!(v.len(), A_SIZE.div_ceil(BASE_FIELD_FULL_BYTES));
let last = v.pop().unwrap();
let mut last_eq = BaseField::ZERO;
for i in 0..(A_SIZE % BASE_FIELD_FULL_BYTES) {
last_eq += BaseField::power_of_two(8 * i);
}
assert_eq!(last, last_eq);
let mut others_eq = BaseField::ZERO;
for i in 0..BASE_FIELD_FULL_BYTES {
others_eq += BaseField::power_of_two(8 * i);
}
for b in v {
assert_eq!(b, others_eq);
}
#[derive(Debug, Clone, Copy, PartialEq, ArcisType)]
struct B {
i8: [i8; 17],
i16: [i16; 23],
i32: [i32; 43],
i64: [i64; 35],
i128: [i128; 22],
u8: [u8; 12],
u16: [u16; 47],
u32: [u32; 28],
u64: [u64; 39],
u128: [u128; 24],
}
for _ in 0..16 {
let b = ArcisRNG::gen_uniform::<B>();
let data = pack(&b);
let c = unpack(&data);
assert_eq!(b, c);
}
#[derive(Debug, PartialEq, ArcisType)]
struct C {
arr: [B; 3],
enc: Enc<Mxe, B>,
}
for _ in 0..16 {
let b = ArcisRNG::gen_uniform::<[B; 4]>();
let enc = Mxe::get().from_arcis(b[3]);
let b = C {
arr: [b[0], b[1], b[2]],
enc,
};
let data = pack(&b);
let c = unpack(&data);
assert_eq!(b, c);
}
}
}