use core::hash::Hash;
use std::ffi::{c_void, CString};
use std::ptr::null_mut;
use crate::error::*;
use crate::{bindgen, serialization::CompressionType, Context, FromBytes, ToBytes};
use serde::ser::Error;
use serde::{Serialize, Serializer};
#[derive(Debug, Eq)]
pub struct Plaintext {
handle: *mut c_void,
}
unsafe impl Sync for Plaintext {}
unsafe impl Send for Plaintext {}
impl Clone for Plaintext {
fn clone(&self) -> Self {
let mut copy = null_mut();
convert_seal_error(unsafe { bindgen::Plaintext_Create5(self.handle, &mut copy) })
.expect("Internal error: Failed to copy plaintext.");
Self { handle: copy }
}
}
impl PartialEq for Plaintext {
fn eq(&self, other: &Self) -> bool {
if self.len() == other.len() {
for i in 0..self.len() {
if self.get_coefficient(i) != other.get_coefficient(i) {
return false;
}
}
true
} else {
false
}
}
}
impl Hash for Plaintext {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
for i in 0..self.len() {
let c = self.get_coefficient(i);
state.write_u64(c);
}
}
}
impl Serialize for Plaintext {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut num_bytes: i64 = 0;
convert_seal_error(unsafe {
bindgen::Plaintext_SaveSize(self.handle, CompressionType::ZStd as u8, &mut num_bytes)
})
.map_err(|e| {
S::Error::custom(format!("Failed to get private key serialized size: {}", e))
})?;
let bytes = self
.as_bytes()
.map_err(|e| S::Error::custom(format!("Failed to serialize bytes: {}", e)))?;
serializer.serialize_bytes(&bytes)
}
}
impl FromBytes for Plaintext {
fn from_bytes(context: &Context, data: &[u8]) -> Result<Self> {
let mut bytes_read = 0;
let plaintext = Plaintext::new()?;
convert_seal_error(unsafe {
bindgen::Plaintext_Load(
plaintext.handle,
context.get_handle(),
data.as_ptr() as *mut u8,
data.len() as u64,
&mut bytes_read,
)
})?;
Ok(plaintext)
}
}
impl ToBytes for Plaintext {
fn as_bytes(&self) -> Result<Vec<u8>> {
let mut num_bytes: i64 = 0;
convert_seal_error(unsafe {
bindgen::Plaintext_SaveSize(self.handle, CompressionType::ZStd as u8, &mut num_bytes)
})?;
let mut data: Vec<u8> = Vec::with_capacity(num_bytes as usize);
let mut bytes_written: i64 = 0;
convert_seal_error(unsafe {
let data_ptr = data.as_mut_ptr();
bindgen::Plaintext_Save(
self.handle,
data_ptr,
num_bytes as u64,
CompressionType::ZStd as u8,
&mut bytes_written,
)
})?;
unsafe { data.set_len(bytes_written as usize) };
Ok(data)
}
}
impl Plaintext {
pub fn get_handle(&self) -> *mut c_void {
self.handle
}
pub fn new() -> Result<Self> {
let mut handle: *mut c_void = null_mut();
convert_seal_error(unsafe { bindgen::Plaintext_Create1(null_mut(), &mut handle) })?;
Ok(Self { handle })
}
pub fn from_hex_string(hex_str: &str) -> Result<Self> {
let mut handle: *mut c_void = null_mut();
let hex_string = CString::new(hex_str).unwrap();
convert_seal_error(unsafe {
bindgen::Plaintext_Create4(hex_string.as_ptr() as *mut i8, null_mut(), &mut handle)
})?;
Ok(Self { handle })
}
pub fn get_coefficient(&self, index: usize) -> u64 {
let mut coeff: u64 = 0;
if index > self.len() {
panic!("Index {} out of bounds {}", index, self.len());
}
convert_seal_error(unsafe {
bindgen::Plaintext_CoeffAt(self.handle, index as u64, &mut coeff)
})
.expect("Fatal error in Plaintext::index().");
coeff
}
pub fn set_coefficient(&mut self, index: usize, value: u64) {
if index > self.len() {
panic!("Index {} out of bounds {}", index, self.len());
}
convert_seal_error(unsafe {
bindgen::Plaintext_SetCoeffAt(self.handle, index as u64, value)
})
.expect("Fatal error in Plaintext::index().");
}
pub fn resize(&mut self, count: usize) {
convert_seal_error(unsafe { bindgen::Plaintext_Resize(self.handle, count as u64) })
.expect("Fatal error in Plaintext::resize().");
}
#[allow(clippy::len_without_is_empty)]
pub fn len(&self) -> usize {
let mut size: u64 = 0;
convert_seal_error(unsafe { bindgen::Plaintext_CoeffCount(self.handle, &mut size) })
.expect("Fatal error in Plaintext::index().");
size as usize
}
pub fn is_ntt_form(&self) -> bool {
let mut result = false;
convert_seal_error(unsafe { bindgen::Plaintext_IsNTTForm(self.handle, &mut result) })
.expect("Fatal error in Plaintext::is_ntt_form().");
result
}
}
impl Drop for Plaintext {
fn drop(&mut self) {
convert_seal_error(unsafe { bindgen::Plaintext_Destroy(self.handle) })
.expect("Internal error in Plaintext::drop.");
}
}
pub struct Ciphertext {
handle: *mut c_void,
}
unsafe impl Sync for Ciphertext {}
unsafe impl Send for Ciphertext {}
impl Clone for Ciphertext {
fn clone(&self) -> Self {
let mut handle = null_mut();
convert_seal_error(unsafe { bindgen::Ciphertext_Create2(self.handle, &mut handle) })
.expect("Fatal error: Failed to clone ciphertext");
Self { handle }
}
}
impl Ciphertext {
pub fn get_handle(&self) -> *mut c_void {
self.handle
}
pub fn new() -> Result<Self> {
let mut handle: *mut c_void = null_mut();
convert_seal_error(unsafe { bindgen::Ciphertext_Create1(null_mut(), &mut handle) })?;
Ok(Self { handle })
}
pub fn num_polynomials(&self) -> u64 {
let mut size: u64 = 0;
convert_seal_error(unsafe { bindgen::Ciphertext_Size(self.handle, &mut size) }).unwrap();
size
}
pub fn coeff_modulus_size(&self) -> u64 {
let mut size: u64 = 0;
convert_seal_error(unsafe { bindgen::Ciphertext_CoeffModulusSize(self.handle, &mut size) })
.unwrap();
size
}
#[allow(dead_code)]
pub(crate) fn get_data(&self, index: usize) -> Result<u64> {
let mut value: u64 = 0;
convert_seal_error(unsafe {
bindgen::Ciphertext_GetDataAt1(self.handle, index as u64, &mut value)
})?;
Ok(value)
}
pub fn get_coefficient(&self, poly_index: usize, coeff_index: usize) -> Result<Vec<u64>> {
let size = self.coeff_modulus_size();
let mut data: Vec<u64> = Vec::with_capacity(size as usize);
convert_seal_error(unsafe {
let data_ptr = data.as_mut_ptr();
bindgen::Ciphertext_GetDataAt2(
self.handle,
poly_index as u64,
coeff_index as u64,
data_ptr,
)
})?;
unsafe { data.set_len(size as usize) };
Ok(data.clone())
}
pub fn is_ntt_form(&self) -> bool {
let mut result = false;
convert_seal_error(unsafe { bindgen::Ciphertext_IsNTTForm(self.handle, &mut result) })
.expect("Fatal error in Plaintext::is_ntt_form().");
result
}
}
impl PartialEq for Ciphertext {
fn eq(&self, other: &Self) -> bool {
self.as_bytes() == other.as_bytes()
}
}
impl ToBytes for Ciphertext {
fn as_bytes(&self) -> Result<Vec<u8>> {
let mut num_bytes: i64 = 0;
convert_seal_error(unsafe {
bindgen::Ciphertext_SaveSize(self.handle, CompressionType::ZStd as u8, &mut num_bytes)
})?;
let mut data: Vec<u8> = Vec::with_capacity(num_bytes as usize);
let mut bytes_written: i64 = 0;
convert_seal_error(unsafe {
let data_ptr = data.as_mut_ptr();
bindgen::Ciphertext_Save(
self.handle,
data_ptr,
num_bytes as u64,
CompressionType::ZStd as u8,
&mut bytes_written,
)
})?;
unsafe { data.set_len(bytes_written as usize) };
Ok(data)
}
}
impl FromBytes for Ciphertext {
fn from_bytes(context: &Context, bytes: &[u8]) -> Result<Self> {
let ciphertext = Self::new()?;
let mut bytes_read = 0i64;
convert_seal_error(unsafe {
bindgen::Ciphertext_Load(
ciphertext.handle,
context.handle,
bytes.as_ptr() as *mut u8,
bytes.len() as u64,
&mut bytes_read,
)
})?;
Ok(ciphertext)
}
}
impl Drop for Ciphertext {
fn drop(&mut self) {
convert_seal_error(unsafe { bindgen::Ciphertext_Destroy(self.handle) })
.expect("Internal error in Ciphertext::drop");
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn can_create_and_destroy_ciphertext() {
let ciphertext = Ciphertext::new().unwrap();
std::mem::drop(ciphertext);
}
#[test]
fn can_create_and_destroy_plaintext() {
let plaintext = Plaintext::new().unwrap();
std::mem::drop(plaintext);
}
#[test]
fn plaintext_coefficients_in_increasing_order() {
let plaintext = Plaintext::from_hex_string("1234x^2 + 4321").unwrap();
assert_eq!(plaintext.get_coefficient(0), 0x4321);
assert_eq!(plaintext.get_coefficient(1), 0);
assert_eq!(plaintext.get_coefficient(2), 0x1234);
}
}