use std::ffi::c_void;
use std::ptr::null_mut;
use crate::bindgen;
use crate::error::*;
use serde::{Deserialize, Serialize};
pub struct Modulus {
handle: *mut c_void,
}
unsafe impl Sync for Modulus {}
unsafe impl Send for Modulus {}
impl std::fmt::Debug for Modulus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> {
write!(f, "{}", self.value())
}
}
impl PartialEq for Modulus {
fn eq(&self, other: &Self) -> bool {
self.value() == other.value()
}
}
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
#[repr(i32)]
pub enum SecurityLevel {
TC128 = 128,
TC192 = 192,
TC256 = 256,
}
impl TryFrom<i32> for SecurityLevel {
type Error = Error;
fn try_from(val: i32) -> Result<SecurityLevel> {
Ok(match val {
128 => SecurityLevel::TC128,
192 => SecurityLevel::TC192,
256 => SecurityLevel::TC256,
_ => Err(Error::SerializationError(Box::new(format!(
"Invalid security level: {}",
val
))))?,
})
}
}
impl From<SecurityLevel> for i32 {
fn from(val: SecurityLevel) -> Self {
match val {
SecurityLevel::TC128 => 128,
SecurityLevel::TC192 => 192,
SecurityLevel::TC256 => 256,
}
}
}
impl Default for SecurityLevel {
fn default() -> Self {
Self::TC128
}
}
pub unsafe fn unchecked_from_handle(handle: *mut c_void) -> Modulus {
Modulus { handle }
}
impl Modulus {
pub fn new(value: u64) -> Result<Self> {
let mut handle: *mut c_void = null_mut();
convert_seal_error(unsafe { bindgen::Modulus_Create1(value, &mut handle) })?;
Ok(Modulus { handle })
}
pub fn value(&self) -> u64 {
let mut val: u64 = 0;
convert_seal_error(unsafe { bindgen::Modulus_Value(self.handle, &mut val) })
.expect("Internal error. Could not get modulus value.");
val
}
pub fn get_handle(&self) -> *mut c_void {
self.handle
}
}
impl Drop for Modulus {
fn drop(&mut self) {
unsafe {
bindgen::Modulus_Destroy(self.handle);
}
}
}
impl Clone for Modulus {
fn clone(&self) -> Self {
let mut copy = null_mut();
unsafe {
convert_seal_error(bindgen::Modulus_Create2(self.handle, &mut copy))
.expect("Failed to clone modulus")
};
Self { handle: copy }
}
}
pub struct CoefficientModulus;
impl CoefficientModulus {
pub fn create(degree: u64, bit_sizes: &[i32]) -> Result<Vec<Modulus>> {
let mut bit_sizes = bit_sizes.to_owned();
let length = bit_sizes.len() as u64;
let mut coefficients: Vec<*mut c_void> = Vec::with_capacity(bit_sizes.len());
let coefficients_ptr = coefficients.as_mut_ptr();
convert_seal_error(unsafe {
bindgen::CoeffModulus_Create1(degree, length, bit_sizes.as_mut_ptr(), coefficients_ptr)
})?;
unsafe { coefficients.set_len(length as usize) };
Ok(coefficients
.iter()
.map(|h| Modulus { handle: *h })
.collect())
}
pub fn bfv_default(degree: u64, security_level: SecurityLevel) -> Result<Vec<Modulus>> {
let mut len: u64 = 0;
convert_seal_error(unsafe {
bindgen::CoeffModulus_BFVDefault(degree, security_level as i32, &mut len, null_mut())
})?;
let mut coefficients: Vec<*mut c_void> = Vec::with_capacity(len as usize);
let coefficients_ptr = coefficients.as_mut_ptr();
convert_seal_error(unsafe {
bindgen::CoeffModulus_BFVDefault(
degree,
security_level as i32,
&mut len,
coefficients_ptr,
)
})?;
unsafe { coefficients.set_len(len as usize) };
Ok(coefficients
.iter()
.map(|handle| Modulus { handle: *handle })
.collect())
}
pub fn max_bit_count(degree: u64, security_level: SecurityLevel) -> u32 {
let mut bits: i32 = 0;
unsafe { bindgen::CoeffModulus_MaxBitCount(degree, security_level as i32, &mut bits) };
assert!(bits > 0);
bits as u32
}
}
pub struct PlainModulus;
impl PlainModulus {
pub fn batching(degree: u64, bit_size: u32) -> Result<Modulus> {
let bit_sizes = vec![bit_size as i32];
let modulus_chain = CoefficientModulus::create(degree, bit_sizes.as_slice())?;
Ok(modulus_chain.first().ok_or(Error::Unexpected)?.clone())
}
pub fn raw(val: u64) -> Result<Modulus> {
Modulus::new(val)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn can_create_plain_modulus() {
let modulus = PlainModulus::batching(1024, 20).unwrap();
assert_eq!(modulus.value(), 1038337);
}
#[test]
fn can_create_default_coefficient_modulus() {
let modulus = CoefficientModulus::bfv_default(1024, SecurityLevel::TC128).unwrap();
assert_eq!(modulus.len(), 1);
assert_eq!(modulus[0].value(), 132120577);
let modulus = CoefficientModulus::bfv_default(1024, SecurityLevel::TC192).unwrap();
assert_eq!(modulus.len(), 1);
assert_eq!(modulus[0].value(), 520193);
let modulus = CoefficientModulus::bfv_default(1024, SecurityLevel::TC256).unwrap();
assert_eq!(modulus.len(), 1);
assert_eq!(modulus[0].value(), 12289);
}
#[test]
fn can_create_custom_coefficient_modulus() {
let modulus = CoefficientModulus::create(8192, &[50, 30, 30, 50, 50]).unwrap();
assert_eq!(modulus.len(), 5);
assert_eq!(modulus[0].value(), 1125899905744897);
assert_eq!(modulus[1].value(), 1073643521);
assert_eq!(modulus[2].value(), 1073692673);
assert_eq!(modulus[3].value(), 1125899906629633);
assert_eq!(modulus[4].value(), 1125899906826241);
}
#[test]
fn can_roundtrip_security_level() {
for sec in [
SecurityLevel::TC128,
SecurityLevel::TC192,
SecurityLevel::TC256,
] {
let sec_2: i32 = sec.into();
let sec_2 = SecurityLevel::try_from(sec_2).unwrap();
assert_eq!(sec, sec_2);
}
}
}