use super::{
columns::{ShaCompressCols, NUM_SHA_COMPRESS_COLS},
ShaCompressChip, SHA_COMPRESS_K,
};
use crate::{
air::{MemoryAirBuilder, SP1CoreAirBuilder, WordAirBuilder},
operations::{
Add5Operation, AddU32Operation, AddrAddOperation, AndU32Operation,
FixedRotateRightOperation, NotU32Operation, XorU32Operation,
},
utils::u32_to_half_word,
};
use core::borrow::Borrow;
use slop_air::{Air, BaseAir};
use slop_algebra::AbstractField;
use slop_matrix::Matrix;
use sp1_hypercube::{
air::{AirInteraction, BaseAirBuilder, InteractionScope, SP1AirBuilder},
InteractionKind, Word,
};
use std::iter::once;
impl<F> BaseAir<F> for ShaCompressChip {
fn width(&self) -> usize {
NUM_SHA_COMPRESS_COLS
}
}
impl<AB> Air<AB> for ShaCompressChip
where
AB: SP1CoreAirBuilder,
{
fn eval(&self, builder: &mut AB) {
let main = builder.main();
let local = main.row_slice(0);
let local: &ShaCompressCols<AB::Var> = (*local).borrow();
self.eval_control_flow_flags(builder, local);
self.eval_memory(builder, local);
self.eval_compression_ops(builder, local);
self.eval_finalize_ops(builder, local);
}
}
impl ShaCompressChip {
fn eval_control_flow_flags<AB: SP1CoreAirBuilder>(
&self,
builder: &mut AB,
local: &ShaCompressCols<AB::Var>,
) {
builder.assert_bool(local.is_real);
let mut computed_index = AB::Expr::zero();
let mut octet_sum = AB::Expr::zero();
for i in 0..8 {
builder.assert_bool(local.octet[i]);
octet_sum = octet_sum.clone() + local.octet[i].into();
computed_index = computed_index.clone()
+ local.octet[i].into() * AB::Expr::from_canonical_u32(i as u32);
}
builder.assert_one(octet_sum);
let mut octet_num_sum = AB::Expr::zero();
for i in 0..10 {
builder.assert_bool(local.octet_num[i]);
octet_num_sum = octet_num_sum.clone() + local.octet_num[i].into();
computed_index = computed_index.clone()
+ local.octet_num[i].into() * AB::Expr::from_canonical_u32(8 * i as u32);
}
builder.assert_one(octet_num_sum);
builder.assert_eq(local.index, computed_index.clone());
builder.assert_eq(local.is_initialize, local.octet_num[0] * local.is_real);
builder.assert_eq(
local.is_compression,
(local.octet_num[1]
+ local.octet_num[2]
+ local.octet_num[3]
+ local.octet_num[4]
+ local.octet_num[5]
+ local.octet_num[6]
+ local.octet_num[7]
+ local.octet_num[8])
* local.is_real,
);
builder.assert_eq(local.is_finalize, local.octet_num[9] * local.is_real);
let receive_values = once(local.clk_high.into())
.chain(once(local.clk_low.into()))
.chain(local.w_ptr.map(Into::into))
.chain(local.h_ptr.map(Into::into))
.chain(once(local.index.into()))
.chain(
[local.a, local.b, local.c, local.d, local.e, local.f, local.g, local.h]
.into_iter()
.flat_map(|word| word.into_iter())
.map(Into::into),
)
.collect::<Vec<_>>();
builder.receive(
AirInteraction::new(receive_values, local.is_real.into(), InteractionKind::ShaCompress),
InteractionScope::Local,
);
let send_values = once(local.clk_high.into())
.chain(once(local.clk_low.into()))
.chain(local.w_ptr.map(Into::into))
.chain(local.h_ptr.map(Into::into))
.chain(once(local.index.into() + AB::Expr::one()))
.chain(
[local.a, local.b, local.c, local.d, local.e, local.f, local.g, local.h]
.into_iter()
.flat_map(|word| word.into_iter())
.map(Into::into),
)
.collect::<Vec<_>>();
builder.send(
AirInteraction::new(
send_values,
local.is_initialize + local.is_finalize,
InteractionKind::ShaCompress,
),
InteractionScope::Local,
);
let compression_send_values = once(local.clk_high.into())
.chain(once(local.clk_low.into()))
.chain(local.w_ptr.map(Into::into))
.chain(local.h_ptr.map(Into::into))
.chain(once(local.index.into() + AB::Expr::one()))
.chain(
[
local.temp1_add_temp2.value,
local.a,
local.b,
local.c,
local.d_add_temp1.value,
local.e,
local.f,
local.g,
]
.into_iter()
.flat_map(|word| word.into_iter())
.map(Into::into),
)
.collect::<Vec<_>>();
builder.send(
AirInteraction::new(
compression_send_values,
local.is_compression.into(),
InteractionKind::ShaCompress,
),
InteractionScope::Local,
);
}
fn eval_memory<AB: SP1AirBuilder>(&self, builder: &mut AB, local: &ShaCompressCols<AB::Var>) {
let mem_value_word = Word::extend_half::<AB>(&local.mem_value);
builder.eval_memory_access_write(
local.clk_high,
local.clk_low
+ local.is_compression
+ local.is_finalize * AB::Expr::from_canonical_u32(2),
&local.mem_addr.map(Into::into),
local.mem,
mem_value_word.clone(),
local.is_real,
);
builder
.when(local.is_initialize + local.is_compression)
.assert_word_eq(local.mem.prev_value, mem_value_word.clone());
builder.assert_zero(local.mem.prev_value[2]);
builder.assert_zero(local.mem.prev_value[3]);
builder.when(local.is_initialize).assert_all_eq(local.mem_addr, local.mem_addr_init.value);
builder
.when(local.is_compression)
.assert_all_eq(local.mem_addr, local.mem_addr_compress.value);
builder
.when(local.is_finalize)
.assert_all_eq(local.mem_addr, local.mem_addr_finalize.value);
AddrAddOperation::<AB::F>::eval(
builder,
Word([
local.h_ptr[0].into(),
local.h_ptr[1].into(),
local.h_ptr[2].into(),
AB::Expr::zero(),
]),
Word::extend_expr::<AB>(local.index * AB::Expr::from_canonical_u32(8)),
local.mem_addr_init,
local.is_initialize.into(),
);
AddrAddOperation::<AB::F>::eval(
builder,
Word([
local.w_ptr[0].into(),
local.w_ptr[1].into(),
local.w_ptr[2].into(),
AB::Expr::zero(),
]),
Word::extend_expr::<AB>(
(local.index - AB::Expr::from_canonical_u32(8)) * AB::Expr::from_canonical_u32(8),
),
local.mem_addr_compress,
local.is_compression.into(),
);
AddrAddOperation::<AB::F>::eval(
builder,
Word([
local.h_ptr[0].into(),
local.h_ptr[1].into(),
local.h_ptr[2].into(),
AB::Expr::zero(),
]),
Word::extend_expr::<AB>(
(local.index - AB::Expr::from_canonical_u32(72)) * AB::Expr::from_canonical_u32(8),
),
local.mem_addr_finalize,
local.is_finalize.into(),
);
let a_word = Word::extend_half::<AB>(&local.a);
let b_word = Word::extend_half::<AB>(&local.b);
let c_word = Word::extend_half::<AB>(&local.c);
let d_word = Word::extend_half::<AB>(&local.d);
let e_word = Word::extend_half::<AB>(&local.e);
let f_word = Word::extend_half::<AB>(&local.f);
let g_word = Word::extend_half::<AB>(&local.g);
let h_word = Word::extend_half::<AB>(&local.h);
let vars = [a_word, b_word, c_word, d_word, e_word, f_word, g_word, h_word];
for (i, var) in vars.iter().enumerate() {
builder
.when(local.is_initialize * local.octet[i])
.assert_word_eq(var.clone(), local.mem.prev_value);
builder
.when(local.is_initialize * local.octet[i])
.assert_word_eq(var.clone(), mem_value_word.clone());
}
builder.when(local.is_finalize).assert_all_eq(local.mem_value, local.finalize_add.value);
}
fn eval_compression_ops<AB: SP1CoreAirBuilder>(
&self,
builder: &mut AB,
local: &ShaCompressCols<AB::Var>,
) {
for i in 0..64 {
let octet_num = i / 8;
let inner_index = i % 8;
let k: [AB::F; 2] = u32_to_half_word(SHA_COMPRESS_K[i]);
builder
.when(local.octet_num[octet_num + 1] * local.octet[inner_index])
.assert_all_eq(local.k, k);
}
FixedRotateRightOperation::<AB::F>::eval(
builder,
local.e,
6,
local.e_rr_6,
local.is_compression,
);
FixedRotateRightOperation::<AB::F>::eval(
builder,
local.e,
11,
local.e_rr_11,
local.is_compression,
);
FixedRotateRightOperation::<AB::F>::eval(
builder,
local.e,
25,
local.e_rr_25,
local.is_compression,
);
let s1_intermediate = XorU32Operation::<AB::F>::eval_xor_u32(
builder,
local.e_rr_6.value.map(Into::into),
local.e_rr_11.value.map(Into::into),
local.s1_intermediate,
local.is_compression,
);
let s1 = XorU32Operation::<AB::F>::eval_xor_u32(
builder,
s1_intermediate,
local.e_rr_25.value.map(Into::into),
local.s1,
local.is_compression,
);
let e_and_f = AndU32Operation::<AB::F>::eval_and_u32(
builder,
local.e.map(Into::into),
local.f.map(Into::into),
local.e_and_f,
local.is_compression,
);
NotU32Operation::<AB::F>::eval(
builder,
local.e.map(Into::into),
local.e_not,
local.is_compression,
);
let e_not_and_g = AndU32Operation::<AB::F>::eval_and_u32(
builder,
local.e_not.value.map(Into::into),
local.g.map(Into::into),
local.e_not_and_g,
local.is_compression,
);
let ch = XorU32Operation::<AB::F>::eval_xor_u32(
builder,
e_and_f,
e_not_and_g,
local.ch,
local.is_compression,
);
Add5Operation::<AB::F>::eval(
builder,
&[
local.h.map(Into::into),
s1,
ch,
local.k.map(Into::into),
local.mem_value.map(Into::into),
],
local.is_compression,
local.temp1,
);
FixedRotateRightOperation::<AB::F>::eval(
builder,
local.a,
2,
local.a_rr_2,
local.is_compression,
);
FixedRotateRightOperation::<AB::F>::eval(
builder,
local.a,
13,
local.a_rr_13,
local.is_compression,
);
FixedRotateRightOperation::<AB::F>::eval(
builder,
local.a,
22,
local.a_rr_22,
local.is_compression,
);
let s0_intermediate = XorU32Operation::<AB::F>::eval_xor_u32(
builder,
local.a_rr_2.value.map(Into::into),
local.a_rr_13.value.map(Into::into),
local.s0_intermediate,
local.is_compression,
);
let s0 = XorU32Operation::<AB::F>::eval_xor_u32(
builder,
s0_intermediate,
local.a_rr_22.value.map(Into::into),
local.s0,
local.is_compression,
);
let a_and_b = AndU32Operation::<AB::F>::eval_and_u32(
builder,
local.a.map(Into::into),
local.b.map(Into::into),
local.a_and_b,
local.is_compression,
);
let a_and_c = AndU32Operation::<AB::F>::eval_and_u32(
builder,
local.a.map(Into::into),
local.c.map(Into::into),
local.a_and_c,
local.is_compression,
);
let b_and_c = AndU32Operation::<AB::F>::eval_and_u32(
builder,
local.b.map(Into::into),
local.c.map(Into::into),
local.b_and_c,
local.is_compression,
);
let maj_intermediate = XorU32Operation::<AB::F>::eval_xor_u32(
builder,
a_and_b,
a_and_c,
local.maj_intermediate,
local.is_compression,
);
let maj = XorU32Operation::<AB::F>::eval_xor_u32(
builder,
maj_intermediate,
b_and_c,
local.maj,
local.is_compression,
);
AddU32Operation::<AB::F>::eval(builder, s0, maj, local.temp2, local.is_compression.into());
AddU32Operation::<AB::F>::eval(
builder,
local.d.map(Into::into),
local.temp1.value.map(Into::into),
local.d_add_temp1,
local.is_compression.into(),
);
AddU32Operation::<AB::F>::eval(
builder,
local.temp1.value.map(Into::into),
local.temp2.value.map(Into::into),
local.temp1_add_temp2,
local.is_compression.into(),
);
}
fn eval_finalize_ops<AB: SP1AirBuilder>(
&self,
builder: &mut AB,
local: &ShaCompressCols<AB::Var>,
) {
let add_operands = [local.a, local.b, local.c, local.d, local.e, local.f, local.g, local.h];
let mut filtered_operand = [AB::Expr::zero(), AB::Expr::zero()];
for (flag, operand) in local.octet.into_iter().zip(add_operands.iter()) {
filtered_operand[0] = filtered_operand[0].clone() + flag * operand[0];
filtered_operand[1] = filtered_operand[1].clone() + flag * operand[1];
}
builder
.when(local.is_finalize)
.assert_all_eq(filtered_operand.clone(), local.finalized_operand);
AddU32Operation::<AB::F>::eval(
builder,
[local.mem.prev_value.0[0], local.mem.prev_value.0[1]].map(Into::into),
local.finalized_operand.map(Into::into),
local.finalize_add,
local.is_finalize.into(),
);
}
}