use std::{
cmp,
collections::{BTreeSet, HashMap},
ops::Range,
};
use super::{RegionColumn, RegionShape};
use crate::{circuit::RegionStart, plonk::Any};
#[derive(Clone, Default, Debug, PartialEq, Eq)]
struct AllocatedRegion {
start: usize,
length: usize,
}
impl Ord for AllocatedRegion {
fn cmp(&self, other: &Self) -> cmp::Ordering {
self.start.cmp(&other.start)
}
}
impl PartialOrd for AllocatedRegion {
fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
Some(self.cmp(other))
}
}
pub(crate) struct EmptySpace {
start: usize,
end: Option<usize>,
}
impl EmptySpace {
pub(crate) fn range(&self) -> Option<Range<usize>> {
self.end.map(|end| self.start..end)
}
}
#[derive(Clone, Default, Debug)]
pub struct Allocations(BTreeSet<AllocatedRegion>);
impl Allocations {
pub(crate) fn unbounded_interval_start(&self) -> usize {
self.0
.iter()
.last()
.map(|r| r.start + r.length)
.unwrap_or(0)
}
pub(crate) fn free_intervals(
&self,
start: usize,
end: Option<usize>,
) -> impl Iterator<Item = EmptySpace> + '_ {
self.0
.iter()
.map(Some)
.chain(Some(None))
.scan(start, move |row, region| {
Some(if let Some(region) = region {
if end.map(|end| region.start >= end).unwrap_or(false) {
None
} else {
let ret = if *row < region.start {
Some(EmptySpace {
start: *row,
end: Some(region.start),
})
} else {
None
};
*row = cmp::max(*row, region.start + region.length);
ret
}
} else if end.map(|end| *row < end).unwrap_or(true) {
Some(EmptySpace { start: *row, end })
} else {
None
})
})
.flatten()
}
}
pub type CircuitAllocations = HashMap<RegionColumn, Allocations>;
fn first_fit_region(
column_allocations: &mut CircuitAllocations,
region_columns: &[RegionColumn],
region_length: usize,
start: usize,
slack: Option<usize>,
) -> Option<usize> {
let (c, remaining_columns) = match region_columns.split_first() {
Some(cols) => cols,
None => return Some(start),
};
let end = slack.map(|slack| start + region_length + slack);
for space in column_allocations
.entry(*c)
.or_default()
.clone()
.free_intervals(start, end)
{
let s_slack = space
.end
.map(|end| (end as isize - space.start as isize) - region_length as isize);
if let Some((slack, s_slack)) = slack.zip(s_slack) {
assert!(s_slack <= slack as isize);
}
if s_slack.unwrap_or(0) >= 0 {
let row = first_fit_region(
column_allocations,
remaining_columns,
region_length,
space.start,
s_slack.map(|s| s as usize),
);
if let Some(row) = row {
if let Some(end) = end {
assert!(row + region_length <= end);
}
column_allocations
.get_mut(c)
.unwrap()
.0
.insert(AllocatedRegion {
start: row,
length: region_length,
});
return Some(row);
}
}
}
None
}
fn slot_in(
region_shapes: Vec<RegionShape>,
) -> (Vec<(RegionStart, RegionShape)>, CircuitAllocations) {
let mut column_allocations: CircuitAllocations = Default::default();
let regions = region_shapes
.into_iter()
.map(|region| {
let mut region_columns: Vec<_> = region.columns().iter().cloned().collect();
region_columns.sort_unstable();
let region_start = first_fit_region(
&mut column_allocations,
®ion_columns,
region.row_count(),
0,
None,
)
.expect("We can always fit a region somewhere");
(region_start.into(), region)
})
.collect();
(regions, column_allocations)
}
pub fn slot_in_biggest_advice_first(
region_shapes: Vec<RegionShape>,
) -> (Vec<RegionStart>, CircuitAllocations) {
let mut sorted_regions: Vec<_> = region_shapes.into_iter().collect();
sorted_regions.sort_unstable_by_key(|shape| {
let advice_cols = shape
.columns()
.iter()
.filter(|c| match c {
RegionColumn::Column(c) => matches!(c.column_type(), Any::Advice),
_ => false,
})
.count();
advice_cols * shape.row_count()
});
sorted_regions.reverse();
let (mut regions, column_allocations) = slot_in(sorted_regions);
regions.sort_unstable_by_key(|(_, region)| region.region_index().0);
let regions = regions.into_iter().map(|(start, _)| start).collect();
(regions, column_allocations)
}
#[test]
fn test_slot_in() {
use crate::plonk::Column;
let regions = vec![
RegionShape {
region_index: 0.into(),
columns: vec![Column::new(0, Any::Advice), Column::new(1, Any::Advice)]
.into_iter()
.map(|a| a.into())
.collect(),
row_count: 15,
},
RegionShape {
region_index: 1.into(),
columns: vec![Column::new(2, Any::Advice)]
.into_iter()
.map(|a| a.into())
.collect(),
row_count: 10,
},
RegionShape {
region_index: 2.into(),
columns: vec![Column::new(2, Any::Advice), Column::new(0, Any::Advice)]
.into_iter()
.map(|a| a.into())
.collect(),
row_count: 10,
},
];
assert_eq!(
slot_in(regions)
.0
.into_iter()
.map(|(i, _)| i)
.collect::<Vec<_>>(),
vec![0.into(), 0.into(), 15.into()]
);
}