use std::collections::{HashMap, HashSet};
use bellframe::Bell;
use datasize::DataSize;
use itertools::Itertools;
use crate::{
graph::ChunkId,
parameters::{Method, MethodIdx, MethodVec, Parameters},
utils::{div_rounding_up, lengths::PerPartLength},
};
#[derive(Debug, Clone, Copy)]
#[repr(transparent)]
struct BitIndex(usize);
type UniqueRowCount = usize;
type Chunk = u16;
const FLAGS_PER_CHUNK: usize = Chunk::BITS as usize;
#[derive(Debug, Clone)]
pub(super) struct AtwTable {
atw_weight: f32,
total_unique_row_positions: UniqueRowCount,
bitmap_chunk_multipliers: Vec<UniqueRowCount>,
bell_place_to_bitmap_index: HashMap<(Bell, u8, MethodIdx), Vec<(usize, BitIndex)>>,
}
#[derive(Debug, Clone)]
struct AtwFlag {
method_idx: MethodIdx,
sub_lead_chunk_start: usize,
sub_lead_chunk_len: PerPartLength,
bell_place_bell_pairs: Vec<(Bell, u8)>,
}
impl AtwFlag {
fn unique_row_positions(&self) -> usize {
self.sub_lead_chunk_len.as_usize() * self.bell_place_bell_pairs.len()
}
}
impl AtwTable {
pub fn new(params: &Parameters, chunk_lengths: &[(ChunkId, PerPartLength)]) -> Self {
let atw_weight = match params.atw_weight {
Some(w) => w,
None if params.require_atw => 0.0,
None => return Self::empty(),
};
let working_bells = params.working_bells();
let part_head_cycles = params
.part_head_group
.bell_cycles()
.into_iter()
.filter(|g| g.len() > 1)
.collect_vec();
let place_bell_range_boundaries: HashMap<(Bell, u8, MethodIdx), Vec<usize>> =
place_bell_range_boundaries(params, chunk_lengths);
let flags: Vec<AtwFlag> = range_boundaries_to_flags(
&working_bells,
&part_head_cycles,
params,
place_bell_range_boundaries,
);
let total_unique_row_positions =
total_unique_row_positions(&working_bells, ¶ms.methods, &flags);
let (bitmap_chunk_multipliers, flag_per_bit) = split_flags_into_bitmap_chunks(flags);
Self {
atw_weight,
bell_place_to_bitmap_index: make_bell_place_to_bitmap_index(&flag_per_bit),
total_unique_row_positions,
bitmap_chunk_multipliers,
}
}
fn empty() -> Self {
Self {
atw_weight: 0.0,
total_unique_row_positions: 1, bitmap_chunk_multipliers: Vec::new(),
bell_place_to_bitmap_index: HashMap::new(),
}
}
pub fn bitmap_for_chunk(
&self,
params: &Parameters,
id: &ChunkId,
chunk_len: PerPartLength,
) -> AtwBitmap {
let mut bitmap = self.empty_bitmap();
for (lead_head, sub_lead_range) in params.chunk_lead_regions(id, chunk_len) {
for (place, bell) in lead_head.bell_iter().enumerate() {
if let Some(bit_starts) =
self.bell_place_to_bitmap_index
.get(&(bell, place as u8, id.method))
{
let index_within_bit_starts = |sub_lead_idx: usize| -> usize {
bit_starts
.binary_search_by_key(&sub_lead_idx, |(sub_lead_idx, _)| *sub_lead_idx)
.unwrap_or_else(|x| x)
};
let start_idx = index_within_bit_starts(sub_lead_range.start);
let end_idx = index_within_bit_starts(sub_lead_range.end);
let bit_idxs = &bit_starts[start_idx..end_idx];
assert!(!bit_idxs.is_empty());
assert_eq!(bit_idxs[0].0, sub_lead_range.start);
for (_sub_lead_idx, bit_index) in bit_idxs {
bitmap.add_bit(*bit_index);
}
}
}
}
bitmap
}
pub fn atw_score(&self, bitmap: &AtwBitmap) -> f32 {
let factor = self.atw_factor(bitmap);
self.atw_weight * factor
}
pub fn atw_factor(&self, bitmap: &AtwBitmap) -> f32 {
self.unique_place_bell_rows_rung(bitmap) as f32 / self.total_unique_row_positions as f32
}
pub fn unique_place_bell_rows_rung(&self, bitmap: &AtwBitmap) -> usize {
self.bitmap_chunk_multipliers
.iter()
.zip_eq(&bitmap.chunks)
.map(|(positions_per_bit, chunk)| *positions_per_bit * chunk.count_ones() as usize)
.sum::<usize>()
}
pub fn empty_bitmap(&self) -> AtwBitmap {
AtwBitmap {
chunks: vec![0 as Chunk; self.bitmap_chunk_multipliers.len()],
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct PlaceBellRange {
pub bell: Bell,
pub place_bell: u8,
pub method_idx: MethodIdx,
pub sub_lead_idx_start: usize,
pub length: PerPartLength,
}
#[derive(Debug, Clone, DataSize)]
pub(crate) struct AtwBitmap {
chunks: Vec<Chunk>,
}
impl AtwBitmap {
pub fn union_with(&mut self, other: &Self) {
for (chunk, other_chunk) in self.chunks.iter_mut().zip_eq(&other.chunks) {
*chunk |= *other_chunk;
}
}
fn add_bit(&mut self, idx: BitIndex) {
let (chunk_idx, mask) = Self::split_idx(idx);
self.chunks[chunk_idx] |= mask;
}
fn split_idx(idx: BitIndex) -> (usize, Chunk) {
let chunk_idx = idx.0 / (Chunk::BITS as usize);
let sub_chunk_idx = idx.0 % (Chunk::BITS as usize);
(chunk_idx, 1 << sub_chunk_idx)
}
}
fn split_flags_into_bitmap_chunks(flags: Vec<AtwFlag>) -> (Vec<usize>, Vec<Option<AtwFlag>>) {
let mut chunk_multipliers = Vec::new();
let mut flag_per_bit = Vec::new();
let flag_groups = flags
.into_iter()
.into_group_map_by(AtwFlag::unique_row_positions);
for (unique_row_positions, flags) in &flag_groups {
let chunks_required = div_rounding_up(flags.len(), FLAGS_PER_CHUNK);
chunk_multipliers.extend(std::iter::repeat(*unique_row_positions).take(chunks_required));
let mut flag_iter = flags.iter().fuse();
for _ in 0..chunks_required * FLAGS_PER_CHUNK {
flag_per_bit.push(flag_iter.next().cloned());
}
}
assert_eq!(
flag_per_bit.len(),
chunk_multipliers.len() * FLAGS_PER_CHUNK
);
(chunk_multipliers, flag_per_bit)
}
fn make_bell_place_to_bitmap_index(
flag_per_bit: &[Option<AtwFlag>],
) -> HashMap<(Bell, u8, MethodIdx), Vec<(usize, BitIndex)>> {
let mut bell_place_to_index: HashMap<_, Vec<_>> = HashMap::new();
for (bit_index, flag) in flag_per_bit.iter().enumerate() {
if let Some(flag) = flag {
for &(bell, place_bell) in &flag.bell_place_bell_pairs {
bell_place_to_index
.entry((bell, place_bell, flag.method_idx))
.or_default()
.push((flag.sub_lead_chunk_start, BitIndex(bit_index)));
}
}
}
for bit_indices in bell_place_to_index.values_mut() {
bit_indices.sort_unstable_by_key(|(sub_lead_idx, _bit_idx)| *sub_lead_idx);
}
bell_place_to_index
}
fn total_unique_row_positions(
working_bells: &[Bell],
methods: &MethodVec<Method>,
flags: &[AtwFlag],
) -> usize {
let total_unique_row_positions = working_bells.len() * working_bells.len() * methods.iter().map(|m| m.lead_len()).sum::<usize>();
let unique_row_positions_in_flags = flags
.iter()
.map(AtwFlag::unique_row_positions)
.sum::<usize>();
if unique_row_positions_in_flags != total_unique_row_positions {
log::warn!("Not enough place bells can be rung for a fully atw composition.");
}
total_unique_row_positions
}
fn place_bell_range_boundaries(
params: &Parameters,
chunk_lengths: &[(ChunkId, PerPartLength)],
) -> HashMap<(Bell, u8, MethodIdx), Vec<usize>> {
let mut range_boundaries = HashMap::<(Bell, u8, MethodIdx), Vec<usize>>::new();
for (chunk_id, length) in chunk_lengths {
for (lead_head, sub_lead_range) in params.chunk_lead_regions(chunk_id, *length) {
for (place, bell) in lead_head.bell_iter().enumerate() {
range_boundaries
.entry((bell, place as u8, chunk_id.method))
.or_default()
.extend_from_slice(&[sub_lead_range.start, sub_lead_range.end]);
}
}
}
for idxs in range_boundaries.values_mut() {
idxs.sort_unstable();
idxs.dedup();
}
range_boundaries
}
fn range_boundaries_to_flags(
working_bells: &[Bell],
part_head_cycles: &[Vec<Bell>],
params: &Parameters,
range_boundaries: HashMap<(Bell, u8, MethodIdx), Vec<usize>>,
) -> Vec<AtwFlag> {
let mut flags = Vec::new();
for (method_idx, method) in params.methods.iter_enumerated() {
let bell_place_sets = bell_place_sets(working_bells, part_head_cycles, method, params);
for bell_place_set in &bell_place_sets {
let mut range_boundaries_for_set = Vec::<usize>::new();
for (bell, place_bell) in bell_place_set {
let boundaries_for_this_pair = range_boundaries
.get(&(*bell, *place_bell, method_idx))
.map(Vec::as_slice)
.unwrap_or(&[] as &[_]);
range_boundaries_for_set.extend_from_slice(boundaries_for_this_pair);
}
range_boundaries_for_set.sort_unstable();
range_boundaries_for_set.dedup();
for (sub_lead_chunk_start, sub_lead_chunk_end) in
range_boundaries_for_set.into_iter().tuple_windows()
{
flags.push(AtwFlag {
method_idx,
sub_lead_chunk_start,
sub_lead_chunk_len: PerPartLength::new(
sub_lead_chunk_end - sub_lead_chunk_start,
),
bell_place_bell_pairs: bell_place_set.clone(),
});
}
}
}
flags
}
fn bell_place_sets(
working_bells: &[Bell],
part_head_cycles: &[Vec<Bell>],
method: &Method,
params: &Parameters,
) -> Vec<Vec<(Bell, u8)>> {
let mut bells_left_to_track = working_bells.iter().copied().collect::<HashSet<_>>();
let mut bell_place_sets = Vec::<Vec<(Bell, u8)>>::new();
for cycle in part_head_cycles {
for place_bell in working_bells {
bell_place_sets.push(
cycle
.iter()
.map(|bell| (*bell, place_bell.index_u8()))
.collect_vec(),
);
}
for bell in cycle {
bells_left_to_track.remove(bell);
}
}
let lh_masks = method.allowed_lead_head_masks(params);
if lh_masks.len() == method.lead_head().order() {
for lead_mask in &lh_masks {
let mut bell_place_pairs = Vec::new();
for (place, bell) in lead_mask.bells().enumerate() {
if let Some(bell) = bell {
if bells_left_to_track.contains(&bell) {
bell_place_pairs.push((bell, place as u8));
}
}
}
bell_place_sets.push(bell_place_pairs);
}
for bell in lh_masks[0].bells().flatten() {
bells_left_to_track.remove(&bell);
}
}
for bell in bells_left_to_track {
for place_bell in working_bells {
bell_place_sets.push(vec![(bell, place_bell.index_u8())]);
}
}
bell_place_sets
}