#![allow(missing_docs)]
use uor_foundation::enforcement::{GroundedShape, ShapeViolation};
use uor_foundation::pipeline::{
AxisExtension, ConstrainedTypeShape, ConstraintRef, IntoBindingValue,
};
use uor_foundation_sdk::axis;
use crate::{check_output, split_pair};
axis! {
pub trait FieldAxis: AxisExtension {
const AXIS_ADDRESS: &'static str = "https://uor.foundation/axis/FieldAxis";
const MAX_OUTPUT_BYTES: usize = 32;
fn add(input: &[u8], out: &mut [u8]) -> Result<usize, ShapeViolation>;
fn sub(input: &[u8], out: &mut [u8]) -> Result<usize, ShapeViolation>;
fn mul(input: &[u8], out: &mut [u8]) -> Result<usize, ShapeViolation>;
}
}
const WIDTH: usize = 32;
const P: [u8; WIDTH] = [
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe, 0xff, 0xff, 0xfc, 0x2f,
];
fn cmp_ge(a: &[u8; WIDTH], b: &[u8; WIDTH]) -> bool {
for i in 0..WIDTH {
if a[i] != b[i] {
return a[i] > b[i];
}
}
true
}
fn sub_assign(target: &mut [u8; WIDTH], rhs: &[u8; WIDTH]) {
let mut borrow: i16 = 0;
for i in (0..WIDTH).rev() {
let diff = i16::from(target[i]) - i16::from(rhs[i]) - borrow;
if diff < 0 {
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
{
target[i] = (diff + 256) as u8;
}
borrow = 1;
} else {
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
{
target[i] = diff as u8;
}
borrow = 0;
}
}
}
fn add_with_carry(a: &[u8; WIDTH], b: &[u8; WIDTH]) -> ([u8; WIDTH], u8) {
let mut out = [0u8; WIDTH];
let mut carry: u16 = 0;
for i in (0..WIDTH).rev() {
let sum = u16::from(a[i]) + u16::from(b[i]) + carry;
#[allow(clippy::cast_possible_truncation)]
{
out[i] = (sum & 0xff) as u8;
}
carry = sum >> 8;
}
#[allow(clippy::cast_possible_truncation)]
(out, carry as u8)
}
fn reduce_to_field(value: [u8; WIDTH], had_carry: bool) -> [u8; WIDTH] {
let mut v = value;
if had_carry {
sub_assign(&mut v, &P);
}
while cmp_ge(&v, &P) {
sub_assign(&mut v, &P);
}
v
}
fn mod_mul(a: &[u8; WIDTH], b: &[u8; WIDTH]) -> [u8; WIDTH] {
let mut acc = [0u32; 2 * WIDTH];
for i in (0..WIDTH).rev() {
for j in (0..WIDTH).rev() {
let prod = u32::from(a[i]) * u32::from(b[j]);
let pos = i + j + 1;
let sum = acc[pos] + (prod & 0xff);
acc[pos] = sum & 0xff;
let mut carry = (sum >> 8) + (prod >> 8);
let mut k = pos;
while carry > 0 && k > 0 {
k -= 1;
let next = acc[k] + carry;
acc[k] = next & 0xff;
carry = next >> 8;
}
}
}
let mut bytes = [0u8; 2 * WIDTH];
for i in 0..2 * WIDTH {
#[allow(clippy::cast_possible_truncation)]
{
bytes[i] = (acc[i] & 0xff) as u8;
}
}
for shift_bytes in (0..=WIDTH).rev() {
loop {
let mut higher_than_p = false;
for i in 0..WIDTH {
let lhs = bytes[shift_bytes + i];
let rhs = P[i];
if lhs != rhs {
higher_than_p = lhs > rhs;
break;
} else if i == WIDTH - 1 {
higher_than_p = true;
}
}
let mut upper_zero = true;
for byte in bytes.iter().take(shift_bytes) {
if *byte != 0 {
upper_zero = false;
break;
}
}
if !upper_zero || !higher_than_p {
break;
}
let mut borrow: i16 = 0;
for i in (0..WIDTH).rev() {
let diff = i16::from(bytes[shift_bytes + i]) - i16::from(P[i]) - borrow;
if diff < 0 {
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
{
bytes[shift_bytes + i] = (diff + 256) as u8;
}
borrow = 1;
} else {
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
{
bytes[shift_bytes + i] = diff as u8;
}
borrow = 0;
}
}
}
}
let mut out = [0u8; WIDTH];
out.copy_from_slice(&bytes[WIDTH..]);
out
}
fn read32(slice: &[u8]) -> [u8; WIDTH] {
let mut out = [0u8; WIDTH];
out.copy_from_slice(&slice[..WIDTH]);
out
}
#[derive(Debug, Clone, Copy, Default)]
pub struct PrimeFieldNumericSecp256k1;
impl FieldAxis for PrimeFieldNumericSecp256k1 {
const AXIS_ADDRESS: &'static str = "https://uor.foundation/axis/FieldAxis/Secp256k1Base";
const MAX_OUTPUT_BYTES: usize = WIDTH;
fn add(input: &[u8], out: &mut [u8]) -> Result<usize, ShapeViolation> {
let (a, b) = split_pair(input, WIDTH)?;
check_output(out, WIDTH)?;
let a = read32(a);
let b = read32(b);
let (sum, carry) = add_with_carry(&a, &b);
let result = reduce_to_field(sum, carry != 0);
out[..WIDTH].copy_from_slice(&result);
Ok(WIDTH)
}
fn sub(input: &[u8], out: &mut [u8]) -> Result<usize, ShapeViolation> {
let (a, b) = split_pair(input, WIDTH)?;
check_output(out, WIDTH)?;
let a = read32(a);
let b = read32(b);
let mut p_minus_b = P;
sub_assign(&mut p_minus_b, &b);
let (sum, carry) = add_with_carry(&a, &p_minus_b);
let result = reduce_to_field(sum, carry != 0);
out[..WIDTH].copy_from_slice(&result);
Ok(WIDTH)
}
fn mul(input: &[u8], out: &mut [u8]) -> Result<usize, ShapeViolation> {
let (a, b) = split_pair(input, WIDTH)?;
check_output(out, WIDTH)?;
let a = read32(a);
let b = read32(b);
let result = mod_mul(&a, &b);
out[..WIDTH].copy_from_slice(&result);
Ok(WIDTH)
}
}
axis_extension_impl_for_field_axis!(PrimeFieldNumericSecp256k1);
#[derive(Debug, Clone, Copy)]
pub struct FieldElementShape<const BYTES: usize>;
impl<const BYTES: usize> Default for FieldElementShape<BYTES> {
fn default() -> Self {
Self
}
}
impl<const BYTES: usize> ConstrainedTypeShape for FieldElementShape<BYTES> {
const IRI: &'static str = "https://uor.foundation/type/ConstrainedType";
const SITE_COUNT: usize = BYTES;
const CONSTRAINTS: &'static [ConstraintRef] = &[];
#[allow(clippy::cast_possible_truncation)]
const CYCLE_SIZE: u64 = 256u64.saturating_pow(BYTES as u32);
}
impl<const BYTES: usize> uor_foundation::pipeline::__sdk_seal::Sealed for FieldElementShape<BYTES> {}
impl<const BYTES: usize> GroundedShape for FieldElementShape<BYTES> {}
impl<const BYTES: usize> IntoBindingValue for FieldElementShape<BYTES> {
const MAX_BYTES: usize = BYTES;
fn into_binding_bytes(&self, _out: &mut [u8]) -> Result<usize, ShapeViolation> {
Ok(0)
}
}