use std::iter::once;
use crate::curve::{EmbeddedFr, EmbeddedGroupAffine};
use crate::curve::{FR_BYTES, FR_BYTES_STORED, Fr};
use crate::hash::transient_commit;
use crate::merkle_tree::{MerklePath, MerkleTreeDigest};
use crate::repr::{FieldRepr, bytes_from_field_repr};
use base_crypto::fab::{
Aligned, AlignedValue, Alignment, AlignmentAtom, AlignmentSegment, DynAligned,
FIELD_BYTE_LIMIT, InvalidBuiltinDecode, Value, ValueAtom, ValueSlice, int_size,
};
use base_crypto::repr::{BinaryHashRepr, MemWrite};
use rand::Rng;
use rand::distributions::Standard;
use rand::prelude::Distribution;
pub(crate) trait ValueExt {
fn repr_traverse<
T,
F: Fn(T, &AlignmentAtom, &ValueAtom) -> T,
L: Fn(&Alignment) -> usize,
P: Fn(T, usize) -> T,
>(
atom_slice: &mut &[ValueAtom],
align: &Alignment,
f: &F,
len: &L,
pad: &P,
acc: T,
) -> T;
fn field_repr_unchecked<W: MemWrite<Fr>>(&self, align: &Alignment, writer: &mut W);
fn binary_repr_unchecked<W: MemWrite<u8>>(&self, align: &Alignment, writer: &mut W);
}
impl From<MerkleTreeDigest> for ValueAtom {
fn from(val: MerkleTreeDigest) -> ValueAtom {
Fr::from(val).into()
}
}
impl TryFrom<&ValueAtom> for MerkleTreeDigest {
type Error = InvalidBuiltinDecode;
fn try_from(value: &ValueAtom) -> Result<MerkleTreeDigest, InvalidBuiltinDecode> {
Ok(Fr::try_from(value)?.into())
}
}
impl<T: Into<Value>> From<MerklePath<T>> for Value {
fn from(path: MerklePath<T>) -> Value {
let mut parts = Vec::new();
parts.push(path.leaf.into());
for entry in path.path.iter() {
parts.push(entry.sibling.into());
parts.push(entry.goes_left.into());
}
Value::concat(parts.iter())
}
}
impl From<MerkleTreeDigest> for Value {
fn from(val: MerkleTreeDigest) -> Value {
Value(vec![val.into()])
}
}
impl TryFrom<&ValueSlice> for MerkleTreeDigest {
type Error = InvalidBuiltinDecode;
fn try_from(value: &ValueSlice) -> Result<MerkleTreeDigest, InvalidBuiltinDecode> {
if value.0.len() == 1 {
Ok(MerkleTreeDigest::try_from(&value.0[0])?)
} else {
Err(InvalidBuiltinDecode(stringify!($ty)))
}
}
}
impl Aligned for MerkleTreeDigest {
fn alignment() -> Alignment {
Alignment::singleton(AlignmentAtom::Field)
}
}
impl<T: DynAligned> DynAligned for MerklePath<T> {
fn dyn_alignment(&self) -> Alignment {
let leaf_align = self.leaf.dyn_alignment();
let entry_align = Alignment::concat([&MerkleTreeDigest::alignment(), &bool::alignment()]);
Alignment::concat(
once(&leaf_align).chain(std::iter::repeat_n(&entry_align, self.path.len())),
)
}
}
impl From<EmbeddedGroupAffine> for Value {
fn from(value: EmbeddedGroupAffine) -> Value {
Value(vec![
value.x().unwrap_or(0.into()).into(),
value.y().unwrap_or(0.into()).into(),
])
}
}
impl From<EmbeddedFr> for Value {
fn from(val: EmbeddedFr) -> Value {
Value(vec![val.into()])
}
}
impl TryFrom<&ValueSlice> for EmbeddedGroupAffine {
type Error = InvalidBuiltinDecode;
fn try_from(value: &ValueSlice) -> Result<EmbeddedGroupAffine, InvalidBuiltinDecode> {
if value.0.len() == 2 {
let x: Fr = (&value.0[0]).try_into()?;
let y: Fr = (&value.0[1]).try_into()?;
let is_identity = EmbeddedGroupAffine::HAS_INFINITY && x == 0.into() && y == 0.into();
if is_identity {
Ok(EmbeddedGroupAffine::identity())
} else {
Ok(EmbeddedGroupAffine::new(x, y)
.ok_or(InvalidBuiltinDecode("EmbeddedGroupAffine"))?)
}
} else {
Err(InvalidBuiltinDecode("EmbeddedGroupAffine"))
}
}
}
impl TryFrom<&ValueSlice> for EmbeddedFr {
type Error = InvalidBuiltinDecode;
fn try_from(value: &ValueSlice) -> Result<EmbeddedFr, InvalidBuiltinDecode> {
if value.0.len() == 1 {
Ok(<EmbeddedFr>::try_from(&value.0[0])?)
} else {
Err(InvalidBuiltinDecode(stringify!(EmbeddedFr)))
}
}
}
impl From<EmbeddedFr> for ValueAtom {
fn from(val: EmbeddedFr) -> ValueAtom {
ValueAtom(val.as_le_bytes()).normalize()
}
}
impl TryFrom<&ValueAtom> for EmbeddedFr {
type Error = InvalidBuiltinDecode;
fn try_from(value: &ValueAtom) -> Result<EmbeddedFr, InvalidBuiltinDecode> {
if value.0.len() <= FR_BYTES {
EmbeddedFr::from_le_bytes(&value.0).ok_or(InvalidBuiltinDecode("EmbeddedFr"))
} else {
Err(InvalidBuiltinDecode("EmbeddedFr"))
}
}
}
macro_rules! forward_primitive_value {
($($ty:ty),*) => {
$(
impl From<$ty> for Value {
fn from(val: $ty) -> Value {
Value(vec![val.into()])
}
}
impl TryFrom<&ValueSlice> for $ty {
type Error = InvalidBuiltinDecode;
fn try_from(value: &ValueSlice) -> Result<$ty, InvalidBuiltinDecode> {
if value.0.len() == 1 {
Ok(<$ty>::try_from(&value.0[0])?)
} else {
Err(InvalidBuiltinDecode(stringify!($ty)))
}
}
}
)*
}
}
forward_primitive_value!(Fr);
impl From<Fr> for ValueAtom {
fn from(val: Fr) -> ValueAtom {
ValueAtom(val.as_le_bytes()).normalize()
}
}
impl TryFrom<&ValueAtom> for Fr {
type Error = InvalidBuiltinDecode;
fn try_from(value: &ValueAtom) -> Result<Fr, InvalidBuiltinDecode> {
if value.0.len() <= FR_BYTES {
Fr::from_le_bytes(&value.0).ok_or(InvalidBuiltinDecode("Fr"))
} else {
Err(InvalidBuiltinDecode("Fr"))
}
}
}
impl ValueExt for Value {
fn repr_traverse<
T,
F: Fn(T, &AlignmentAtom, &ValueAtom) -> T,
L: Fn(&Alignment) -> usize,
P: Fn(T, usize) -> T,
>(
atom_slice: &mut &[ValueAtom],
align: &Alignment,
f: &F,
len: &L,
pad: &P,
mut acc: T,
) -> T {
for segment in align.0.iter() {
match segment {
AlignmentSegment::Atom(atom) => {
acc = f(acc, atom, &atom_slice[0]);
*atom_slice = &atom_slice[1..];
}
AlignmentSegment::Option(options) => {
let discriminant = u16::try_from(&atom_slice[0])
.expect("unchecked discriminant should decode");
let choice = &options[discriminant as usize];
acc = f(acc, &AlignmentAtom::Bytes { length: 2 }, &atom_slice[0]);
*atom_slice = &atom_slice[1..];
acc = Value::repr_traverse(atom_slice, choice, f, len, pad, acc);
let padding = options.iter().map(len).max().unwrap_or(0) - len(choice);
acc = pad(acc, padding);
}
}
}
acc
}
fn field_repr_unchecked<W: MemWrite<Fr>>(&self, align: &Alignment, writer: &mut W) {
let mut unconsumed_value = &self.0[..];
Value::repr_traverse(
&mut unconsumed_value,
align,
&|mut w: &mut W, a, v| {
v.field_repr_unchecked(a, &mut w);
w
},
&Alignment::field_len,
&|w, n| {
w.write(&vec![Fr::from(0); n]);
w
},
writer,
);
debug_assert!(unconsumed_value.is_empty());
}
fn binary_repr_unchecked<W: MemWrite<u8>>(&self, align: &Alignment, writer: &mut W) {
let mut unconsumed_value = &self.0[..];
Value::repr_traverse(
&mut unconsumed_value,
align,
&|mut w: &mut W, a, v| {
v.binary_repr_unchecked(a, &mut w);
w
},
&Alignment::bin_len,
&|w, n| {
w.write(&vec![0u8; n]);
w
},
writer,
);
debug_assert!(unconsumed_value.is_empty());
}
}
pub trait AlignmentExt {
fn parse_field_repr(&self, repr: &[Fr]) -> Option<AlignedValue>;
fn max_aligned_size(&self) -> usize;
fn field_len(&self) -> usize;
fn bin_len(&self) -> usize;
}
fn parse_field_repr_inner(
segments: &[AlignmentSegment],
repr: &mut &[Fr],
val: &mut Vec<ValueAtom>,
) -> Option<()> {
for segment in segments.iter() {
match segment {
AlignmentSegment::Atom(atom) => val.push(atom.parse_field_repr(repr)?),
AlignmentSegment::Option(options) => {
let variant = u16::try_from(*repr.first()?).ok()?;
*repr = &repr[1..];
val.push(variant.into());
let choice = options.get(variant as usize)?;
parse_field_repr_inner(&choice.0, repr, val)?;
let padding = options.iter().map(Alignment::field_len).max().unwrap_or(0)
- choice.field_len();
if repr.len() < padding || repr[..padding].iter().any(|f| *f != Fr::from(0)) {
return None;
}
*repr = &repr[padding..];
}
}
}
Some(())
}
impl AlignmentExt for Alignment {
fn parse_field_repr(&self, mut repr: &[Fr]) -> Option<AlignedValue> {
let mut value = Vec::new();
parse_field_repr_inner(&self.0, &mut repr, &mut value)?;
Some(AlignedValue {
value: Value(value),
alignment: self.clone(),
})
}
fn max_aligned_size(&self) -> usize {
1 + int_size(self.0.len())
+ self
.0
.iter()
.map(AlignmentSegment::max_aligned_size)
.sum::<usize>()
}
fn field_len(&self) -> usize {
self.0.iter().map(AlignmentSegment::field_len).sum()
}
fn bin_len(&self) -> usize {
self.0.iter().map(AlignmentSegment::bin_len).sum()
}
}
impl FieldRepr for Alignment {
fn field_repr<W: MemWrite<Fr>>(&self, writer: &mut W) {
(self.0.len() as u32).field_repr(writer);
for segment in self.0.iter() {
segment.field_repr(writer);
}
}
fn field_size(&self) -> usize {
1 + self.0.iter().map(FieldRepr::field_size).sum::<usize>()
}
}
impl FieldRepr for AlignedValue {
fn field_repr<W: MemWrite<Fr>>(&self, writer: &mut W) {
self.alignment.field_repr(writer);
self.value.field_repr_unchecked(&self.alignment, writer);
}
fn field_size(&self) -> usize {
self.alignment.field_size() + self.alignment.field_len()
}
}
pub trait AlignedValueExt {
fn value_only_field_repr<W: MemWrite<Fr>>(&self, writer: &mut W);
fn value_only_field_size(&self) -> usize;
}
impl AlignedValueExt for AlignedValue {
fn value_only_field_repr<W: MemWrite<Fr>>(&self, writer: &mut W) {
self.value.field_repr_unchecked(&self.alignment, writer)
}
fn value_only_field_size(&self) -> usize {
self.alignment.field_len()
}
}
pub struct ValueReprAlignedValue(pub AlignedValue);
impl From<ValueReprAlignedValue> for Value {
fn from(value: ValueReprAlignedValue) -> Value {
value.0.value
}
}
impl FieldRepr for ValueReprAlignedValue {
fn field_repr<W: MemWrite<Fr>>(&self, writer: &mut W) {
self.0.value_only_field_repr(writer);
}
fn field_size(&self) -> usize {
self.0.value_only_field_size()
}
}
impl BinaryHashRepr for ValueReprAlignedValue {
fn binary_repr<W: MemWrite<u8>>(&self, writer: &mut W) {
self.0
.value
.binary_repr_unchecked(&self.0.alignment, writer);
}
fn binary_len(&self) -> usize {
self.0.alignment.bin_len()
}
}
impl DynAligned for ValueReprAlignedValue {
fn dyn_alignment(&self) -> Alignment {
self.0.dyn_alignment()
}
}
pub(crate) trait ValueAtomExt {
#[allow(dead_code)]
fn field_repr<W: MemWrite<Fr>>(&self, ty: &AlignmentAtom, writer: &mut W) -> bool;
fn field_repr_unchecked<W: MemWrite<Fr>>(&self, ty: &AlignmentAtom, writer: &mut W);
fn binary_repr_unchecked<W: MemWrite<u8>>(&self, ty: &AlignmentAtom, writer: &mut W);
}
fn value_atom_as_field(atom: &ValueAtom) -> Fr {
let mut bytes = [0u8; FIELD_BYTE_LIMIT];
assert!(atom.0.len() <= FIELD_BYTE_LIMIT);
bytes[..atom.0.len()].copy_from_slice(&atom.0);
Fr::from_uniform_bytes(&bytes)
}
impl ValueAtomExt for ValueAtom {
fn field_repr<W: MemWrite<Fr>>(&self, ty: &AlignmentAtom, writer: &mut W) -> bool {
if !ty.fits(self) {
false
} else {
self.field_repr_unchecked(ty, writer);
true
}
}
fn field_repr_unchecked<W: MemWrite<Fr>>(&self, ty: &AlignmentAtom, writer: &mut W) {
match ty {
AlignmentAtom::Compress => {
if self.0.is_empty() {
writer.write(&[0.into()]);
} else {
writer.write(&[transient_commit(&self.0[..], (self.0.len() as u64).into())])
}
}
AlignmentAtom::Bytes { length } => {
let prepend_zeros = (*length as usize).div_ceil(FR_BYTES_STORED)
- self.0.len().div_ceil(FR_BYTES_STORED);
let raw = self
.0
.chunks(FR_BYTES_STORED)
.map(|bytes| {
Fr::from_le_bytes(bytes).expect("Bytes must fit into FR_BYTES_STORED chunk")
})
.rev();
writer.write(&vec![Fr::from(0); prepend_zeros]);
writer.write(&raw.collect::<Vec<_>>());
}
AlignmentAtom::Field => writer.write(&[value_atom_as_field(self)]),
}
}
fn binary_repr_unchecked<W: MemWrite<u8>>(&self, ty: &AlignmentAtom, writer: &mut W) {
match ty {
AlignmentAtom::Compress => {
transient_commit(&self.0[..], (self.0.len() as u64).into()).binary_repr(writer);
}
AlignmentAtom::Bytes { length } => {
writer.write(&self.0);
let missing_bytes = (*length as usize) - self.0.len();
let zeroes = vec![0u8; missing_bytes];
writer.write(&zeroes);
}
AlignmentAtom::Field => {
value_atom_as_field(self).binary_repr(writer);
}
}
}
}
pub(crate) trait AlignmentAtomExt {
fn parse_field_repr(&self, repr: &mut &[Fr]) -> Option<ValueAtom>;
#[allow(dead_code)]
fn sample_value_atom<R: Rng + ?Sized>(&self, rng: &mut R) -> ValueAtom;
fn max_aligned_size(&self) -> usize;
fn field_len(&self) -> usize;
fn bin_len(&self) -> usize;
}
impl AlignmentAtomExt for AlignmentAtom {
fn parse_field_repr(&self, repr: &mut &[Fr]) -> Option<ValueAtom> {
match self {
AlignmentAtom::Compress => None,
AlignmentAtom::Field => {
let res = repr.first()?;
*repr = &repr[1..];
Some(ValueAtom(res.as_le_bytes()).normalize())
}
AlignmentAtom::Bytes { length } => {
bytes_from_field_repr(repr, *length as usize).map(ValueAtom)
}
}
}
fn sample_value_atom<R: Rng + ?Sized>(&self, rng: &mut R) -> ValueAtom
where
Standard: Distribution<Fr>,
{
match self {
Self::Compress | Self::Field => {
let val = rng.r#gen::<Fr>().as_le_bytes();
ValueAtom(val).normalize()
}
Self::Bytes { length } => {
let mut bytes: Vec<u8> = vec![0; *length as usize];
rng.fill_bytes(&mut bytes);
ValueAtom(bytes).normalize()
}
}
}
fn max_aligned_size(&self) -> usize {
match self {
AlignmentAtom::Compress | AlignmentAtom::Field => 2 + FR_BYTES,
AlignmentAtom::Bytes { length } => 2 + int_size(*length as usize) + *length as usize,
}
}
fn field_len(&self) -> usize {
match self {
AlignmentAtom::Compress | AlignmentAtom::Field => 1,
AlignmentAtom::Bytes { length } => length.div_ceil(FR_BYTES_STORED as u32) as usize,
}
}
fn bin_len(&self) -> usize {
match self {
AlignmentAtom::Compress | AlignmentAtom::Field => FR_BYTES,
AlignmentAtom::Bytes { length } => *length as usize,
}
}
}
impl FieldRepr for AlignmentAtom {
fn field_repr<W: MemWrite<Fr>>(&self, writer: &mut W) {
match self {
AlignmentAtom::Bytes { length } => writer.write(&[(*length).into()]),
AlignmentAtom::Compress => writer.write(&[(-1).into()]),
AlignmentAtom::Field => writer.write(&[(-2).into()]),
}
}
fn field_size(&self) -> usize {
1
}
}
pub(crate) trait AlignmentSegmentExt {
fn max_aligned_size(&self) -> usize;
fn field_len(&self) -> usize;
fn bin_len(&self) -> usize;
}
impl AlignmentSegmentExt for AlignmentSegment {
fn max_aligned_size(&self) -> usize {
match self {
AlignmentSegment::Atom(atom) => atom.max_aligned_size(),
AlignmentSegment::Option(options) => options
.iter()
.map(Alignment::max_aligned_size)
.max()
.unwrap_or(0),
}
}
fn field_len(&self) -> usize {
match self {
AlignmentSegment::Atom(atom) => atom.field_len(),
AlignmentSegment::Option(options) => {
1 + options.iter().map(Alignment::field_len).max().unwrap_or(0)
}
}
}
fn bin_len(&self) -> usize {
match self {
AlignmentSegment::Atom(atom) => atom.bin_len(),
AlignmentSegment::Option(options) => {
2 + options.iter().map(Alignment::bin_len).max().unwrap_or(0)
}
}
}
}
impl FieldRepr for AlignmentSegment {
fn field_repr<W: MemWrite<Fr>>(&self, writer: &mut W) {
match self {
AlignmentSegment::Atom(atom) => atom.field_repr(writer),
AlignmentSegment::Option(options) => {
writer.write(&[(-3).into(), (options.len() as u32).into()]);
for option in options {
option.field_repr(writer);
}
}
}
}
fn field_size(&self) -> usize {
match self {
AlignmentSegment::Atom(atom) => atom.field_size(),
AlignmentSegment::Option(options) => {
2 + options.iter().map(FieldRepr::field_size).sum::<usize>()
}
}
}
}
#[cfg(all(test, feature = "proptest"))]
mod tests {
use super::*;
use base_crypto::fab::AlignedValue;
use proptest::prelude::*;
proptest! {
#[test]
fn test_fab_repr_consistency(value in AlignedValue::arbitrary()) {
assert!(value.alignment.fits(&value.value));
assert_eq!(value.field_vec().len(), value.field_size());
let value_repr = ValueReprAlignedValue(value);
assert_eq!(value_repr.binary_vec().len(), value_repr.binary_len());
}
}
}