use alloc::vec::Vec;
use p3_util::log2_ceil_usize;
use crate::layout::witness::{Selector, TablePlacement};
#[derive(Debug, Clone, Copy)]
pub(crate) struct LayoutShape {
pub(crate) arity: usize,
pub(crate) width: usize,
}
pub(crate) fn plan_layout(shapes: &[LayoutShape]) -> (usize, Vec<TablePlacement>) {
let mut order = (0..shapes.len()).collect::<Vec<usize>>();
order.sort_by_key(|&i| shapes[i].arity);
let k = log2_ceil_usize(
shapes
.iter()
.map(|s| s.width * (1usize << s.arity))
.sum::<usize>(),
);
let mut offset = 0usize;
let mut placements = Vec::with_capacity(shapes.len());
for &table_idx in order.iter().rev() {
let shape = &shapes[table_idx];
let slot_size = 1usize << shape.arity;
let selectors = (0..shape.width)
.map(|_| {
let selector = Selector::new(k - shape.arity, offset >> shape.arity);
offset += slot_size;
selector
})
.collect();
placements.push(TablePlacement::new(table_idx, selectors));
}
(k, placements)
}
#[cfg(test)]
mod tests {
use alloc::vec;
use super::*;
#[test]
fn plan_layout_empty_returns_zero_arity() {
let (k, placements) = plan_layout(&[]);
assert_eq!(k, 0);
assert!(placements.is_empty());
}
#[test]
fn plan_layout_single_table_places_at_origin() {
let (k, placements) = plan_layout(&[LayoutShape { arity: 3, width: 2 }]);
assert_eq!(k, 4);
assert_eq!(placements.len(), 1);
assert_eq!(placements[0].idx(), 0);
assert_eq!(placements[0].selectors()[0].index(), 0);
assert_eq!(placements[0].selectors()[1].index(), 1);
}
#[test]
fn plan_layout_sorts_largest_first() {
let shapes = vec![
LayoutShape { arity: 3, width: 1 },
LayoutShape { arity: 5, width: 1 },
];
let (_, placements) = plan_layout(&shapes);
assert_eq!(placements[0].idx(), 1);
assert_eq!(placements[1].idx(), 0);
}
#[test]
fn plan_layout_offsets_are_contiguous_and_aligned() {
let shapes = vec![
LayoutShape { arity: 3, width: 2 },
LayoutShape { arity: 5, width: 2 },
];
let (k, placements) = plan_layout(&shapes);
assert_eq!(k, 7);
assert_eq!(placements[0].idx(), 1);
assert_eq!(placements[0].selectors()[0].index() << 5, 0);
assert_eq!(placements[0].selectors()[1].index() << 5, 32);
assert_eq!(placements[1].idx(), 0);
assert_eq!(placements[1].selectors()[0].index() << 3, 64);
assert_eq!(placements[1].selectors()[1].index() << 3, 72);
}
}