use std::{
io,
io::{Read, Write},
};
use ff::PrimeField;
use group::{Curve, GroupEncoding};
use midnight_curves::serde::SerdeObject;
use crate::poly::Polynomial;
#[derive(Clone, Copy, Debug)]
pub enum SerdeFormat {
Processed,
RawBytes,
RawBytesUnchecked,
}
pub trait ProcessedSerdeObject: GroupEncoding {
fn read<R: io::Read>(reader: &mut R, format: SerdeFormat) -> io::Result<Self>;
fn write<W: io::Write>(&self, writer: &mut W, format: SerdeFormat) -> io::Result<()>;
}
pub fn byte_length<T: ProcessedSerdeObject>(format: SerdeFormat) -> usize {
match format {
SerdeFormat::Processed => <T as GroupEncoding>::Repr::default().as_ref().len(),
_ => <T as GroupEncoding>::Repr::default().as_ref().len() * 2,
}
}
pub(crate) fn read_f<F: PrimeField + SerdeObject, R: io::Read>(
reader: &mut R,
format: SerdeFormat,
) -> io::Result<F> {
match format {
SerdeFormat::Processed => <F as SerdeObject>::read_raw(reader),
SerdeFormat::RawBytes => <F as SerdeObject>::read_raw(reader),
SerdeFormat::RawBytesUnchecked => Ok(<F as SerdeObject>::read_raw_unchecked(reader)),
}
}
impl<C> ProcessedSerdeObject for C
where
C: Curve + Default + GroupEncoding + From<C::AffineRepr>,
C::AffineRepr: SerdeObject,
{
fn read<R: Read>(reader: &mut R, format: SerdeFormat) -> io::Result<Self> {
{
match format {
SerdeFormat::Processed => {
let mut compressed = <Self as GroupEncoding>::Repr::default();
reader.read_exact(compressed.as_mut())?;
Option::from(Self::from_bytes(&compressed))
.ok_or_else(|| io::Error::other("Invalid point encoding in proof"))
}
SerdeFormat::RawBytes => {
<Self as Curve>::AffineRepr::read_raw(reader).map(|p| p.into())
}
SerdeFormat::RawBytesUnchecked => {
Ok(<Self as Curve>::AffineRepr::read_raw_unchecked(reader).into())
}
}
}
}
fn write<W: Write>(&self, writer: &mut W, format: SerdeFormat) -> io::Result<()> {
match format {
SerdeFormat::Processed => writer.write_all(self.to_bytes().as_ref()),
_ => self.to_affine().write_raw(writer),
}
}
}
pub fn pack(bits: &[bool]) -> u8 {
let mut value = 0u8;
assert!(bits.len() <= 8);
for (bit_index, bit) in bits.iter().enumerate() {
value |= (*bit as u8) << bit_index;
}
value
}
pub fn unpack(byte: u8, bits: &mut [bool]) {
for (bit_index, bit) in bits.iter_mut().enumerate() {
*bit = (byte >> bit_index) & 1 == 1;
}
}
pub(crate) fn read_polynomial_vec<R: io::Read, F: PrimeField + SerdeObject, B>(
reader: &mut R,
format: SerdeFormat,
) -> io::Result<Vec<Polynomial<F, B>>> {
let mut len = [0u8; 4];
reader.read_exact(&mut len)?;
let len = u32::from_be_bytes(len);
(0..len)
.map(|_| Polynomial::<F, B>::read(reader, format))
.collect::<io::Result<Vec<_>>>()
}
pub(crate) fn write_polynomial_slice<W: io::Write, F: PrimeField + SerdeObject, B>(
slice: &[Polynomial<F, B>],
writer: &mut W,
) -> io::Result<()> {
writer.write_all(&(slice.len() as u32).to_be_bytes())?;
for poly in slice.iter() {
poly.write(writer)?;
}
Ok(())
}
pub(crate) fn polynomial_slice_byte_length<F: PrimeField, B>(slice: &[Polynomial<F, B>]) -> usize {
let field_len = F::default().to_repr().as_ref().len();
4 + slice.len() * (4 + field_len * slice.first().map(|poly| poly.len()).unwrap_or(0))
}