1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
// Copyright 2022 Risc0, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "risc0/zkvm/circuit/decode_cycle.h"
#include "risc0/zkvm/circuit/step_state.h"
namespace risc0::circuit {
// Helper for matching in nondet sections
static Value match(int pat, Value x) {
if (pat == -1) {
return 1;
}
return 1 - nonzero(pat - x);
}
// Do the decode cycle. Turn off max function size since macro inclusion make it too big.
// NOLINTNEXTLINE(readability-function-size)
void DecodeCycle::set(StepState& state) {
// Get previous instruction (either reset or final)
DataRegs& prev = state.getPrev(1);
FinalCycle prevFinal = prev.asFinal();
// Make sure PC is an aligned access
equate(prevFinal.pc.getPart(0, 2), 0);
// Get PC to read from
Value pc = prevFinal.pc.getPart(2, kMemBits);
// Do the memory read of the instruction
MemIORegs& memIO = state.data.memIO;
Value cycle = state.code.cycle.get();
memIO.doRead(cycle, pc);
ValueU32 instVal = memIO.value.get();
// Extract bits from the instruction
inst.set(instVal);
// Require low 2 bits of opcode to be 1 (normal compressed instruction0
equate(inst.get(0), 1);
equate(inst.get(1), 1);
// Now we generate the various intermediate formats, which will later by selected during decode.
// Basically this is an encoding of figure 2.4 from the RiscV spec, with the additional
// complication that we need to split the low and high 16 bits. We initially extract and shift
// each of the common fields. Then we put them together into various instruction types, and
// finally pick the right one.
// Extract + shift
auto instField = [&](size_t low, size_t high) { return inst.getPart(low, high - low + 1); };
Value imm3025 = instField(25, 30) * (1 << 5);
Value imm2421 = instField(21, 24) * (1 << 1);
Value imm1108 = instField(8, 11) * (1 << 1);
Value imm1512 = instField(12, 15) * (1 << 12);
Value imm1916 = instField(16, 19);
Value imm3020 = instField(20, 30) * (1 << 4);
Value imm07 = inst.get(07);
Value imm20 = inst.get(20);
Value imm31 = inst.get(31);
// Combine by type
ValueU32 immR = {0, 0};
ValueU32 immI = {imm3025 + imm2421 + imm20 + imm31 * int(uint16_t(-(1 << 11))), imm31 * 0xffff};
ValueU32 immS = {imm3025 + imm1108 + imm07 + imm31 * int(uint16_t(-(1 << 11))), imm31 * 0xffff};
ValueU32 immB = {imm3025 + imm1108 + imm07 * (1 << 11) + imm31 * int(uint16_t(-(1 << 12))),
imm31 * 0xffff};
ValueU32 immU = {imm1512, imm3020 + imm1916 + imm31 * 0x8000};
ValueU32 immJ = {imm1512 + imm20 * (1 << 11) + imm3025 + imm2421,
imm1916 + imm31 * int(uint16_t(-(1 << 4)))};
// Extract rs1 + rs2 and get 1-hot versions
Value rs1Id = inst.getPart(15, 5);
Value rs2Id = inst.getPart(20, 5);
rs1Low.set(inst.getPart(15, 3));
rs2Low.set(inst.getPart(20, 3));
rs1High.set(inst.getPart(18, 2));
rs2High.set(inst.getPart(23, 2));
// Demux rs1 + rs2. Missing case of 0 handles the zero register load
ValueU32 rs1Val = {0, 0};
ValueU32 rs2Val = {0, 0};
for (int i = 1; i < 32; i++) {
Value r1Sel = rs1Low.is(i & 0b111) * rs1High.is(i >> 3);
Value r2Sel = rs2Low.is(i & 0b111) * rs2High.is(i >> 3);
rs1Val = rs1Val + r1Sel * prevFinal.regs[i].get();
rs2Val = rs2Val + r2Sel * prevFinal.regs[i].get();
}
rs1.set(rs1Val);
rs2.set(rs2Val);
// Extract the various fields useful to instruction decoding
Value opcode = inst.getPart(2, 5);
Value func3 = inst.getPart(12, 3);
Value func7 = inst.getPart(25, 7);
// Decode the instruction. We do this in two parts: First we set the opID nondeterministicly, and
// then we verify we got it right. For the purposed of decoding, we treat OP<X> instructions the
// same, so we redefine them all to ANYOP
#define OPC(...) ANYOP(COMPUTE_0, __VA_ARGS__)
#define OPM(...) ANYOP(MULTIPLY, __VA_ARGS__)
#define OPD(...) ANYOP(DIVIDE, __VA_ARGS__)
// Pick out the op id
BYZ_NONDET {
#define ANYOP(ct, id, mnem, oc, f3, f7, ...) \
BYZ_IF(match(oc, opcode) * match(f3, func3) * match(f7, func7)) { \
opID1.set(id & 7); \
opID2.set(id >> 3); \
}
#include "risc0/zkvm/circuit/riscv32im.inl"
// Special case for halt, basically to avoid OPH
ANYOP(SHA_SYNC, 63, HALT, 0b11100, 0, 0, R, 0)
#undef ANYOP
}
// Now verify we got it right + set the immediate value + next step
#define ANYOP(ct, id, mnem, oc, f3, f7, immFmt, ...) \
BYZ_IF(opID1.is(id & 7) * opID2.is(id >> 3)) { \
equate(opcode, oc); \
if (f3 >= 0) { \
equate(func3, f3); \
} \
if (f7 >= 0) { \
equate(func7, f7); \
} \
BYZ_GROUP { \
imm.set(imm##immFmt); \
Value val2Low16 = (#immFmt[0] == 'R') ? rs2.low() : imm.low(); \
val2Low.setPartExact(val2Low16, 0, 16); \
val2Split.setPartExact(val2Low.getPart(4, 2), 0, 2); \
val2OH.set(val2Low.getPart(0, 4) + 16 * val2Split.get(0)); \
} \
nextCycleType.set(DataCycleType::ct + (id < 32 ? id >> 3 : 0)); \
risc0Log("C%u: pc: %08x Decode: " #mnem " r%u=0x%04x%04x, r%u=0x%04x%04x, imm=0x%04x%04x", \
{cycle, \
pc * 4, \
rs1Id, \
rs1.high(), \
rs1.low(), \
rs2Id, \
rs2.high(), \
rs2.low(), \
imm.high(), \
imm.low()}); \
}
#include "risc0/zkvm/circuit/riscv32im.inl"
ANYOP(SHA_SYNC, 63, HALT, 0b11100, 0, 0, R, 0)
#undef ANYOP
// Done with decode
#undef OPC
#undef OPM
#undef OPD
}
} // namespace risc0::circuit