#pragma once
#include "risc0/zkvm/circuit/context.h"
namespace risc0::circuit {
class Value {
public:
Value() = default;
Value(std::shared_ptr<ValueImplBase> impl) : impl(impl) {}
Value(int value, SourceLoc loc = SourceLoc::current())
: Value(getGlobalContext()->constant(Fp(value), loc)) {}
Value(Fp value, SourceLoc loc = SourceLoc::current())
: Value(getGlobalContext()->constant(value, loc)) {}
std::shared_ptr<ValueImplBase> getImpl() const { return impl; }
private:
std::shared_ptr<ValueImplBase> impl;
};
class Buffer {
public:
Buffer(std::shared_ptr<BufferImplBase> impl) : impl(impl) {}
size_t size() { return impl->size; }
size_t back() { return impl->back; }
Buffer slice(size_t start, size_t size, SourceLoc loc = SourceLoc::current()) const {
return getGlobalContext()->slice(impl, start, size, loc);
}
Buffer back(size_t back, SourceLoc loc = SourceLoc::current()) {
return getGlobalContext()->back(impl, back, loc);
}
Buffer requireDigits(size_t bits, SourceLoc loc = SourceLoc::current()) {
return getGlobalContext()->requireDigits(impl, bits, loc);
}
Buffer requireMux(SourceLoc loc = SourceLoc::current()) {
return getGlobalContext()->requireMux(impl, loc);
}
struct BufAccess {
public:
BufAccess(std::shared_ptr<BufferImplBase> impl, size_t idx, SourceLoc loc)
: impl(impl), idx(idx), loc(loc) {}
BufAccess(const BufAccess& rhs) = delete;
BufAccess(BufAccess&& rhs) = default;
void operator=(const BufAccess& rhs) {
auto rawVal = getGlobalContext()->getVal(impl, idx, loc);
getGlobalContext()->setVal(impl, idx, rawVal, loc);
}
void operator=(Value val) { getGlobalContext()->setVal(impl, idx, val.getImpl(), loc); }
operator Value() const { return getGlobalContext()->getVal(impl, idx, loc); }
private:
std::shared_ptr<BufferImplBase> impl;
size_t idx;
SourceLoc loc;
};
struct CaptureIdxLoc {
CaptureIdxLoc(size_t idx, SourceLoc loc = SourceLoc::current()) : idx(idx), loc(loc) {}
size_t idx;
SourceLoc loc;
};
const BufAccess operator[](CaptureIdxLoc idxLoc) const {
return BufAccess(impl, idxLoc.idx, idxLoc.loc);
}
BufAccess operator[](CaptureIdxLoc idxLoc) { return BufAccess(impl, idxLoc.idx, idxLoc.loc); }
const BufAccess at(size_t idx, SourceLoc loc = SourceLoc::current()) const {
return BufAccess(impl, idx, loc);
}
BufAccess at(size_t idx, SourceLoc loc = SourceLoc::current()) {
return BufAccess(impl, idx, loc);
}
Value getDigits(size_t bits, SourceLoc loc = SourceLoc::current()) {
return getGlobalContext()->getDigits(impl, bits, loc);
}
Value setDigits(size_t bits, Value val, SourceLoc loc = SourceLoc::current()) {
return getGlobalContext()->setDigits(impl, bits, val.getImpl(), loc);
}
Value getMux(SourceLoc loc = SourceLoc::current()) {
return getGlobalContext()->getMux(impl, loc);
}
void setMux(Value val, SourceLoc loc = SourceLoc::current()) {
getGlobalContext()->setMux(impl, val.getImpl(), loc);
}
private:
std::shared_ptr<BufferImplBase> impl;
};
inline Value getGlobal(size_t offset, SourceLoc loc = SourceLoc::current()) {
return getGlobalContext()->getGlobal(offset, loc);
}
inline void setGlobal(size_t offset, Value val, SourceLoc loc = SourceLoc::current()) {
return getGlobalContext()->setGlobal(offset, val.getImpl(), loc);
}
struct CaptureValLoc {
CaptureValLoc(Value val, SourceLoc loc = SourceLoc::current()) : val(val), loc(loc) {}
CaptureValLoc(Buffer::BufAccess access, SourceLoc loc = SourceLoc::current())
: val(access), loc(loc) {}
CaptureValLoc(Fp fp, SourceLoc loc = SourceLoc::current()) : val(fp), loc(loc) {}
CaptureValLoc(int num, SourceLoc loc = SourceLoc::current()) : val(Fp(num)), loc(loc) {}
Value val;
SourceLoc loc;
};
inline Value operator+(CaptureValLoc a, CaptureValLoc b) {
return getGlobalContext()->add(a.val.getImpl(), b.val.getImpl(), b.loc);
}
inline Value operator-(CaptureValLoc a, CaptureValLoc b) {
return getGlobalContext()->sub(a.val.getImpl(), b.val.getImpl(), b.loc);
}
inline Value operator*(CaptureValLoc a, CaptureValLoc b) {
return getGlobalContext()->mul(a.val.getImpl(), b.val.getImpl(), b.loc);
}
inline Value operator/(CaptureValLoc a, CaptureValLoc b) {
auto invB = getGlobalContext()->inv(b.val.getImpl(), b.loc);
return getGlobalContext()->mul(a.val.getImpl(), invB, b.loc);
}
inline Value operator&(CaptureValLoc a, CaptureValLoc b) {
return getGlobalContext()->bitAnd(a.val.getImpl(), b.val.getImpl(), b.loc);
}
inline Value inv(Value a, SourceLoc loc = SourceLoc::current()) {
return getGlobalContext()->inv(a.getImpl(), loc);
}
inline Value nonzero(Value a, SourceLoc loc = SourceLoc::current()) {
return getGlobalContext()->nonzero(a.getImpl(), loc);
}
inline void equate(Value a, Value b, SourceLoc loc = SourceLoc::current()) {
auto diff = getGlobalContext()->sub(a.getImpl(), b.getImpl(), loc);
return getGlobalContext()->assertZero(diff, loc);
}
inline void risc0Log(const char* str, std::vector<Value> vals) {
std::vector<std::shared_ptr<ValueImplBase>> impls;
for (auto& val : vals) {
impls.push_back(val.getImpl());
}
getGlobalContext()->log(str, impls);
}
inline std::array<Value, 4> divide(Value numerLow,
Value numerHigh,
Value denomLow,
Value denomHigh,
SourceLoc loc = SourceLoc::current()) {
auto [v1, v2, v3, v4] = getGlobalContext()->divide32(
numerLow.getImpl(), numerHigh.getImpl(), denomLow.getImpl(), denomHigh.getImpl(), loc);
return {v1, v2, v3, v4};
}
inline void
memWrite(Value cycle, Value addr, Value low, Value high, SourceLoc loc = SourceLoc::current()) {
getGlobalContext()->memWrite(cycle.getImpl(), addr.getImpl(), low.getImpl(), high.getImpl(), loc);
}
inline std::array<Value, 2> memRead(Value cycle, Value addr, SourceLoc loc = SourceLoc::current()) {
auto [v1, v2] = getGlobalContext()->memRead(cycle.getImpl(), addr.getImpl(), loc);
return {v1, v2};
}
inline std::array<Value, 5> memCheck(SourceLoc loc = SourceLoc::current()) {
auto [v1, v2, v3, v4, v5] = getGlobalContext()->memCheck(loc);
return {v1, v2, v3, v4, v5};
}
class NondetGuard {
public:
bool doNondet;
NondetGuard(SourceLoc loc = SourceLoc::current())
: doNondet(getGlobalContext()->beginNondet(loc)) {}
~NondetGuard() { getGlobalContext()->endNondet(); }
operator bool() { return doNondet; }
};
class GroupGuard {
public:
GroupGuard(SourceLoc loc = SourceLoc::current()) { getGlobalContext()->beginGroup(loc); }
~GroupGuard() { getGlobalContext()->endGroup(); }
operator bool() { return true; }
};
class IfGuard {
public:
IfGuard(Value cond, SourceLoc loc = SourceLoc::current()) {
getGlobalContext()->beginIf(cond.getImpl(), loc);
}
~IfGuard() { getGlobalContext()->endIf(); }
operator bool() { return true; }
};
#define BYZ_NONDET if (auto nondetGuard = NondetGuard())
#define BYZ_GROUP if (auto groupGuard = GroupGuard())
#define BYZ_IF(cond) if (auto ifGuard = IfGuard(cond))
}