#include "risc0/zkvm/circuit/divide_cycle.h"
#include "risc0/zkvm/circuit/step_state.h"
namespace risc0::circuit {
void DivideCycle::set(StepState& state) {
Value cycle = state.code.cycle.get();
risc0Log("C%u: Divide", {cycle});
state.data.memIO.doRead(state.code.cycle.get());
auto decode = state.getPrev(1).asDecode();
auto final = state.getPrev(2).asFinal();
ValueU32 R1 = decode.rs1.get();
ValueU32 R2 = decode.rs2.get();
ValueU32 V2 = decode.val2OH.asU32();
auto doDivide = [&](ValueU32 numerExpr, ValueU32 denomExpr, int signType, bool resultRem) {
denomNonzero.set(denomExpr.low + denomExpr.high);
if (signType == 0) {
numer32.set(numerExpr);
denom.set(denomExpr);
negNumer.set(0);
negDenom.set(0);
negNumerImpl.set(ValueU32::fromConst(0));
negDenomImpl.set(ValueU32::fromConst(0));
} else if (signType == 1) {
BYZ_NONDET {
negNumer.set((numerExpr.high & 0x8000) / 0x8000);
negDenom.set((denomExpr.high & 0x8000) / 0x8000);
}
auto negNumerVal = negNumerImpl.set(numerExpr);
auto negDenomVal = negDenomImpl.set(denomExpr);
numer32.set(negNumer.get() * negNumerVal + (1 - negNumer.get()) * numerExpr);
denom.set(negDenom.get() * negDenomVal + (1 - negDenom.get()) * denomExpr);
} else if (signType == 2) {
BYZ_NONDET { negNumer.set((numerExpr.high & 0x8000) / 0x8000); }
negDenom.set(0);
numer32.set(
{negNumer.get() * (0xffff - numerExpr.low) + (1 - negNumer.get()) * numerExpr.low,
negNumer.get() * (0xffff - numerExpr.high) + (1 - negNumer.get()) * numerExpr.high});
denom.set(denomExpr);
negNumerImpl.set(ValueU32::fromConst(0));
negDenomImpl.set(ValueU32::fromConst(0));
}
auto denomVal = denom.get();
risc0Log("numer32 = %04x%04x, denom = %04x%04x, negNumer = %d, negDenom = %d",
{numer32.high(),
numer32.low(),
denomVal.high,
denomVal.low,
negNumer.get(),
negDenom.get()});
BYZ_NONDET {
auto [quotLow, quotHigh, remLow, remHigh] =
divide(numer32.low(), numer32.high(), denomVal.low, denomVal.high);
quot32.setLow(quotLow);
quot32.setHigh(quotHigh);
rem32.setLow(remLow);
rem32.setHigh(remHigh);
}
risc0Log("quot32 = %04x%04x, rem32 = %04x%04x",
{quot32.high(), quot32.low(), rem32.high(), rem32.low()});
quot.set(quot32.get());
rem.set(rem32.get());
BYZ_IF(denomNonzero.get()) {
carryDMR.set(
denomMinusRem.setPart(0x10000 + denom.getPart(0, 16) - rem.getPart(0, 16) - 1, 0, 16));
Value final = denomMinusRem.setPart(
denom.getPart(16, 16) - rem.getPart(16, 16) - 1 + carryDMR.get(), 16, 16);
equate(final, 0);
}
numer0.setPartExact(quot.getPart(0, 12) * denom.getPart(0, 12) + rem.getPart(0, 12), 0, 24);
numer1.setPartExact(quot.getPart(0, 12) * denom.getPart(12, 12) +
quot.getPart(12, 12) * denom.getPart(0, 12) + rem.getPart(12, 12) +
numer0.getPart(12, 12),
0,
20);
numer2.setPartExact(
quot.getPart(0, 12) * denom.getPart(24, 8) + quot.getPart(12, 12) * denom.getPart(12, 12) +
quot.getPart(24, 8) * denom.getPart(0, 12) + rem.getPart(24, 8) + numer1.getPart(12, 8),
0,
8);
risc0Log("numer0=0x%06x, numer1=0x%05x, numer2=0x%02x",
{numer0.getPart(0, 24), numer1.getPart(0, 20), numer2.getPart(0, 8)});
equate(numer0.getPart(0, 12) + 0x1000 * numer1.getPart(0, 4), numer32.low());
equate(numer1.getPart(4, 8) + 0x100 * numer2.getPart(0, 8), numer32.high());
ValueU32 out;
Value negOut;
if (resultRem) {
out = rem32.get();
if (signType) {
negOut = negNumer.get();
} else {
negOut = 0;
}
} else {
out = quot32.get();
if (signType) {
negOut =
(negNumer.get() + negDenom.get() - 2 * negNumer.get() * negDenom.get())
- (negNumer.get() * (1 - denomNonzero.get()));
} else {
negOut = 0;
}
}
negResult.set(negOut);
if (signType == 1) {
auto negResultVal = negResultImpl.set(out);
resultInfo.result.set(negResult.get() * negResultVal + (1 - negResult.get()) * out);
} else if (signType == 2) {
resultInfo.result.set(
{negResult.get() * (0xffff - out.low) + (1 - negResult.get()) * out.low,
negResult.get() * (0xffff - out.high) + (1 - negResult.get()) * out.high});
negResultImpl.set(ValueU32::fromConst(0));
} else {
resultInfo.result.set(out);
negResultImpl.set(ValueU32::fromConst(0));
}
};
#define OPC(...)
#define OPM(...)
#define OPD(id, mnem, opc, f3, f7, immFmt, numerExpr, denomExpr, signType, resultRem) \
BYZ_IF(decode.opID1.is((id)&7)) { doDivide(numerExpr, denomExpr, signType, resultRem); }
#include "risc0/zkvm/circuit/riscv32im.inl"
#undef OPC
#undef OPM
#undef OPD
resultInfo.setReg.set(1);
resultInfo.doStore.set(0);
resultInfo.pcRaw.set(final.pc + 4);
}
}