use primitives::algebra::elliptic_curve::{Curve, Point, Scalar};
use serde::Deserialize;
use super::{
constants::{
BaseFieldPlaintext,
BaseFieldPlaintextBatch,
BitPlaintext,
BitPlaintextBatch,
Mersenne107Plaintext,
Mersenne107PlaintextBatch,
PointPlaintext,
PointPlaintextBatch,
ScalarPlaintext,
ScalarPlaintextBatch,
},
gate::Gate,
ops::Input,
};
use crate::circuit::{v1::errors::ConversionErrorToV2, v2, AlgebraicType, BatchSize, Slice};
#[derive(Deserialize)]
#[repr(transparent)]
pub struct GateIndex(u32);
impl GateIndex {
fn into_usize(self) -> usize {
self.0 as usize
}
}
#[derive(Deserialize)]
#[serde(bound(deserialize = "Scalar<C>: Deserialize<'de>, Point<C>: Deserialize<'de>"))]
#[repr(C)]
pub struct Circuit<C: Curve> {
ops: Vec<Gate<C>>,
input_gates: Vec<GateIndex>,
output_gates: Vec<GateIndex>,
}
impl<C: Curve> Circuit<C> {
pub fn into_v2(self) -> Result<v2::Circuit<C>, ConversionErrorToV2<C>> {
let mut circuit = v2::Circuit::new();
let mut old_to_new_idx = vec![0; self.ops.len()];
for (old_gate_idx, gate) in self.ops.into_iter().enumerate() {
let gate = match gate {
Gate::Input { input_type } => match input_type {
Input::SecretPlaintext {
inputer,
algebraic_type,
batched,
} => v2::Gate::Input(v2::Input::SecretPlaintext {
inputer,
algebraic_type,
batch_size: batched.count() as BatchSize,
}),
Input::Share {
algebraic_type,
batched,
} => v2::Gate::Input(v2::Input::Share {
algebraic_type,
batch_size: batched.count() as BatchSize,
}),
Input::RandomShare {
algebraic_type,
batched,
} => v2::Gate::Random {
algebraic_type,
batch_size: batched.count() as BatchSize,
},
Input::Scalar(val) => match val {
ScalarPlaintext::<C>::Fixed(val) => {
v2::Gate::Constant(v2::Constant::Scalar(val))
}
ScalarPlaintext::<C>::Input(val) => v2::Gate::Input(v2::Input::Plaintext {
algebraic_type: AlgebraicType::ScalarField,
batch_size: val as BatchSize,
}),
},
Input::ScalarBatch(val) => match val {
ScalarPlaintextBatch::<C>::Fixed(val) => {
v2::Gate::Constant(v2::Constant::ScalarBatch(val))
}
ScalarPlaintextBatch::<C>::Input(val) => {
v2::Gate::Input(v2::Input::Plaintext {
algebraic_type: AlgebraicType::ScalarField,
batch_size: val as BatchSize,
})
}
},
Input::BaseField(val) => match val {
BaseFieldPlaintext::<C>::Fixed(val) => {
v2::Gate::Constant(v2::Constant::BaseField(val))
}
BaseFieldPlaintext::<C>::Input(val) => {
v2::Gate::Input(v2::Input::Plaintext {
algebraic_type: AlgebraicType::BaseField,
batch_size: val as BatchSize,
})
}
},
Input::BaseFieldBatch(val) => match val {
BaseFieldPlaintextBatch::<C>::Fixed(val) => {
v2::Gate::Constant(v2::Constant::BaseFieldBatch(val))
}
BaseFieldPlaintextBatch::<C>::Input(val) => {
v2::Gate::Input(v2::Input::Plaintext {
algebraic_type: AlgebraicType::BaseField,
batch_size: val as BatchSize,
})
}
},
Input::Mersenne107(val) => match val {
Mersenne107Plaintext::Fixed(val) => {
v2::Gate::Constant(v2::Constant::Mersenne107(val))
}
Mersenne107Plaintext::Input(val) => v2::Gate::Input(v2::Input::Plaintext {
algebraic_type: AlgebraicType::Mersenne107,
batch_size: val as BatchSize,
}),
},
Input::Mersenne107Batch(val) => match val {
Mersenne107PlaintextBatch::Fixed(val) => {
v2::Gate::Constant(v2::Constant::Mersenne107Batch(val))
}
Mersenne107PlaintextBatch::Input(val) => {
v2::Gate::Input(v2::Input::Plaintext {
algebraic_type: AlgebraicType::Mersenne107,
batch_size: val as BatchSize,
})
}
},
Input::Bit(val) => match val {
BitPlaintext::Fixed(val) => v2::Gate::Constant(v2::Constant::Bit(val)),
BitPlaintext::Input(val) => v2::Gate::Input(v2::Input::Plaintext {
algebraic_type: AlgebraicType::Bit,
batch_size: val as BatchSize,
}),
},
Input::BitBatch(val) => match val {
BitPlaintextBatch::Fixed(val) => {
v2::Gate::Constant(v2::Constant::BitBatch(val))
}
BitPlaintextBatch::Input(val) => v2::Gate::Input(v2::Input::Plaintext {
algebraic_type: AlgebraicType::Bit,
batch_size: val as BatchSize,
}),
},
Input::Point(val) => match val {
PointPlaintext::<C>::Fixed(val) => {
v2::Gate::Constant(v2::Constant::Point(val))
}
PointPlaintext::<C>::Input(val) => v2::Gate::Input(v2::Input::Plaintext {
algebraic_type: AlgebraicType::Point,
batch_size: val as BatchSize,
}),
},
Input::PointBatch(val) => match val {
PointPlaintextBatch::<C>::Fixed(val) => {
v2::Gate::Constant(v2::Constant::PointBatch(val))
}
PointPlaintextBatch::<C>::Input(val) => {
v2::Gate::Input(v2::Input::Plaintext {
algebraic_type: AlgebraicType::Point,
batch_size: val as BatchSize,
})
}
},
},
Gate::FieldShareUnaryOp { x, op, .. } => {
let x = old_to_new_idx[x.into_usize()];
v2::Gate::FieldShareUnaryOp { x, op }
}
Gate::FieldShareBinaryOp { x, y, op, .. } => {
let x = old_to_new_idx[x.into_usize()];
let y = old_to_new_idx[y.into_usize()];
v2::Gate::FieldShareBinaryOp { x, y, op }
}
Gate::BatchSummation { x, .. } => {
let x = old_to_new_idx[x.into_usize()];
v2::Gate::BatchSummation { x }
}
Gate::BitShareUnaryOp { x, op } => {
let x = old_to_new_idx[x.into_usize()];
v2::Gate::BitShareUnaryOp { x, op }
}
Gate::BitShareBinaryOp { x, y, op, .. } => {
let x = old_to_new_idx[x.into_usize()];
let y = old_to_new_idx[y.into_usize()];
v2::Gate::BitShareBinaryOp { x, y, op }
}
Gate::PointShareUnaryOp { p, op } => {
let p = old_to_new_idx[p.into_usize()];
v2::Gate::PointShareUnaryOp { p, op }
}
Gate::PointShareBinaryOp { p, y, op, .. } => {
let p = old_to_new_idx[p.into_usize()];
let y = old_to_new_idx[y.into_usize()];
v2::Gate::PointShareBinaryOp { p, y, op }
}
Gate::FieldPlaintextUnaryOp { x, op, .. } => {
let x = old_to_new_idx[x.into_usize()];
v2::Gate::FieldPlaintextUnaryOp { x, op }
}
Gate::FieldPlaintextBinaryOp { x, y, op, .. } => {
let x = old_to_new_idx[x.into_usize()];
let y = old_to_new_idx[y.into_usize()];
v2::Gate::FieldPlaintextBinaryOp { x, y, op }
}
Gate::BitPlaintextUnaryOp { x, op } => {
let x = old_to_new_idx[x.into_usize()];
v2::Gate::BitPlaintextUnaryOp {
x,
op: op.try_into()?,
}
}
Gate::BitPlaintextBinaryOp { x, y, op } => {
let x = old_to_new_idx[x.into_usize()];
let y = old_to_new_idx[y.into_usize()];
v2::Gate::BitPlaintextBinaryOp {
x,
y,
op: op.try_into()?,
}
}
Gate::PointPlaintextUnaryOp { p, op } => {
let p = old_to_new_idx[p.into_usize()];
v2::Gate::PointPlaintextUnaryOp { p, op }
}
Gate::PointPlaintextBinaryOp { p, y, op } => {
let p = old_to_new_idx[p.into_usize()];
let y = old_to_new_idx[y.into_usize()];
v2::Gate::PointPlaintextBinaryOp { p, y, op }
}
Gate::DaBit {
field_type,
batched,
} => v2::Gate::DaBit {
field_type,
batch_size: batched.count() as u32,
},
Gate::GetDaBitFieldShare { x, .. } => {
let x = old_to_new_idx[x.into_usize()];
v2::Gate::GetDaBitFieldShare { x }
}
Gate::GetDaBitSharedBit { x, .. } => {
let x = old_to_new_idx[x.into_usize()];
v2::Gate::GetDaBitSharedBit { x }
}
Gate::BaseFieldPow { x, exp } => {
let x = old_to_new_idx[x.into_usize()];
v2::Gate::BaseFieldPow { x, exp }
}
Gate::BitPlaintextToField { x, field_type } => {
let x = old_to_new_idx[x.into_usize()];
v2::Gate::BitPlaintextToField { x, field_type }
}
Gate::FieldPlaintextToBit { x, .. } => {
let x = old_to_new_idx[x.into_usize()];
v2::Gate::FieldPlaintextToBit { x }
}
Gate::BatchGetIndex { x, index, .. } => {
let x = old_to_new_idx[x.into_usize()];
v2::Gate::ExtractFromBatch {
x,
slice: Slice::single(index as u32),
}
}
Gate::CollectToBatch { wires, .. } => {
let wires = wires
.into_iter()
.map(|w| old_to_new_idx[w.into_usize()])
.collect();
v2::Gate::CollectToBatch { wires }
}
Gate::PointFromPlaintextExtendedEdwards { wires } => {
let wires = wires
.into_iter()
.map(|w| old_to_new_idx[w.into_usize()])
.collect();
v2::Gate::PointFromPlaintextExtendedEdwards { wires }
}
Gate::PlaintextPointToExtendedEdwards { point } => {
let point = old_to_new_idx[point.into_usize()];
v2::Gate::PlaintextPointToExtendedEdwards { point }
}
Gate::PlaintextKeccakF1600 { wires } => {
let wires = wires
.into_iter()
.map(|w| old_to_new_idx[w.into_usize()])
.collect();
let x = circuit.add_gate(v2::Gate::CollectToBatch { wires })?;
v2::Gate::PlaintextKeccakF1600 { x }
}
Gate::CompressPlaintextPoint { point } => {
let point = old_to_new_idx[point.into_usize()];
v2::Gate::CompressPlaintextPoint { point }
}
Gate::KeyRecoveryPlaintextComputeErrors {
d_minus_one,
syndromes,
} => {
let d_minus_one = old_to_new_idx[d_minus_one.into_usize()];
let syndromes = old_to_new_idx[syndromes.into_usize()];
v2::Gate::KeyRecoveryPlaintextComputeErrors {
d_minus_one,
syndromes,
}
}
};
let new_gate_idx = circuit.add_gate(gate)?;
old_to_new_idx[old_gate_idx] = new_gate_idx;
}
for output in self.output_gates {
let output = old_to_new_idx[output.into_usize()];
circuit.add_output(output)?;
}
Ok(circuit)
}
}