#[cfg(not(feature = "std"))]
use alloc::vec;
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
pub use qp_plonky2_core::poseidon::{
Permuter, Poseidon, PoseidonHash, PoseidonPermutation, ALL_ROUND_CONSTANTS, HALF_N_FULL_ROUNDS,
N_FULL_ROUNDS_TOTAL, N_PARTIAL_ROUNDS, N_ROUNDS, SPONGE_CAPACITY, SPONGE_RATE, SPONGE_WIDTH,
};
use crate::field::extension::Extendable;
use crate::field::types::Field;
use crate::gates::gate::Gate;
use crate::gates::poseidon::PoseidonGate;
use crate::gates::poseidon_mds::PoseidonMdsGate;
use crate::hash::hash_types::RichField;
use crate::hash::hashing::PlonkyPermutation;
use crate::iop::ext_target::ExtensionTarget;
use crate::iop::target::{BoolTarget, Target};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::config::AlgebraicHasher;
pub trait PoseidonCircuit: Poseidon {
fn mds_row_shf_circuit<const D: usize>(
builder: &mut CircuitBuilder<Self, D>,
r: usize,
v: &[ExtensionTarget<D>; SPONGE_WIDTH],
) -> ExtensionTarget<D>
where
Self: RichField + Extendable<D>,
{
debug_assert!(r < SPONGE_WIDTH);
let mut res = builder.zero_extension();
for i in 0..SPONGE_WIDTH {
let c = Self::from_canonical_u64(<Self as Poseidon>::MDS_MATRIX_CIRC[i]);
res = builder.mul_const_add_extension(c, v[(i + r) % SPONGE_WIDTH], res);
}
{
let c = Self::from_canonical_u64(<Self as Poseidon>::MDS_MATRIX_DIAG[r]);
res = builder.mul_const_add_extension(c, v[r], res);
}
res
}
fn mds_layer_circuit<const D: usize>(
builder: &mut CircuitBuilder<Self, D>,
state: &[ExtensionTarget<D>; SPONGE_WIDTH],
) -> [ExtensionTarget<D>; SPONGE_WIDTH]
where
Self: RichField + Extendable<D>,
{
let mds_gate = PoseidonMdsGate::<Self, D>::new();
if builder.config.num_routed_wires >= mds_gate.num_wires() {
let index = builder.add_gate(mds_gate, vec![]);
for i in 0..SPONGE_WIDTH {
let input_wire = PoseidonMdsGate::<Self, D>::wires_input(i);
builder.connect_extension(state[i], ExtensionTarget::from_range(index, input_wire));
}
(0..SPONGE_WIDTH)
.map(|i| {
let output_wire = PoseidonMdsGate::<Self, D>::wires_output(i);
ExtensionTarget::from_range(index, output_wire)
})
.collect::<Vec<_>>()
.try_into()
.unwrap()
} else {
let mut result = [builder.zero_extension(); SPONGE_WIDTH];
for r in 0..SPONGE_WIDTH {
result[r] = Self::mds_row_shf_circuit(builder, r, state);
}
result
}
}
fn partial_first_constant_layer_circuit<const D: usize>(
builder: &mut CircuitBuilder<Self, D>,
state: &mut [ExtensionTarget<D>; SPONGE_WIDTH],
) where
Self: RichField + Extendable<D>,
{
for i in 0..SPONGE_WIDTH {
let c = <Self as Poseidon>::FAST_PARTIAL_FIRST_ROUND_CONSTANT[i];
let c = Self::Extension::from_canonical_u64(c);
let c = builder.constant_extension(c);
state[i] = builder.add_extension(state[i], c);
}
}
fn mds_partial_layer_init_circuit<const D: usize>(
builder: &mut CircuitBuilder<Self, D>,
state: &[ExtensionTarget<D>; SPONGE_WIDTH],
) -> [ExtensionTarget<D>; SPONGE_WIDTH]
where
Self: RichField + Extendable<D>,
{
let mut result = [builder.zero_extension(); SPONGE_WIDTH];
result[0] = state[0];
for r in 1..SPONGE_WIDTH {
for c in 1..SPONGE_WIDTH {
let t = <Self as Poseidon>::FAST_PARTIAL_ROUND_INITIAL_MATRIX[r - 1][c - 1];
let t = Self::Extension::from_canonical_u64(t);
let t = builder.constant_extension(t);
result[c] = builder.mul_add_extension(t, state[r], result[c]);
}
}
result
}
fn mds_partial_layer_fast_circuit<const D: usize>(
builder: &mut CircuitBuilder<Self, D>,
state: &[ExtensionTarget<D>; SPONGE_WIDTH],
r: usize,
) -> [ExtensionTarget<D>; SPONGE_WIDTH]
where
Self: RichField + Extendable<D>,
{
let s0 = state[0];
let mds0to0 = Self::MDS_MATRIX_CIRC[0] + Self::MDS_MATRIX_DIAG[0];
let mut d = builder.mul_const_extension(Self::from_canonical_u64(mds0to0), s0);
for i in 1..SPONGE_WIDTH {
let t = <Self as Poseidon>::FAST_PARTIAL_ROUND_W_HATS[r][i - 1];
let t = Self::Extension::from_canonical_u64(t);
let t = builder.constant_extension(t);
d = builder.mul_add_extension(t, state[i], d);
}
let mut result = [builder.zero_extension(); SPONGE_WIDTH];
result[0] = d;
for i in 1..SPONGE_WIDTH {
let t = <Self as Poseidon>::FAST_PARTIAL_ROUND_VS[r][i - 1];
let t = Self::Extension::from_canonical_u64(t);
let t = builder.constant_extension(t);
result[i] = builder.mul_add_extension(t, state[0], state[i]);
}
result
}
fn constant_layer_circuit<const D: usize>(
builder: &mut CircuitBuilder<Self, D>,
state: &mut [ExtensionTarget<D>; SPONGE_WIDTH],
round_ctr: usize,
) where
Self: RichField + Extendable<D>,
{
for i in 0..SPONGE_WIDTH {
let c = ALL_ROUND_CONSTANTS[i + SPONGE_WIDTH * round_ctr];
let c = Self::Extension::from_canonical_u64(c);
let c = builder.constant_extension(c);
state[i] = builder.add_extension(state[i], c);
}
}
fn sbox_monomial_circuit<const D: usize>(
builder: &mut CircuitBuilder<Self, D>,
x: ExtensionTarget<D>,
) -> ExtensionTarget<D>
where
Self: RichField + Extendable<D>,
{
builder.exp_u64_extension(x, 7)
}
fn sbox_layer_circuit<const D: usize>(
builder: &mut CircuitBuilder<Self, D>,
state: &mut [ExtensionTarget<D>; SPONGE_WIDTH],
) where
Self: RichField + Extendable<D>,
{
for i in 0..SPONGE_WIDTH {
state[i] = <Self as PoseidonCircuit>::sbox_monomial_circuit(builder, state[i]);
}
}
}
impl<T: Poseidon> PoseidonCircuit for T {}
impl<F: RichField> AlgebraicHasher<F> for PoseidonHash {
type AlgebraicPermutation = PoseidonPermutation<Target>;
fn permute_swapped<const D: usize>(
inputs: Self::AlgebraicPermutation,
swap: BoolTarget,
builder: &mut CircuitBuilder<F, D>,
) -> Self::AlgebraicPermutation
where
F: RichField + Extendable<D>,
{
let gate_type = PoseidonGate::<F, D>::new();
let gate = builder.add_gate(gate_type, vec![]);
let swap_wire = PoseidonGate::<F, D>::WIRE_SWAP;
let swap_wire = Target::wire(gate, swap_wire);
builder.connect(swap.target, swap_wire);
let inputs = inputs.as_ref();
for i in 0..SPONGE_WIDTH {
let in_wire = PoseidonGate::<F, D>::wire_input(i);
let in_wire = Target::wire(gate, in_wire);
builder.connect(inputs[i], in_wire);
}
Self::AlgebraicPermutation::new(
(0..SPONGE_WIDTH).map(|i| Target::wire(gate, PoseidonGate::<F, D>::wire_output(i))),
)
}
}