use anyhow::{anyhow, Result};
use k256::elliptic_curve::ops::{MulByGenerator, Reduce};
use k256::U256 as K256U256;
use k256::{ProjectivePoint, Scalar};
pub struct ModConstraint {
pub base_point: ProjectivePoint,
pub transformed_pubkey: ProjectivePoint,
pub j_start: [u8; 32],
pub effective_range_bits: u32,
pub mod_step: Scalar,
pub mod_start: Scalar,
}
impl ModConstraint {
pub fn new(
mod_step_hex: &str,
mod_start_hex: &str,
pubkey: &ProjectivePoint,
start: &[u8; 32],
range_bits: u32,
) -> Result<Option<Self>> {
let m_u64 = u64::from_str_radix(mod_step_hex.trim_start_matches("0x"), 16)
.map_err(|e| anyhow!("Invalid hex for mod_step M: {e}"))?;
let r_u64 = u64::from_str_radix(mod_start_hex.trim_start_matches("0x"), 16)
.map_err(|e| anyhow!("Invalid hex for mod_start R: {e}"))?;
if m_u64 == 0 {
return Err(anyhow!("mod_step M must be >= 1, got 0"));
}
if r_u64 >= m_u64 {
return Err(anyhow!(
"mod_start R must be < M, got R={r_u64} >= M={m_u64}"
));
}
if m_u64 == 1 {
return Ok(None);
}
let mod_step = <Scalar as Reduce<K256U256>>::reduce(K256U256::from(m_u64));
let mod_start = <Scalar as Reduce<K256U256>>::reduce(K256U256::from(r_u64));
let base_point = ProjectivePoint::mul_by_generator(&mod_step);
let transformed_pubkey = if r_u64 == 0 {
*pubkey
} else {
let r_g = ProjectivePoint::mul_by_generator(&mod_start);
*pubkey - r_g
};
let diff = sub_u64_from_u256_le(start, r_u64);
let (quotient, remainder) = div_u256_le_by_u64(&diff, m_u64);
let j_start = if remainder > 0 {
add_one_u256_le("ient)
} else {
quotient
};
let log2_m = 63u32 - m_u64.leading_zeros();
let effective_range_bits = range_bits.saturating_sub(log2_m).max(1);
Ok(Some(Self {
base_point,
transformed_pubkey,
j_start,
effective_range_bits,
mod_step,
mod_start,
}))
}
}
fn sub_u64_from_u256_le(le_bytes: &[u8; 32], val: u64) -> [u8; 32] {
let mut result = *le_bytes;
let mut borrow = val as u128;
for chunk in 0..4 {
if borrow == 0 {
break;
}
let offset = chunk * 8;
let limb = u64::from_le_bytes(result[offset..offset + 8].try_into().unwrap()) as u128;
if limb >= borrow {
result[offset..offset + 8].copy_from_slice(&((limb - borrow) as u64).to_le_bytes());
borrow = 0;
} else {
let diff = (1u128 << 64) + limb - borrow;
result[offset..offset + 8].copy_from_slice(&(diff as u64).to_le_bytes());
borrow = 1;
}
}
if borrow != 0 {
return [0u8; 32];
}
result
}
fn div_u256_le_by_u64(le_bytes: &[u8; 32], divisor: u64) -> ([u8; 32], u64) {
let mut result = [0u8; 32];
let mut remainder: u128 = 0;
let d = divisor as u128;
for chunk in (0..4).rev() {
let offset = chunk * 8;
let limb = u64::from_le_bytes(le_bytes[offset..offset + 8].try_into().unwrap()) as u128;
let combined = (remainder << 64) | limb;
result[offset..offset + 8].copy_from_slice(&((combined / d) as u64).to_le_bytes());
remainder = combined % d;
}
(result, remainder as u64)
}
fn add_one_u256_le(le_bytes: &[u8; 32]) -> [u8; 32] {
let mut result = *le_bytes;
for chunk in 0..4 {
let offset = chunk * 8;
let limb = u64::from_le_bytes(result[offset..offset + 8].try_into().unwrap());
let (sum, overflow) = limb.overflowing_add(1);
result[offset..offset + 8].copy_from_slice(&sum.to_le_bytes());
if !overflow {
break;
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
use crate::crypto::parse_pubkey;
#[test]
fn test_mod_constraint_identity() {
let pubkey =
parse_pubkey("033c4a45cbd643ff97d77f41ea37e843648d50fd894b864b0d52febc62f6454f7c")
.unwrap();
let start = [0u8; 32];
let result = ModConstraint::new("1", "0", &pubkey, &start, 20).unwrap();
assert!(result.is_none(), "M=1,R=0 should return None");
}
#[test]
fn test_mod_constraint_m7_r0() {
let pubkey =
parse_pubkey("033c4a45cbd643ff97d77f41ea37e843648d50fd894b864b0d52febc62f6454f7c")
.unwrap();
let mut start = [0u8; 32];
start[2] = 0x08;
let constraint = ModConstraint::new("7", "0", &pubkey, &start, 20).unwrap();
assert!(constraint.is_some(), "M=7,R=0 should return Some");
let c = constraint.unwrap();
let seven = <Scalar as Reduce<K256U256>>::reduce(K256U256::from(7u64));
let expected_h = ProjectivePoint::mul_by_generator(&seven);
assert_eq!(c.base_point, expected_h, "base_point should be 7*G");
assert_eq!(
c.transformed_pubkey, pubkey,
"transformed_pubkey should equal pubkey when R=0"
);
assert_eq!(
c.effective_range_bits, 18,
"effective_range_bits should be 18 for M=7"
);
}
#[test]
fn test_mod_constraint_invalid_r_ge_m() {
let pubkey =
parse_pubkey("033c4a45cbd643ff97d77f41ea37e843648d50fd894b864b0d52febc62f6454f7c")
.unwrap();
let start = [0u8; 32];
let result = ModConstraint::new("3", "3", &pubkey, &start, 20);
assert!(result.is_err(), "R >= M should return Err");
}
#[test]
fn test_mod_constraint_m0_invalid() {
let pubkey =
parse_pubkey("033c4a45cbd643ff97d77f41ea37e843648d50fd894b864b0d52febc62f6454f7c")
.unwrap();
let start = [0u8; 32];
let result = ModConstraint::new("0", "0", &pubkey, &start, 20);
assert!(result.is_err(), "M=0 should return Err");
}
}