use alloc::vec::Vec;
use p3_matrix::Dimensions;
use p3_util::log2_ceil_usize;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TableShape(Dimensions);
impl TableShape {
pub const fn new(num_variables: usize, width: usize) -> Self {
assert!(width > 0);
assert!(num_variables < usize::BITS as usize);
Self(Dimensions {
width,
height: 1 << num_variables,
})
}
pub const fn num_variables(&self) -> usize {
log2_ceil_usize(self.0.height)
}
pub const fn width(&self) -> usize {
self.0.width
}
}
pub type PointSchedule = Vec<Vec<usize>>;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TableSpec {
shape: TableShape,
point_schedule: PointSchedule,
}
impl TableSpec {
pub fn new(shape: TableShape, point_schedule: PointSchedule) -> Self {
assert!(
point_schedule
.iter()
.flatten()
.all(|&poly_idx| poly_idx < shape.width())
);
Self {
shape,
point_schedule,
}
}
pub const fn shape(&self) -> &TableShape {
&self.shape
}
pub const fn point_schedule(&self) -> &PointSchedule {
&self.point_schedule
}
pub const fn pad_to_min_num_variables(&mut self, min_num_variables: usize) {
if self.shape.num_variables() < min_num_variables {
self.shape = TableShape::new(min_num_variables, self.shape.width());
}
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct OpeningProtocol(Vec<TableSpec>);
impl OpeningProtocol {
pub const fn new(tables: Vec<TableSpec>) -> Self {
Self(tables)
}
pub fn table_shapes(&self) -> Vec<TableShape> {
self.0.iter().map(|table| *table.shape()).collect()
}
pub fn pad_to_min_num_variables(mut self, min_num_variables: usize) -> Self {
self.0
.iter_mut()
.for_each(|table| table.pad_to_min_num_variables(min_num_variables));
self
}
pub fn num_openings(&self) -> usize {
self.0
.iter()
.map(|table| table.point_schedule().len())
.sum()
}
pub fn iter_openings(&self) -> impl Iterator<Item = (usize, &[usize])> {
self.0.iter().enumerate().flat_map(|(table_idx, table)| {
table
.point_schedule()
.iter()
.map(move |polys| (table_idx, polys.as_slice()))
})
}
}
#[cfg(test)]
mod tests {
use alloc::vec;
use super::*;
fn single_table_protocol() -> OpeningProtocol {
OpeningProtocol::new(vec![TableSpec::new(
TableShape::new(3, 2),
vec![vec![0, 1], vec![0]],
)])
}
fn two_table_protocol() -> OpeningProtocol {
OpeningProtocol::new(vec![
TableSpec::new(TableShape::new(3, 2), vec![vec![0, 1]]),
TableSpec::new(TableShape::new(4, 3), vec![vec![0, 2], vec![1]]),
])
}
#[test]
fn opening_protocol_table_shapes_returns_shapes_in_protocol_order() {
let protocol = two_table_protocol();
let shapes = protocol.table_shapes();
assert_eq!(shapes, vec![TableShape::new(3, 2), TableShape::new(4, 3)],);
}
#[test]
fn opening_protocol_num_openings_sums_per_table_schedules() {
let protocol = two_table_protocol();
assert_eq!(protocol.num_openings(), 3);
let empty = OpeningProtocol::new(vec![]);
assert_eq!(empty.num_openings(), 0);
}
#[test]
fn opening_protocol_iter_openings_yields_batches_in_transcript_order() {
let protocol = two_table_protocol();
let collected: Vec<(usize, Vec<usize>)> = protocol
.iter_openings()
.map(|(table_idx, polys)| (table_idx, polys.to_vec()))
.collect();
assert_eq!(
collected,
vec![(0, vec![0, 1]), (1, vec![0, 2]), (1, vec![1]),],
);
}
#[test]
fn opening_protocol_pad_to_min_num_variables_grows_small_tables() {
let padded = two_table_protocol().pad_to_min_num_variables(5);
let shapes = padded.table_shapes();
assert_eq!(shapes[0], TableShape::new(5, 2));
assert_eq!(shapes[1], TableShape::new(5, 3));
assert_eq!(padded.num_openings(), 3);
}
#[test]
fn opening_protocol_pad_to_min_num_variables_leaves_large_tables_alone() {
let original = single_table_protocol();
let original_shapes = original.table_shapes();
let padded = original.pad_to_min_num_variables(2);
assert_eq!(padded.table_shapes(), original_shapes);
}
#[test]
fn table_spec_accessors_forward_constructor_arguments() {
let shape = TableShape::new(3, 2);
let schedule: PointSchedule = vec![vec![0, 1], vec![0]];
let spec = TableSpec::new(shape, schedule.clone());
assert_eq!(*spec.shape(), shape);
assert_eq!(spec.point_schedule(), &schedule);
}
#[test]
#[should_panic]
fn table_spec_new_panics_on_out_of_range_poly_idx() {
let _ = TableSpec::new(TableShape::new(3, 2), vec![vec![2]]);
}
#[test]
fn table_spec_pad_to_min_num_variables_grows_small_arity() {
let mut spec = TableSpec::new(TableShape::new(2, 3), vec![vec![0]]);
spec.pad_to_min_num_variables(4);
assert_eq!(spec.shape().num_variables(), 4);
assert_eq!(spec.shape().width(), 3);
}
#[test]
fn table_spec_pad_to_min_num_variables_leaves_larger_arity_alone() {
let mut spec = TableSpec::new(TableShape::new(5, 1), vec![vec![0]]);
spec.pad_to_min_num_variables(3);
assert_eq!(spec.shape().num_variables(), 5);
assert_eq!(spec.shape().width(), 1);
}
}