extern crate alloc;
use alloc::{vec, vec::Vec};
use crate::common::Proof;
use solana_bn254::prelude::{alt_bn128_addition, alt_bn128_pairing};
use solana_program::{entrypoint::ProgramResult, program_error::ProgramError};
pub type ComputeCoeffsFn = fn(&[i64]) -> Result<(i64, i64, i64), crate::ZkError>;
pub fn process_check_pairing(
proof: Proof,
public_input_values: &[i64],
compute_coeffs: ComputeCoeffsFn,
) -> ProgramResult {
let (computed_a2, computed_b2, computed_c2) =
compute_coeffs(public_input_values).map_err(|_e| ProgramError::InvalidInstructionData)?;
if computed_b2 == 0 {
solana_program::msg!("ERROR: b2 coefficient is zero - invalid constraint system");
return Err(ProgramError::InvalidInstructionData);
}
let g1_a2_calc = if computed_a2 == 0 {
vec![0u8; 64]
} else {
scalar_mult_g1(&G1_GENERATOR, computed_a2)?
};
let g1_c2_calc = if computed_c2 == 0 {
vec![0u8; 64]
} else {
scalar_mult_g1(&G1_GENERATOR, computed_c2)?
};
let b2 = computed_b2;
let g1_b2: [u8; 64] = if b2 == 0 {
[0u8; 64]
} else {
let tmp = scalar_mult_g1(&G1_GENERATOR, b2)?;
let mut arr = [0u8; 64];
arr.copy_from_slice(&tmp);
arr
};
let a_curve = &proof.a_curve; let b_pub = &proof.g2_b2; let c_private_adjusted = &proof.c_curve; let hz_total = &proof.g1_hz;
if proof.g1_a2.as_ref() != g1_a2_calc.as_slice() {
#[allow(unused_macros)]
solana_program::msg!("g1_a2 mismatch");
return Err(ProgramError::InvalidInstructionData);
}
if proof.g1_c2.as_ref() != g1_c2_calc.as_slice() {
solana_program::msg!("g1_c2 mismatch");
return Err(ProgramError::InvalidInstructionData);
}
let a_total = if !is_infinity_g1(proof.g1_a2.as_ref()) {
add_g1_points(a_curve.as_ref(), proof.g1_a2.as_ref())?
} else {
a_curve.to_vec()
};
let c_total_prime = if !is_infinity_g1(proof.g1_c2.as_ref()) {
add_g1_points(c_private_adjusted.as_ref(), proof.g1_c2.as_ref())?
} else {
c_private_adjusted.to_vec()
};
let sum_a_with_g1 = add_g1_points(&a_total, &G1_GENERATOR)?;
let mut sum_g2: Option<Vec<u8>> = None;
if !is_infinity_g1(&c_total_prime[..]) {
sum_g2 = Some(c_total_prime.clone());
}
if !is_infinity_g1(hz_total.as_ref()) {
sum_g2 = Some(match sum_g2 {
Some(acc) => add_g1_points(&acc, hz_total.as_ref())?,
None => hz_total.as_ref().to_vec(),
});
}
if !is_infinity_g1(&g1_b2) {
sum_g2 = Some(match sum_g2 {
Some(acc) => add_g1_points(&acc, &g1_b2)?,
None => g1_b2.to_vec(),
});
}
let sum_g2_neg = sum_g2.as_ref().map(|v| {
let mut arr = [0u8; 64];
arr.copy_from_slice(v);
negate_g1_uncompressed(&arr)
});
let mut pairing_input = Vec::with_capacity(2 * (64 + 128));
pairing_input.extend_from_slice(&sum_a_with_g1);
pairing_input.extend_from_slice(b_pub.as_ref());
if let Some(neg) = sum_g2_neg {
pairing_input.extend_from_slice(&neg);
pairing_input.extend_from_slice(&G2_GENERATOR);
}
let pairing_result = bn254_pairing(&pairing_input)?;
if !pairing_result {
solana_program::msg!("pairing failed");
return Err(ProgramError::InvalidInstructionData);
}
Ok(())
}
pub fn process_check_pairing_relaxed(
proof: Proof,
public_input_values: &[i64],
compute_coeffs: ComputeCoeffsFn,
) -> ProgramResult {
let (_computed_a2, computed_b2, _computed_c2) =
compute_coeffs(public_input_values).map_err(|_e| ProgramError::InvalidInstructionData)?;
let b2 = computed_b2;
let g1_b2: [u8; 64] = if b2 == 0 {
[0u8; 64]
} else {
let tmp = scalar_mult_g1(&G1_GENERATOR, b2)?;
let mut arr = [0u8; 64];
arr.copy_from_slice(&tmp);
arr
};
let a_curve = &proof.a_curve;
let b_pub = &proof.g2_b2;
let c_private_adjusted = &proof.c_curve;
let hz_total = &proof.g1_hz;
let a_total = if !is_infinity_g1(proof.g1_a2.as_ref()) {
add_g1_points(a_curve.as_ref(), proof.g1_a2.as_ref())?
} else {
a_curve.to_vec()
};
let c_total_prime = if !is_infinity_g1(proof.g1_c2.as_ref()) {
add_g1_points(c_private_adjusted.as_ref(), proof.g1_c2.as_ref())?
} else {
c_private_adjusted.to_vec()
};
let sum_a_with_g1 = add_g1_points(&a_total, &G1_GENERATOR)?;
let mut sum_g2: Option<Vec<u8>> = None;
if !is_infinity_g1(&c_total_prime[..]) {
sum_g2 = Some(c_total_prime.clone());
}
if !is_infinity_g1(hz_total.as_ref()) {
sum_g2 = Some(match sum_g2 {
Some(acc) => add_g1_points(&acc, hz_total.as_ref())?,
None => hz_total.as_ref().to_vec(),
});
}
if !is_infinity_g1(&g1_b2) {
sum_g2 = Some(match sum_g2 {
Some(acc) => add_g1_points(&acc, &g1_b2)?,
None => g1_b2.to_vec(),
});
}
let sum_g2_neg = sum_g2.as_ref().map(|v| {
let mut arr = [0u8; 64];
arr.copy_from_slice(v);
negate_g1_uncompressed(&arr)
});
let mut pairing_input = Vec::with_capacity(2 * (64 + 128));
pairing_input.extend_from_slice(&sum_a_with_g1);
pairing_input.extend_from_slice(b_pub.as_ref());
if let Some(neg) = sum_g2_neg {
pairing_input.extend_from_slice(&neg);
pairing_input.extend_from_slice(&G2_GENERATOR);
}
let pairing_result = bn254_pairing(&pairing_input)?;
if !pairing_result {
solana_program::msg!("pairing failed");
return Err(ProgramError::InvalidInstructionData);
}
Ok(())
}
const G1_GENERATOR: [u8; 64] = [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2,
];
const G2_GENERATOR: [u8; 128] = [
25, 142, 147, 147, 146, 13, 72, 58, 114, 96, 191, 183, 49, 251, 93, 37, 241, 170, 73, 51, 53,
169, 231, 18, 151, 228, 133, 183, 174, 243, 18, 194,
24, 0, 222, 239, 18, 31, 30, 118, 66, 106, 0, 102, 94, 92, 68, 121, 103, 67, 34, 212, 247, 94,
218, 221, 70, 222, 189, 92, 217, 146, 246, 237,
9, 6, 137, 208, 88, 95, 240, 117, 236, 158, 153, 173, 105, 12, 51, 149, 188, 75, 49, 51, 112,
179, 142, 243, 85, 172, 218, 220, 209, 34, 151, 91,
18, 200, 94, 165, 219, 140, 109, 235, 74, 171, 113, 128, 141, 203, 64, 143, 227, 209, 231, 105,
12, 67, 211, 123, 76, 230, 204, 1, 102, 250, 125, 170,
];
const BN254_P_BE: [u8; 32] = [
0x30, 0x64, 0x4e, 0x72, 0xe1, 0x31, 0xa0, 0x29, 0xb8, 0x50, 0x45, 0xb6, 0x81, 0x81, 0x58, 0x5d,
0x97, 0x81, 0x6a, 0x91, 0x68, 0x71, 0xca, 0x8d, 0x3c, 0x20, 0x8c, 0x16, 0xd8, 0x7c, 0xfd, 0x47,
];
const BN254_R_BE: [u8; 32] = [
0x30, 0x64, 0x4e, 0x72, 0xe1, 0x31, 0xa0, 0x29, 0xb8, 0x50, 0x45, 0xb6, 0x81, 0x81, 0x58, 0x5d,
0x28, 0x33, 0xe8, 0x48, 0x79, 0xb9, 0x70, 0x91, 0x43, 0xe1, 0xf5, 0x93, 0xf0, 0x00, 0x00, 0x01,
];
fn is_infinity_g1(p: &[u8]) -> bool {
debug_assert!(p.len() == 64);
p.iter().all(|&b| b == 0)
}
fn negate_g1_uncompressed(point: &[u8; 64]) -> [u8; 64] {
if is_infinity_g1(point) {
return *point;
}
let mut out = [0u8; 64];
out[..32].copy_from_slice(&point[..32]);
let mut yneg = [0u8; 32];
let mut borrow: i32 = 0;
for i in (0..32).rev() {
let pi = BN254_P_BE[i] as i32;
let yi = point[32 + i] as i32;
let sub = pi - yi - borrow;
yneg[i] = ((sub % 256 + 256) % 256) as u8;
borrow = if sub < 0 { 1 } else { 0 };
}
out[32..].copy_from_slice(&yneg);
out
}
fn reduce_scalar_mod_r(scalar: i64) -> [u8; 32] {
if scalar == 0 {
return [0u8; 32];
}
if scalar > 0 {
let mut scalar_bytes = [0u8; 32];
scalar_bytes[24..32].copy_from_slice(&(scalar as u64).to_be_bytes());
return scalar_bytes;
}
let abs_scalar = if scalar == i64::MIN {
1u64 << 63
} else {
(-scalar) as u64
};
let mut result = BN254_R_BE;
let abs_bytes = abs_scalar.to_be_bytes();
let mut borrow = 0u16;
for i in (0..32).rev() {
let minuend = result[i] as u16;
let subtrahend_byte = if i >= 24 {
abs_bytes[i - 24] as u16
} else {
0u16
};
let subtrahend_total = subtrahend_byte + borrow;
if minuend >= subtrahend_total {
result[i] = (minuend - subtrahend_total) as u8;
borrow = 0;
} else {
result[i] = (256u16 + minuend - subtrahend_total) as u8;
borrow = 1;
}
}
result
}
fn scalar_mult_g1(base: &[u8], scalar: i64) -> Result<Vec<u8>, ProgramError> {
use solana_bn254::prelude::alt_bn128_multiplication;
if scalar == 0 {
return Ok(vec![0u8; 64]);
}
let scalar_bytes = reduce_scalar_mod_r(scalar);
let mut mult_input = Vec::with_capacity(96);
mult_input.extend_from_slice(base);
mult_input.extend_from_slice(&scalar_bytes);
let result =
alt_bn128_multiplication(&mult_input).map_err(|_| ProgramError::InvalidInstructionData)?;
Ok(result)
}
fn add_g1_points(point1: &[u8], point2: &[u8]) -> Result<Vec<u8>, ProgramError> {
let mut addition_input = Vec::with_capacity(128);
addition_input.extend_from_slice(point1); addition_input.extend_from_slice(point2); let result =
alt_bn128_addition(&addition_input).map_err(|_e| ProgramError::InvalidInstructionData)?;
Ok(result)
}
fn bn254_pairing(input: &[u8]) -> Result<bool, ProgramError> {
let result = alt_bn128_pairing(input).map_err(|_e| ProgramError::InvalidInstructionData)?;
Ok(result[31] == 1 && result[..31].iter().all(|&b| b == 0))
}