use std::hash::Hash;
use primitives::algebra::{
elliptic_curve::{Curve, Point, Scalar},
BoxedUint,
};
use serde::{Deserialize, Serialize};
use wincode::{SchemaRead, SchemaWrite};
use crate::circuit::{
errors::CircuitError,
AlgebraicType,
BatchSize,
BitPlaintextBinaryOp,
BitPlaintextUnaryOp,
BitShareBinaryOp,
BitShareUnaryOp,
Constant,
FieldPlaintextBinaryOp,
FieldPlaintextUnaryOp,
FieldShareBinaryOp,
FieldShareUnaryOp,
FieldType,
GateIndex,
Input,
PointPlaintextBinaryOp,
PointPlaintextUnaryOp,
PointShareBinaryOp,
PointShareUnaryOp,
Slice,
};
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, SchemaRead, SchemaWrite)]
#[serde(bound(
serialize = "Scalar<C>: Serialize, Point<C>: Serialize",
deserialize = "Scalar<C>: Deserialize<'de>, Point<C>: Deserialize<'de>"
))]
#[repr(C)]
pub enum Gate<C: Curve> {
Input(Input),
Constant(Constant<C>),
Random {
algebraic_type: AlgebraicType,
batch_size: BatchSize,
},
FieldShareUnaryOp {
x: GateIndex,
op: FieldShareUnaryOp,
},
FieldShareBinaryOp {
x: GateIndex,
y: GateIndex,
op: FieldShareBinaryOp,
},
BatchSummation {
x: GateIndex,
},
BitShareUnaryOp {
x: GateIndex,
op: BitShareUnaryOp,
},
BitShareBinaryOp {
x: GateIndex,
y: GateIndex,
op: BitShareBinaryOp,
},
PointShareUnaryOp {
p: GateIndex,
op: PointShareUnaryOp,
},
PointShareBinaryOp {
p: GateIndex,
y: GateIndex,
op: PointShareBinaryOp,
},
FieldPlaintextUnaryOp {
x: GateIndex,
op: FieldPlaintextUnaryOp,
},
FieldPlaintextBinaryOp {
x: GateIndex,
y: GateIndex,
op: FieldPlaintextBinaryOp,
},
BitPlaintextUnaryOp {
x: GateIndex,
op: BitPlaintextUnaryOp,
},
BitPlaintextBinaryOp {
x: GateIndex,
y: GateIndex,
op: BitPlaintextBinaryOp,
},
PointPlaintextUnaryOp {
p: GateIndex,
op: PointPlaintextUnaryOp,
},
PointPlaintextBinaryOp {
p: GateIndex,
y: GateIndex,
op: PointPlaintextBinaryOp,
},
DaBit {
field_type: FieldType,
batch_size: BatchSize,
},
GetDaBitFieldShare {
x: GateIndex,
},
GetDaBitSharedBit {
x: GateIndex,
},
BaseFieldPow {
x: GateIndex,
exp: BoxedUint,
},
BitPlaintextToField {
x: GateIndex,
field_type: FieldType,
},
FieldPlaintextToBit {
x: GateIndex,
},
ExtractFromBatch {
x: GateIndex,
slice: Slice,
},
CollectToBatch {
wires: Vec<GateIndex>,
},
PointFromPlaintextExtendedEdwards {
wires: Vec<GateIndex>,
},
PlaintextPointToExtendedEdwards {
point: GateIndex,
},
PlaintextKeccakF1600 {
x: GateIndex,
},
CompressPlaintextPoint {
point: GateIndex,
},
KeyRecoveryPlaintextComputeErrors {
d_minus_one: GateIndex,
syndromes: GateIndex,
},
}
impl<C: Curve> Gate<C> {
pub fn is_input(&self) -> bool {
matches!(self, Gate::Input { .. })
}
pub fn get_inputs(&self) -> Vec<GateIndex> {
match &self {
Gate::Input(_) | Gate::Random { .. } | Gate::Constant(_) | Gate::DaBit { .. } => {
Vec::new()
}
Gate::FieldShareUnaryOp { x, .. }
| Gate::BatchSummation { x, .. }
| Gate::BitShareUnaryOp { x, .. }
| Gate::PointShareUnaryOp { p: x, .. }
| Gate::FieldPlaintextUnaryOp { x, .. }
| Gate::BitPlaintextUnaryOp { x, .. }
| Gate::PointPlaintextUnaryOp { p: x, .. }
| Gate::GetDaBitFieldShare { x, .. }
| Gate::GetDaBitSharedBit { x, .. }
| Gate::BaseFieldPow { x, .. }
| Gate::BitPlaintextToField { x, .. }
| Gate::FieldPlaintextToBit { x, .. }
| Gate::ExtractFromBatch { x, .. }
| Gate::PlaintextPointToExtendedEdwards { point: x, .. }
| Gate::CompressPlaintextPoint { point: x, .. }
| Gate::PlaintextKeccakF1600 { x } => {
vec![*x]
}
Gate::FieldShareBinaryOp { x, y, .. }
| Gate::BitShareBinaryOp { x, y, .. }
| Gate::PointShareBinaryOp { p: x, y, .. }
| Gate::FieldPlaintextBinaryOp { x, y, .. }
| Gate::BitPlaintextBinaryOp { x, y, .. }
| Gate::PointPlaintextBinaryOp { p: x, y, .. }
| Gate::KeyRecoveryPlaintextComputeErrors {
d_minus_one: x,
syndromes: y,
..
} => {
vec![*x, *y]
}
Gate::CollectToBatch { wires, .. }
| Gate::PointFromPlaintextExtendedEdwards { wires, .. } => wires.clone(),
}
}
pub fn map_inputs<F: FnMut(GateIndex) -> GateIndex>(mut self, mut f: F) -> Self {
match &mut self {
Gate::Input(_) | Gate::Random { .. } | Gate::Constant(_) | Gate::DaBit { .. } => (),
Gate::FieldShareUnaryOp { x, .. }
| Gate::BatchSummation { x, .. }
| Gate::BitShareUnaryOp { x, .. }
| Gate::PointShareUnaryOp { p: x, .. }
| Gate::FieldPlaintextUnaryOp { x, .. }
| Gate::BitPlaintextUnaryOp { x, .. }
| Gate::PointPlaintextUnaryOp { p: x, .. }
| Gate::GetDaBitFieldShare { x, .. }
| Gate::GetDaBitSharedBit { x, .. }
| Gate::BaseFieldPow { x, .. }
| Gate::BitPlaintextToField { x, .. }
| Gate::FieldPlaintextToBit { x, .. }
| Gate::ExtractFromBatch { x, .. }
| Gate::PlaintextPointToExtendedEdwards { point: x, .. }
| Gate::CompressPlaintextPoint { point: x, .. }
| Gate::PlaintextKeccakF1600 { x } => {
*x = f(*x);
}
Gate::FieldShareBinaryOp { x, y, .. }
| Gate::BitShareBinaryOp { x, y, .. }
| Gate::PointShareBinaryOp { p: x, y, .. }
| Gate::FieldPlaintextBinaryOp { x, y, .. }
| Gate::BitPlaintextBinaryOp { x, y, .. }
| Gate::PointPlaintextBinaryOp { p: x, y, .. }
| Gate::KeyRecoveryPlaintextComputeErrors {
d_minus_one: x,
syndromes: y,
..
} => {
*x = f(*x);
*y = f(*y);
}
Gate::CollectToBatch { wires, .. }
| Gate::PointFromPlaintextExtendedEdwards { wires, .. } => {
wires.iter_mut().for_each(|x| *x = f(*x))
}
};
self
}
pub fn try_replace_inputs(mut self, inputs: Vec<GateIndex>) -> Result<Self, CircuitError<C>> {
if inputs.len() != self.get_inputs().len() {
return Err(CircuitError::InvalidGateInputCount {
expected: self.get_inputs().len(),
found: inputs.len(),
});
}
match &mut self {
Gate::Input(_) | Gate::Random { .. } | Gate::Constant(_) | Gate::DaBit { .. } => (),
Gate::FieldShareUnaryOp { x, .. }
| Gate::BatchSummation { x, .. }
| Gate::BitShareUnaryOp { x, .. }
| Gate::PointShareUnaryOp { p: x, .. }
| Gate::FieldPlaintextUnaryOp { x, .. }
| Gate::BitPlaintextUnaryOp { x, .. }
| Gate::PointPlaintextUnaryOp { p: x, .. }
| Gate::GetDaBitFieldShare { x, .. }
| Gate::GetDaBitSharedBit { x, .. }
| Gate::BaseFieldPow { x, .. }
| Gate::BitPlaintextToField { x, .. }
| Gate::FieldPlaintextToBit { x, .. }
| Gate::ExtractFromBatch { x, .. }
| Gate::PlaintextPointToExtendedEdwards { point: x, .. }
| Gate::CompressPlaintextPoint { point: x, .. }
| Gate::PlaintextKeccakF1600 { x } => {
*x = inputs[0];
}
Gate::FieldShareBinaryOp { x, y, .. }
| Gate::BitShareBinaryOp { x, y, .. }
| Gate::PointShareBinaryOp { p: x, y, .. }
| Gate::FieldPlaintextBinaryOp { x, y, .. }
| Gate::BitPlaintextBinaryOp { x, y, .. }
| Gate::PointPlaintextBinaryOp { p: x, y, .. }
| Gate::KeyRecoveryPlaintextComputeErrors {
d_minus_one: x,
syndromes: y,
..
} => {
*x = inputs[0];
*y = inputs[1];
}
Gate::CollectToBatch { wires, .. }
| Gate::PointFromPlaintextExtendedEdwards { wires, .. } => *wires = inputs,
};
Ok(self)
}
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use primitives::algebra::elliptic_curve::Curve25519Ristretto as C;
use super::*;
use crate::circuit::{
preprocessing::{CircuitPreprocessing, FieldCircuitPreprocessing},
FieldShareBinaryOp,
};
#[test]
fn test_ser_gate() {
let scalar_gate: Gate<C> = Gate::FieldShareBinaryOp {
x: 1,
y: 3,
op: FieldShareBinaryOp::Add,
};
let point_gate: Gate<C> = Gate::PointShareBinaryOp {
p: 1,
y: 3,
op: PointShareBinaryOp::Add,
};
let scalar_gate_ser = bincode::serialize(&scalar_gate).unwrap();
let point_gate_ser = bincode::serialize(&point_gate).unwrap();
let scalar_gate_de: Gate<C> = bincode::deserialize(&scalar_gate_ser).unwrap();
let point_gate_de: Gate<C> = bincode::deserialize(&point_gate_ser).unwrap();
assert_eq!(scalar_gate, scalar_gate_de);
assert_eq!(point_gate, point_gate_de);
let set = HashSet::from([scalar_gate, scalar_gate_de, point_gate, point_gate_de]);
assert_eq!(set.len(), 2)
}
#[test]
fn test_circuit_preprocessing_add() {
let a = CircuitPreprocessing {
bit_singlets: 0,
bit_triples: 1,
base_field: FieldCircuitPreprocessing {
singlets: 3,
triples: 4,
dabits: 2,
},
scalar: FieldCircuitPreprocessing {
singlets: 1,
triples: 2,
dabits: 1,
},
base_field_pow_pairs: vec![
(BoxedUint::from(vec![21]), 5),
(BoxedUint::from(vec![14]), 6),
]
.into_iter()
.collect(),
mersenne107: FieldCircuitPreprocessing {
singlets: 0,
triples: 0,
dabits: 0,
},
};
let b = CircuitPreprocessing {
bit_singlets: 3,
bit_triples: 4,
base_field: FieldCircuitPreprocessing {
singlets: 0,
triples: 5,
dabits: 3,
},
scalar: FieldCircuitPreprocessing {
singlets: 2,
triples: 3,
dabits: 2,
},
base_field_pow_pairs: vec![
(BoxedUint::from(vec![21]), 6),
(BoxedUint::from(vec![13]), 7),
]
.into_iter()
.collect(),
mersenne107: FieldCircuitPreprocessing {
singlets: 3,
triples: 2,
dabits: 0,
},
};
let c = a + b;
assert_eq!(c.scalar.singlets, 3);
assert_eq!(c.scalar.triples, 5);
assert_eq!(c.base_field.singlets, 3);
assert_eq!(c.base_field.triples, 9);
assert_eq!(c.bit_singlets, 3);
assert_eq!(c.bit_triples, 5);
assert_eq!(c.mersenne107.dabits, 0);
assert_eq!(c.mersenne107.singlets, 3);
assert_eq!(c.mersenne107.triples, 2);
assert_eq!(c.scalar.dabits, 3);
assert_eq!(c.base_field.dabits, 5);
assert_eq!(
c.base_field_pow_pairs.get(&BoxedUint::from(vec![21])),
Some(&11)
);
assert_eq!(
c.base_field_pow_pairs.get(&BoxedUint::from(vec![14])),
Some(&6)
);
assert_eq!(
c.base_field_pow_pairs.get(&BoxedUint::from(vec![13])),
Some(&7)
);
}
}