use std::{
cmp::Ordering,
collections::{BinaryHeap, HashSet},
ops::Deref,
};
use bellframe::Row;
use bit_vec::BitVec;
use datasize::DataSize;
use itertools::Itertools;
use crate::{
builder::{MusicTypeId, SpliceStyle},
composition::{Composition, PathElem},
graph::LinkSide,
group::PartHead,
utils::{
counts::Counts,
div_rounding_up,
lengths::{PerPartLength, TotalLength},
Score,
},
};
use super::{
graph::{ChunkIdx, Graph},
path::{PathId, Paths},
Search,
};
#[derive(Debug, Clone)]
pub(super) struct CompPrefix {
score: Score,
length: TotalLength,
inner: Box<PrefixInner>,
}
#[derive(Debug, Clone)]
pub(super) struct PrefixInner {
path: PathId,
next_link_side: LinkSide<ChunkIdx>,
unringable_chunks: BitVec,
part_head: PartHead,
contiguous_duffer: PerPartLength,
total_duffer: TotalLength,
method_counts: Counts,
}
impl CompPrefix {
pub fn starts(graph: &Graph, paths: &mut Paths) -> BinaryHeap<Self> {
let all_chunks_ringable = BitVec::from_elem(graph.chunks.len(), false);
graph
.starts
.iter_enumerated()
.map(|(start_idx, &(chunk_idx, _link_id, part_head))| {
let chunk = &graph.chunks[chunk_idx];
Self {
score: Score::from(0.0), length: TotalLength::ZERO,
inner: Box::new(PrefixInner {
path: paths.add_start(start_idx),
next_link_side: LinkSide::Chunk(chunk_idx),
unringable_chunks: all_chunks_ringable.clone(),
part_head,
contiguous_duffer: PerPartLength::ZERO, total_duffer: TotalLength::ZERO,
method_counts: Counts::zeros(chunk.method_counts.len()),
}),
}
})
.collect()
}
pub fn size(&self) -> usize {
std::mem::size_of::<Self>()
+ std::mem::size_of::<PrefixInner>()
+ div_rounding_up(self.inner.unringable_chunks.len(), 8)
+ self.inner.method_counts.estimate_heap_size()
}
pub fn avg_score(&self) -> Score {
self.score / self.length.as_usize() as f32
}
pub fn path_head(&self) -> PathId {
self.path
}
pub fn length(&self) -> TotalLength {
self.length
}
}
impl PartialEq for CompPrefix {
fn eq(&self, other: &Self) -> bool {
self.avg_score() == other.avg_score()
}
}
impl Eq for CompPrefix {}
impl PartialOrd for CompPrefix {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for CompPrefix {
fn cmp(&self, other: &Self) -> Ordering {
self.avg_score().cmp(&other.avg_score())
}
}
impl Deref for CompPrefix {
type Target = PrefixInner;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl CompPrefix {
#[allow(clippy::let_unit_value)]
pub(super) fn expand(
self,
search: &Search,
paths: &mut Paths,
frontier: &mut BinaryHeap<Self>,
) -> Option<Composition> {
let chunk_idx = match self.next_link_side {
LinkSide::Chunk(chunk_idx) => chunk_idx,
LinkSide::StartOrEnd => return self.check_comp(search, paths),
};
let chunk = &search.graph.chunks[chunk_idx];
let CompPrefix {
inner,
mut length,
mut score,
} = self;
let PrefixInner {
path,
next_link_side: _,
mut unringable_chunks,
mut method_counts,
part_head, mut contiguous_duffer,
mut total_duffer,
} = *inner;
length += chunk.total_length;
if chunk.duffer {
contiguous_duffer += chunk.per_part_length;
total_duffer += chunk.total_length;
} else {
contiguous_duffer = PerPartLength::ZERO;
}
score += chunk.score;
method_counts += &chunk.method_counts;
unringable_chunks.or(&chunk.falseness);
let succ_iter = chunk.succs.iter_enumerated();
#[allow(unused_variables)]
let chunk = ();
let max_length = *search.refined_ranges.length.end();
for (succ_idx, link) in succ_iter {
let part_head = part_head * link.ph_rotation;
let score = score + link.score;
if let LinkSide::Chunk(succ_idx) = link.next {
let succ_chunk = &search.graph.chunks[succ_idx];
let length_after_succ = length + succ_chunk.total_length;
let method_counts_after_chunk = &method_counts + &succ_chunk.method_counts;
if let Some(duffer_limit) = search.query.max_contiguous_duffer {
if succ_chunk.duffer {
let min_contiguous_duffer = contiguous_duffer
+ succ_chunk.per_part_length
+ succ_chunk.min_dist_to_non_duffer;
if min_contiguous_duffer > duffer_limit {
continue; }
}
}
if let Some(max_total_duffer) = search.query.max_total_duffer {
let succ_duffer_len = match succ_chunk.duffer {
false => TotalLength::ZERO,
true => succ_chunk.total_length,
};
let total_duffer_including_succ = total_duffer
+ succ_duffer_len
+ succ_chunk
.min_dist_to_non_duffer
.as_total(&search.query.part_head_group);
if total_duffer_including_succ > max_total_duffer {
continue; }
}
if length_after_succ + succ_chunk.min_len_to_rounds > max_length {
continue; }
if unringable_chunks.get(succ_idx.index()).unwrap() {
continue; }
if !method_counts_after_chunk.is_feasible(
(max_length - length_after_succ).as_usize(),
search.refined_ranges.method_counts.as_raw_slice(),
) {
continue; }
}
frontier.push(CompPrefix {
inner: Box::new(PrefixInner {
path: paths.add(path, succ_idx),
next_link_side: link.next,
unringable_chunks: unringable_chunks.clone(),
part_head,
contiguous_duffer,
total_duffer,
method_counts: method_counts.clone(),
}),
score,
length,
});
}
None
}
}
impl CompPrefix {
fn check_comp(&self, search: &Search, paths: &Paths) -> Option<Composition> {
assert!(self.next_link_side.is_start_or_end());
if !search.refined_ranges.length.contains(&self.length) {
return None; }
if !self
.method_counts
.is_feasible(0, search.refined_ranges.method_counts.as_raw_slice())
{
return None; }
if !search.query.part_head_group.is_generator(self.part_head) {
return None; }
let (path, music_counts, contiguous_duffer_lengths) = self.flattened_path(search, paths);
let first_elem = path.first().expect("Must have at least one chunk");
let last_elem = path.last().expect("Must have at least one chunk");
let mut score = self.score;
let is_splice = first_elem.method != last_elem.method
|| first_elem.start_sub_lead_idx != last_elem.end_sub_lead_idx(&search.query);
let splice_over_part_head = search.query.is_multipart() && is_splice;
if splice_over_part_head {
let start_labels = search.query.methods[first_elem.method]
.first_lead()
.get_annot(first_elem.start_sub_lead_idx)
.unwrap();
let end_labels = search.query.methods[last_elem.method]
.first_lead()
.get_annot(last_elem.end_sub_lead_idx(&search.query))
.unwrap();
let is_valid_splice = start_labels.iter().any(|label| end_labels.contains(label));
if !is_valid_splice {
return None;
}
if search.query.splice_style == SpliceStyle::Calls && last_elem.ends_with_plain() {
return None;
}
score += search.query.splice_weight * (search.query.num_parts() - 1) as f32;
}
let comp = Composition {
path,
part_head: self.part_head,
length: self.length,
method_counts: self.method_counts.clone(),
music_counts: search
.query
.music_types
.iter_enumerated()
.zip_eq(music_counts.iter())
.map(|((index, _), count)| (MusicTypeId { index }, *count))
.collect(),
total_score: score,
contiguous_duffer_lengths,
total_duffer: self.total_duffer,
query: search.query.clone(),
};
if search.query.require_truth {
let mut rows_so_far = HashSet::<&Row>::with_capacity(comp.length());
for row in comp.rows().rows() {
if !rows_so_far.insert(row) {
panic!("Generated false composition ({})", comp.call_string());
}
}
}
Some(comp)
}
fn flattened_path(
&self,
search: &Search,
paths: &Paths,
) -> (Vec<PathElem>, Counts, Vec<PerPartLength>) {
let (start_idx, succ_idxs) = paths.flatten(self.path);
let mut path = Vec::<PathElem>::new();
let mut music_counts = Counts::zeros(search.query.music_types.len());
let mut duffer_lengths = Vec::<PerPartLength>::new();
let (start_chunk_idx, _start_link, mut part_head_elem) = search.graph.starts[start_idx];
let mut next_link_side = LinkSide::Chunk(start_chunk_idx);
let mut was_last_chunk_duffer = false; let mut consecutive_duffer = PerPartLength::ZERO;
for succ_idx in succ_idxs {
let next_chunk_idx = match next_link_side {
LinkSide::Chunk(idx) => idx,
LinkSide::StartOrEnd => unreachable!(),
};
let chunk = &search.graph.chunks[next_chunk_idx];
let succ_link = &chunk.succs[succ_idx];
music_counts += &chunk.music_counts;
match (was_last_chunk_duffer, chunk.duffer) {
(false, true) => consecutive_duffer = chunk.per_part_length, (true, true) => consecutive_duffer += chunk.per_part_length, (true, false) => duffer_lengths.push(consecutive_duffer), (false, false) => {
if succ_link.call.is_some() {
duffer_lengths.push(PerPartLength::ZERO);
}
}
}
let method_idx = chunk.id.row_idx.method;
let sub_lead_idx = chunk.id.row_idx.sub_lead_idx;
path.push(PathElem {
start_row: search.query.part_head_group.get_row(part_head_elem)
* chunk.id.lead_head.as_ref()
* search.query.methods[method_idx].row_in_plain_lead(sub_lead_idx),
method: method_idx,
start_sub_lead_idx: sub_lead_idx,
length: chunk.per_part_length,
call: succ_link.call,
});
next_link_side = succ_link.next;
was_last_chunk_duffer = chunk.duffer;
part_head_elem = part_head_elem * succ_link.ph_rotation;
}
assert!(next_link_side.is_start_or_end());
(path, music_counts, duffer_lengths)
}
}