use std::{
cmp,
collections::{BTreeSet, HashMap},
ops::Range,
};
use super::{RegionColumn, RegionShape};
use crate::{circuit::RegionStart, plonk::Any};
#[derive(Clone, Default, Debug, PartialEq, Eq)]
struct AllocatedRegion {
start: usize,
length: usize,
}
impl Ord for AllocatedRegion {
fn cmp(&self, other: &Self) -> cmp::Ordering {
self.start.cmp(&other.start)
}
}
impl PartialOrd for AllocatedRegion {
fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
Some(self.cmp(other))
}
}
pub(crate) struct EmptySpace {
start: usize,
end: Option<usize>,
}
impl EmptySpace {
pub(crate) fn range(&self) -> Option<Range<usize>> {
self.end.map(|end| self.start..end)
}
}
#[derive(Clone, Default, Debug)]
pub struct Allocations(BTreeSet<AllocatedRegion>);
impl Allocations {
pub(crate) fn unbounded_interval_start(&self) -> usize {
self.0.iter().last().map(|r| r.start + r.length).unwrap_or(0)
}
pub(crate) fn free_intervals(
&self,
start: usize,
end: Option<usize>,
) -> impl Iterator<Item = EmptySpace> + '_ {
self.0
.iter()
.map(Some)
.chain(Some(None))
.scan(start, move |row, region| {
Some(if let Some(region) = region {
if end.map(|end| region.start >= end).unwrap_or(false) {
None
} else {
let ret = if *row < region.start {
Some(EmptySpace {
start: *row,
end: Some(region.start),
})
} else {
None
};
*row = cmp::max(*row, region.start + region.length);
ret
}
} else if end.map(|end| *row < end).unwrap_or(true) {
Some(EmptySpace { start: *row, end })
} else {
None
})
})
.flatten()
}
}
pub type CircuitAllocations = HashMap<RegionColumn, Allocations>;
fn first_fit_region(
column_allocations: &mut CircuitAllocations,
region_columns: &[RegionColumn],
region_length: usize,
start: usize,
slack: Option<usize>,
) -> Option<usize> {
let (c, remaining_columns) = match region_columns.split_first() {
Some(cols) => cols,
None => return Some(start),
};
let end = slack.map(|slack| start + region_length + slack);
for space in column_allocations.entry(*c).or_default().clone().free_intervals(start, end) {
let s_slack = space
.end
.map(|end| (end as isize - space.start as isize) - region_length as isize);
if let Some((slack, s_slack)) = slack.zip(s_slack) {
assert!(s_slack <= slack as isize);
}
if s_slack.unwrap_or(0) >= 0 {
let row = first_fit_region(
column_allocations,
remaining_columns,
region_length,
space.start,
s_slack.map(|s| s as usize),
);
if let Some(row) = row {
if let Some(end) = end {
assert!(row + region_length <= end);
}
column_allocations.get_mut(c).unwrap().0.insert(AllocatedRegion {
start: row,
length: region_length,
});
return Some(row);
}
}
}
None
}
fn slot_in(
region_shapes: Vec<RegionShape>,
) -> (Vec<(RegionStart, RegionShape)>, CircuitAllocations) {
let mut column_allocations: CircuitAllocations = Default::default();
let regions = region_shapes
.into_iter()
.map(|region| {
let mut region_columns: Vec<_> = region.columns().iter().cloned().collect();
region_columns.sort_unstable();
let region_start = first_fit_region(
&mut column_allocations,
®ion_columns,
region.row_count(),
0,
None,
)
.expect("We can always fit a region somewhere");
(region_start.into(), region)
})
.collect();
(regions, column_allocations)
}
pub fn slot_in_biggest_advice_first(
region_shapes: Vec<RegionShape>,
) -> (Vec<RegionStart>, CircuitAllocations) {
let mut sorted_regions: Vec<_> = region_shapes.into_iter().collect();
let sort_key = |shape: &RegionShape| {
let advice_cols = shape
.columns()
.iter()
.filter(|c| match c {
RegionColumn::Column(c) => matches!(c.column_type(), Any::Advice(_)),
_ => false,
})
.count();
advice_cols * shape.row_count()
};
sorted_regions.sort_by_cached_key(sort_key);
sorted_regions.reverse();
let (mut regions, column_allocations) = slot_in(sorted_regions);
regions.sort_unstable_by_key(|(_, region)| region.region_index().0);
let regions = regions.into_iter().map(|(start, _)| start).collect();
(regions, column_allocations)
}
#[test]
fn test_slot_in() {
use crate::plonk::Column;
let regions = vec![
RegionShape {
region_index: 0.into(),
columns: vec![Column::new(0, Any::advice()), Column::new(1, Any::advice())]
.into_iter()
.map(|a| a.into())
.collect(),
row_count: 15,
},
RegionShape {
region_index: 1.into(),
columns: vec![Column::new(2, Any::advice())].into_iter().map(|a| a.into()).collect(),
row_count: 10,
},
RegionShape {
region_index: 2.into(),
columns: vec![Column::new(2, Any::advice()), Column::new(0, Any::advice())]
.into_iter()
.map(|a| a.into())
.collect(),
row_count: 10,
},
];
assert_eq!(
slot_in(regions).0.into_iter().map(|(i, _)| i).collect::<Vec<_>>(),
vec![0.into(), 0.into(), 15.into()]
);
}