use std::ops::{Index, RangeTo};
use ff::PrimeField;
use num_bigint::BigUint;
pub trait CircuitField: PrimeField {
const NUM_BYTES: usize;
type Bytes: Copy
+ Send
+ Sync
+ 'static
+ AsRef<[u8]>
+ AsMut<[u8]>
+ Index<usize, Output = u8>
+ Index<RangeTo<usize>, Output = [u8]>;
fn to_biguint(&self) -> BigUint;
fn from_biguint(n: &BigUint) -> Option<Self>;
fn modulus() -> BigUint {
(-Self::ONE).to_biguint() + 1u64
}
fn to_bytes_le(&self) -> Self::Bytes;
fn to_bytes_be(&self) -> Self::Bytes;
fn from_bytes_le(bytes: &[u8]) -> Option<Self>;
fn from_bytes_be(bytes: &[u8]) -> Option<Self> {
let mut bytes_le: Vec<u8> = bytes.into();
bytes_le.reverse();
Self::from_bytes_le(&bytes_le)
}
fn to_bits_le(&self, nb_bits: Option<usize>) -> Vec<bool> {
let bytes = self.to_bytes_le();
let all_bits: Vec<bool> = bytes
.as_ref()
.iter()
.flat_map(|byte| (0..8).map(move |j| byte & (1 << j) != 0))
.collect();
match nb_bits {
Some(n) => {
assert!(
n > 0 && all_bits[n..].iter().all(|b| !b),
"field element does not fit in {n} bits"
);
all_bits[..n].to_vec()
}
None => {
let len = all_bits.iter().rposition(|b| *b).unwrap_or(0);
all_bits[..=len].to_vec()
}
}
}
fn from_bits_le(bits: &[bool]) -> Self {
assert!(bits.len() as u32 <= Self::NUM_BITS);
let mut bytes = vec![0u8; Self::NUM_BYTES];
for (i, chunk) in bits.chunks(8).enumerate() {
bytes[i] = chunk
.iter()
.enumerate()
.fold(0u8, |acc, (j, b)| acc + if *b { 1 << j } else { 0 });
}
Self::from_bytes_le(&bytes).unwrap()
}
}
macro_rules! impl_circuit_field_le {
($field:ty, $repr_size:expr) => {
impl CircuitField for $field {
const NUM_BYTES: usize = $repr_size;
type Bytes = [u8; $repr_size];
fn to_biguint(&self) -> BigUint {
BigUint::from_bytes_le(self.to_repr().as_ref())
}
fn from_biguint(n: &BigUint) -> Option<Self> {
let bytes = n.to_bytes_le();
if bytes.len() > $repr_size {
return None;
}
let mut padded = [0u8; $repr_size];
padded[..bytes.len()].copy_from_slice(&bytes);
Self::from_repr(padded.into()).into()
}
fn from_bytes_le(bytes: &[u8]) -> Option<Self> {
let mut repr = [0u8; $repr_size];
repr.copy_from_slice(bytes);
<$field as PrimeField>::from_repr(repr.into()).into_option()
}
fn to_bytes_le(&self) -> Self::Bytes {
let mut bytes = [0u8; $repr_size];
bytes.copy_from_slice(self.to_repr().as_ref());
bytes
}
fn to_bytes_be(&self) -> Self::Bytes {
let mut bytes = [0u8; $repr_size];
bytes.copy_from_slice(self.to_repr().as_ref());
bytes.reverse();
bytes
}
}
};
}
macro_rules! impl_circuit_field_be {
($field:ty, $repr_size:expr) => {
impl CircuitField for $field {
const NUM_BYTES: usize = $repr_size;
type Bytes = [u8; $repr_size];
fn to_biguint(&self) -> BigUint {
BigUint::from_bytes_be(self.to_repr().as_ref())
}
fn from_biguint(n: &BigUint) -> Option<Self> {
let bytes = n.to_bytes_be();
if bytes.len() > $repr_size {
return None;
}
let mut padded = [0u8; $repr_size];
padded[$repr_size - bytes.len()..].copy_from_slice(&bytes);
Self::from_repr(padded.into()).into()
}
fn from_bytes_le(bytes: &[u8]) -> Option<Self> {
let mut repr = [0u8; $repr_size];
repr.copy_from_slice(&bytes);
repr.reverse();
<$field as PrimeField>::from_repr(repr.into()).into_option()
}
fn to_bytes_le(&self) -> Self::Bytes {
let mut bytes = [0u8; $repr_size];
bytes.copy_from_slice(self.to_repr().as_ref());
bytes.reverse();
bytes
}
fn to_bytes_be(&self) -> Self::Bytes {
let mut bytes = [0u8; $repr_size];
bytes.copy_from_slice(self.to_repr().as_ref());
bytes
}
}
};
}
impl_circuit_field_le!(midnight_curves::Fr, 32);
impl_circuit_field_le!(midnight_curves::Fq, 32);
impl_circuit_field_le!(midnight_curves::Fp, 48);
impl_circuit_field_be!(midnight_curves::k256::Fp, 32);
impl_circuit_field_be!(midnight_curves::k256::Fq, 32);
impl_circuit_field_be!(midnight_curves::p256::Fp, 32);
impl_circuit_field_be!(midnight_curves::p256::Fq, 32);
impl_circuit_field_le!(midnight_curves::curve25519::Fp, 32);
impl_circuit_field_le!(midnight_curves::curve25519::Scalar, 32);
#[cfg(feature = "dev-curves")]
impl_circuit_field_le!(midnight_curves::bn256::Fq, 32);
#[cfg(feature = "dev-curves")]
impl_circuit_field_le!(midnight_curves::bn256::Fr, 32);
#[cfg(test)]
mod tests {
use ff::Field;
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use super::*;
type F = midnight_curves::Fq;
#[test]
fn test_biguint_roundtrip() {
let mut rng = ChaCha8Rng::seed_from_u64(0xCAFE);
for _ in 0..100 {
let fe = F::random(&mut rng);
let big = fe.to_biguint();
let recovered = F::from_biguint(&big).unwrap();
assert_eq!(fe, recovered);
}
}
#[test]
fn test_modulus_rejected() {
let modulus = F::modulus();
assert!(F::from_biguint(&modulus).is_none());
let too_large = &modulus + 1u64;
assert!(F::from_biguint(&too_large).is_none());
}
#[test]
fn test_zero() {
let zero = F::ZERO;
let big = zero.to_biguint();
assert_eq!(big, BigUint::from(0u64));
let recovered = F::from_biguint(&big).unwrap();
assert_eq!(zero, recovered);
}
#[test]
fn test_one() {
let one = F::ONE;
let big = one.to_biguint();
assert_eq!(big, BigUint::from(1u64));
let recovered = F::from_biguint(&big).unwrap();
assert_eq!(one, recovered);
}
#[test]
fn test_bytes_le_roundtrip() {
let mut rng = ChaCha8Rng::seed_from_u64(0xBEEF);
for _ in 0..100 {
let fe = F::random(&mut rng);
let bytes = fe.to_bytes_le();
assert_eq!(bytes.len(), 32); let recovered = F::from_bytes_le(&bytes).unwrap();
assert_eq!(fe, recovered);
}
}
#[test]
fn test_bits_le_roundtrip() {
let mut rng = ChaCha8Rng::seed_from_u64(0xFACE);
for _ in 0..100 {
let fe = F::random(&mut rng);
let bits = fe.to_bits_le(None);
let recovered = F::from_bits_le(&bits);
assert_eq!(fe, recovered);
}
for _ in 0..100 {
let fe = F::random(&mut rng);
let bits = fe.to_bits_le(Some(F::NUM_BITS as usize));
assert_eq!(bits.len(), F::NUM_BITS as usize);
let recovered = F::from_bits_le(&bits);
assert_eq!(fe, recovered);
}
let bits = F::ZERO.to_bits_le(None);
assert_eq!(bits, vec![false]);
assert_eq!(F::from_bits_le(&bits), F::ZERO);
let bits = F::ONE.to_bits_le(None);
assert_eq!(bits, vec![true]);
assert_eq!(F::from_bits_le(&bits), F::ONE);
}
#[test]
fn test_bytes_be_roundtrip() {
let mut rng = ChaCha8Rng::seed_from_u64(0xDEAD);
for _ in 0..100 {
let fe = F::random(&mut rng);
let bytes = fe.to_bytes_be();
assert_eq!(bytes.len(), 32);
let recovered = F::from_bytes_be(&bytes).unwrap();
assert_eq!(fe, recovered);
}
}
}