risc0-zkvm-circuit 0.2.0

RISC Zero zero-knowledge VM circuit crate
Documentation
// Copyright 2022 Risc0, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "risc0/zkvm/circuit/compute_cycle.h"

#include "risc0/zkvm/circuit/step_state.h"

namespace risc0::circuit {

// Do the compute cycle.  Turn off max function size since macro inclusion make it too big.
// NOLINTNEXTLINE(readability-function-size)
void ComputeCycle::set(StepState& state, int highID) {
  Value cycle = state.code.cycle.get();

  // Cache some commonly used values
  auto decode = state.getPrev(1).asDecode();
  auto final = state.getPrev(2).asFinal();

  // Setup values used by the instruction macros
  auto R1 = decode.rs1.get();
#define R2 decode.rs2.get()
  auto IM = decode.imm.get();
#define MM state.data.memIO.value.get()
  auto memoryStore = [&]() {
    // Load the low 16 with the low 16 of rs2, and also the bit of the memory IO we will
    // be overwriting (so we can mix the bytes if needed).
    return ValueU32({decode.rs2.low(),
                     (1 - x1.get(1)) * state.data.memIO.value.low() +
                         x1.get(1) * state.data.memIO.value.high()});
  };
  auto& X1 = x1;
  auto& X2 = x2;
  auto& X3 = x3;
  auto PC = final.pc.get();
  // Do a memory read
  auto GET = [&](bool signExt, int bits) {
    if (bits == 32) {
      equate(x1.getPart(0, 2), 0);
      return x2.get();
    }
    ValueU32 out;
    Value highBit = 0;
    if (bits == 8) {
      Value tot = 0;
      for (int i = 0; i < 4; i++) {
        tot = tot + x1Low.is(i) * x2.getPart(i * 8, 8);
        if (signExt) {
          highBit = highBit + x1Low.is(i) * x2.get(i * 8 + 7);
        }
      }
      out.low = 0xff00 * highBit + tot;
    } else {
      equate(x1.get(0), 0);
      out.low = (1 - x1.get(1)) * x2.getPart(0, 16) + x1.get(1) * x2.getPart(16, 16);
      if (signExt) {
        highBit = highBit + (1 - x1.get(1)) * x2.get(15) + x1.get(1) * x2.get(31);
      }
    }
    out.high = 0xffff * highBit;
    return out;
  };
  // Do a memory write
  auto PUT = [&](int bits) {
    if (bits == 32) {
      equate(x1.getPart(0, 2), 0);
      return R2;
    }
    ValueU32 out;
    Value halfUpdate;
    if (bits == 8) {
      Value lowByte = (1 - x1.get(0)) * x2.getPart(0, 8) + x1.get(0) * x2.getPart(16, 8);
      Value highByte = x1.get(0) * x2.getPart(0, 8) + (1 - x1.get(0)) * x2.getPart(24, 8);
      tmp.set(lowByte + 256 * highByte);
      halfUpdate = tmp.get();
    } else {
      equate(x1.get(0), 0);
      halfUpdate = x2.getPart(0, 16);
    }
    out.low = (1 - x1.get(1)) * halfUpdate + x1.get(1) * state.data.memIO.value.low();
    out.high = x1.get(1) * halfUpdate + (1 - x1.get(1)) * state.data.memIO.value.high();
    return out;
  };
  auto maybeBranch = [](ValueU32 pc, Value doBranch, ValueU32 imm) {
    return pc + doBranch * imm + (1 - doBranch) * ValueU32({4, 0});
  };
#define MS memoryStore()
#define EQ ValueU32({1 - nz.get(), 0})
#define LT ValueU32({lt.get(), 0})
#define LTU ValueU32({1 - x3CarryHigh.get(), 0})
#define BR_TRUE maybeBranch(PC, resultInfo.result.low(), IM)
#define BR_FALSE maybeBranch(PC, (1 - resultInfo.result.low()), IM)

  // We call this once *inside* the risc0-if in a purely nondet way to compute the values without a
  // high degree, and then once outside to verify.
  auto setStatusFlags = [&]() {
    nzLow.set(x3.getPart(0, 16));
    nzHigh.set(x3.getPart(16, 16));
    nz.set(1 - (1 - nzLow.get()) * (1 - nzHigh.get()));
    Value s1 = x1.get(31);
    Value s2 = x2.get(31);
    Value s3 = x3.get(31);
    overflow.set(s1 * (1 - s2) * (1 - s3) + (1 - s1) * s2 * s3);
    lt.set(overflow.get() + s3 - 2 * overflow.get() * s3);
  };
#define OPC(                                                                                       \
    id, mnem, opc, f3, f7, immFmt, x1Expr, doLoad, x2Expr, x3Expr, resExpr, sreg, dstore, pcExpr)  \
  if ((id) >> 3 == highID) {                                                                       \
    BYZ_IF(decode.opID1.is((id)&7)) {                                                              \
      x1.setCarry(x1Expr, x1CarryLow, x1CarryHigh);                                                \
      Value cycle = state.code.cycle.get();                                                        \
      if (doLoad) {                                                                                \
        state.data.memIO.doRead(cycle, x1.getPart(2, kMemBits));                                   \
        equate(x1.getPart(2 + kMemBits, 32 - kMemBits - 2), 0);                                    \
      } else {                                                                                     \
        state.data.memIO.doRead(cycle, 0);                                                         \
      }                                                                                            \
      x2.set(x2Expr);                                                                              \
      x1Low.set(x1.getPart(0, 2));                                                                 \
      x3.setCarry(x3Expr, x3CarryLow, x3CarryHigh);                                                \
      BYZ_NONDET { setStatusFlags(); }                                                             \
      resultInfo.result.set(resExpr);                                                              \
      resultInfo.setReg.set(sreg);                                                                 \
      resultInfo.doStore.set(dstore);                                                              \
      resultInfo.pcRaw.set(pcExpr);                                                                \
      risc0Log(                                                                                    \
          "C%u: pc: %04x%04x Compute: x1 = 0x%04x%04x, x2 = 0x%04x%04x, x3 = 0x%04x%04x, nz = "    \
          "%d, carry = %d, "                                                                       \
          "overflow = %d, lt = %d, result = 0x%04x%04x",                                           \
          {cycle,                                                                                  \
           PC.high,                                                                                \
           PC.low,                                                                                 \
           x1.getPart(16, 16),                                                                     \
           x1.getPart(0, 16),                                                                      \
           x2.getPart(16, 16),                                                                     \
           x2.getPart(0, 16),                                                                      \
           x3.getPart(16, 16),                                                                     \
           x3.getPart(0, 16),                                                                      \
           nz.get(),                                                                               \
           x3CarryHigh.get(),                                                                      \
           overflow.get(),                                                                         \
           lt.get(),                                                                               \
           resultInfo.result.high(),                                                               \
           resultInfo.result.low()});                                                              \
    }                                                                                              \
  }
#define OPM(...) /**/
#define OPD(...) /**/
#include "risc0/zkvm/circuit/riscv32im.inl"
#undef MS
#undef EQ
#undef LT
#undef LTU
#undef BR_TRUE
#undef BR_FALSE
#undef OPC
#undef OPM
#undef OPD
  setStatusFlags();
}

} // namespace risc0::circuit