use itertools::izip;
use serde::{Deserialize, Serialize};
use sp1_core_executor::events::ByteRecord;
use sp1_hypercube::{air::SP1AirBuilder, Word};
use slop_air::AirBuilder;
use slop_algebra::{AbstractField, Field};
use sp1_derive::{AlignedBorrow, InputExpr, InputParams, IntoShape, SP1OperationBuilder};
use sp1_primitives::consts::{u64_to_u16_limbs, WORD_SIZE};
use crate::air::{SP1Operation, SP1OperationBuilder};
use super::{U16CompareOperation, U16CompareOperationInput, U16MSBOperation, U16MSBOperationInput};
#[derive(
AlignedBorrow,
Default,
Debug,
Clone,
Copy,
Serialize,
Deserialize,
IntoShape,
SP1OperationBuilder,
)]
#[repr(C)]
pub struct LtOperationUnsigned<T> {
pub u16_compare_operation: U16CompareOperation<T>,
pub u16_flags: [T; WORD_SIZE],
pub not_eq_inv: T,
pub comparison_limbs: [T; 2],
}
#[derive(
AlignedBorrow,
Default,
Debug,
Clone,
Copy,
Serialize,
Deserialize,
IntoShape,
SP1OperationBuilder,
)]
#[repr(C)]
pub struct LtOperationSigned<T> {
pub result: LtOperationUnsigned<T>,
pub b_msb: U16MSBOperation<T>,
pub c_msb: U16MSBOperation<T>,
}
impl<F: Field> LtOperationSigned<F> {
pub fn populate_signed(
&mut self,
record: &mut impl ByteRecord,
a_u64: u64,
b_u64: u64,
c_u64: u64,
is_signed: bool,
) {
let b_comp = u64_to_u16_limbs(b_u64);
let c_comp = u64_to_u16_limbs(c_u64);
if is_signed {
self.b_msb.populate_msb(record, b_comp[3]);
self.c_msb.populate_msb(record, c_comp[3]);
self.result.populate_unsigned(record, a_u64, b_u64 ^ (1 << 63), c_u64 ^ (1 << 63));
} else {
self.b_msb.msb = F::zero();
self.c_msb.msb = F::zero();
self.result.populate_unsigned(record, a_u64, b_u64, c_u64);
}
}
pub fn eval_lt_signed<AB>(
builder: &mut AB,
b: Word<AB::Expr>,
c: Word<AB::Expr>,
cols: LtOperationSigned<AB::Var>,
is_signed: AB::Expr,
is_real: AB::Expr,
) where
AB: SP1AirBuilder
+ SP1OperationBuilder<U16CompareOperation<<AB as AirBuilder>::F>>
+ SP1OperationBuilder<U16MSBOperation<<AB as AirBuilder>::F>>
+ SP1OperationBuilder<LtOperationUnsigned<<AB as AirBuilder>::F>>,
{
builder.assert_bool(is_signed.clone());
builder.assert_bool(is_real.clone());
builder.when_not(is_real.clone()).assert_zero(is_signed.clone());
<U16MSBOperation<AB::F> as SP1Operation<AB>>::eval(
builder,
U16MSBOperationInput::<AB>::new(
b.0[WORD_SIZE - 1].clone(),
cols.b_msb,
is_signed.clone(),
),
);
<U16MSBOperation<AB::F> as SP1Operation<AB>>::eval(
builder,
U16MSBOperationInput::<AB>::new(
c.0[WORD_SIZE - 1].clone(),
cols.c_msb,
is_signed.clone(),
),
);
builder.when_not(is_signed.clone()).assert_zero(cols.b_msb.msb);
builder.when_not(is_signed.clone()).assert_zero(cols.c_msb.msb);
let mut b_compare = b;
let mut c_compare = c;
let base = AB::Expr::from_canonical_u32(1 << 16);
b_compare[WORD_SIZE - 1] = b_compare[WORD_SIZE - 1].clone()
+ is_signed.clone() * AB::Expr::from_canonical_u32(1 << 15)
- base.clone() * cols.b_msb.msb;
c_compare[WORD_SIZE - 1] = c_compare[WORD_SIZE - 1].clone()
+ is_signed.clone() * AB::Expr::from_canonical_u32(1 << 15)
- base.clone() * cols.c_msb.msb;
<LtOperationUnsigned<AB::F> as SP1Operation<AB>>::eval(
builder,
LtOperationUnsignedInput::<AB>::new(b_compare, c_compare, cols.result, is_real.clone()),
);
}
}
impl<F: Field> LtOperationUnsigned<F> {
pub fn populate_unsigned(
&mut self,
record: &mut impl ByteRecord,
a_u64: u64,
b_u64: u64,
c_u64: u64,
) {
self.comparison_limbs[0] = F::zero();
self.comparison_limbs[1] = F::zero();
self.not_eq_inv = F::zero();
self.u16_flags = [F::zero(), F::zero(), F::zero(), F::zero()];
let a_limbs = u64_to_u16_limbs(a_u64);
let b_limbs = u64_to_u16_limbs(b_u64);
let c_limbs = u64_to_u16_limbs(c_u64);
let a_u16 = a_limbs[0] as u16;
let mut comparison_limbs = [0u16; 2];
for (b_limb, c_limb, flag) in
izip!(b_limbs.iter().rev(), c_limbs.iter().rev(), self.u16_flags.iter_mut().rev())
{
if b_limb != c_limb {
*flag = F::one();
comparison_limbs[0] = *b_limb;
comparison_limbs[1] = *c_limb;
let b_limb = F::from_canonical_u16(*b_limb);
let c_limb = F::from_canonical_u16(*c_limb);
self.not_eq_inv = (b_limb - c_limb).inverse();
self.comparison_limbs = [b_limb, c_limb];
break;
}
}
self.u16_compare_operation.populate(
record,
a_u16,
comparison_limbs[0],
comparison_limbs[1],
);
}
pub fn eval_lt_unsigned<AB>(
builder: &mut AB,
b: Word<AB::Expr>,
c: Word<AB::Expr>,
cols: LtOperationUnsigned<AB::Var>,
is_real: AB::Expr,
) where
AB: SP1AirBuilder + SP1OperationBuilder<U16CompareOperation<<AB as AirBuilder>::F>>,
{
builder.assert_bool(is_real.clone());
let sum_flags =
cols.u16_flags[0] + cols.u16_flags[1] + cols.u16_flags[2] + cols.u16_flags[3];
builder.assert_bool(cols.u16_flags[0]);
builder.assert_bool(cols.u16_flags[1]);
builder.assert_bool(cols.u16_flags[2]);
builder.assert_bool(cols.u16_flags[3]);
builder.assert_bool(sum_flags.clone());
let is_comp_eq = AB::Expr::one() - sum_flags;
let mut is_inequality_visited = AB::Expr::zero();
let mut b_comparison_limb = AB::Expr::zero();
let mut c_comparison_limb = AB::Expr::zero();
for (b_limb, c_limb, &flag) in
izip!(b.0.iter().rev(), c.0.iter().rev(), cols.u16_flags.iter().rev())
{
is_inequality_visited = is_inequality_visited.clone() + flag.into();
builder
.when(is_real.clone() - is_inequality_visited.clone())
.assert_eq(b_limb.clone(), c_limb.clone());
b_comparison_limb = b_comparison_limb.clone() + b_limb.clone() * flag.into();
c_comparison_limb = c_comparison_limb.clone() + c_limb.clone() * flag.into();
}
let (b_comp_limb, c_comp_limb) = (cols.comparison_limbs[0], cols.comparison_limbs[1]);
builder.assert_eq(b_comparison_limb, b_comp_limb);
builder.assert_eq(c_comparison_limb, c_comp_limb);
builder
.when_not(is_comp_eq)
.assert_eq(cols.not_eq_inv * (b_comp_limb - c_comp_limb), is_real.clone());
<U16CompareOperation<AB::F> as SP1Operation<AB>>::eval(
builder,
U16CompareOperationInput::<AB>::new(
b_comp_limb.into(),
c_comp_limb.into(),
cols.u16_compare_operation,
is_real.clone(),
),
);
}
}
#[derive(Clone, InputExpr, InputParams)]
pub struct LtOperationUnsignedInput<AB: SP1AirBuilder> {
pub b: Word<AB::Expr>,
pub c: Word<AB::Expr>,
pub cols: LtOperationUnsigned<AB::Var>,
pub is_real: AB::Expr,
}
impl<AB> SP1Operation<AB> for LtOperationUnsigned<AB::F>
where
AB: SP1AirBuilder + SP1OperationBuilder<U16CompareOperation<<AB as AirBuilder>::F>>,
{
type Input = LtOperationUnsignedInput<AB>;
type Output = ();
fn lower(builder: &mut AB, input: Self::Input) -> Self::Output {
Self::eval_lt_unsigned(builder, input.b, input.c, input.cols, input.is_real);
}
}
#[derive(Clone, InputExpr, InputParams)]
pub struct LtOperationSignedInput<AB: SP1AirBuilder> {
pub b: Word<AB::Expr>,
pub c: Word<AB::Expr>,
pub cols: LtOperationSigned<AB::Var>,
pub is_signed: AB::Expr,
pub is_real: AB::Expr,
}
impl<AB> SP1Operation<AB> for LtOperationSigned<AB::F>
where
AB: SP1AirBuilder
+ SP1OperationBuilder<U16CompareOperation<<AB as AirBuilder>::F>>
+ SP1OperationBuilder<U16MSBOperation<<AB as AirBuilder>::F>>
+ SP1OperationBuilder<LtOperationUnsigned<<AB as AirBuilder>::F>>,
{
type Input = LtOperationSignedInput<AB>;
type Output = ();
fn lower(builder: &mut AB, input: Self::Input) -> Self::Output {
Self::eval_lt_signed(builder, input.b, input.c, input.cols, input.is_signed, input.is_real);
}
}