use std::{
cmp::Reverse,
collections::{BTreeMap, BinaryHeap, HashMap, HashSet},
ops::RangeInclusive,
time::Instant,
};
use bellframe::Mask;
use itertools::Itertools;
use crate::{
graph::{ChunkId, Graph, LinkSide, RowIdx},
parameters::{MethodIdx, MethodVec, OptionalRangeInclusive, Parameters},
utils::lengths::TotalLength,
};
const METHOD_COUNT_RELAX_FACTOR: f32 = 0.1;
#[derive(Debug)]
pub(crate) struct RefinedRanges {
pub length: RangeInclusive<TotalLength>,
pub method_counts: MethodVec<RangeInclusive<TotalLength>>,
}
pub(crate) fn prove_lengths(graph: &Graph, params: &Parameters) -> crate::Result<RefinedRanges> {
log::debug!("Proving lengths");
let possible_lengths = possible_lengths(graph, params);
let refined_len_range = match matching_lengths(&possible_lengths, ¶ms.length) {
LengthMatches {
range: Some(range), ..
} => range,
LengthMatches {
next_smaller,
range: None,
next_larger,
} => {
return Err(crate::Error::UnachievableLength {
requested_range: params.length.clone(),
next_shorter_len: next_smaller.map(TotalLength::as_usize),
next_longer_len: next_larger.map(TotalLength::as_usize),
});
}
};
log::debug!(
" Total length bounded to {}..={}",
refined_len_range.start(),
refined_len_range.end(),
);
let start = Instant::now();
let possible_lengths_by_method = params
.methods
.iter_enumerated()
.map(|(idx, method)| possible_method_counts(idx, method, graph, params))
.collect::<MethodVec<_>>();
let method_bounds_min = method_bounds(params, &refined_len_range, Bound::Min);
let method_bounds_max = method_bounds(params, &refined_len_range, Bound::Max);
let mut refined_method_counts = MethodVec::new();
for (method_idx, possible_lengths) in possible_lengths_by_method.into_iter_enumerated() {
let min_bound = method_bounds_min[method_idx];
let max_bound = method_bounds_max[method_idx];
let method = ¶ms.methods[method_idx];
let refined_counts = refine_method_counts(min_bound, max_bound, &possible_lengths, method)?;
refined_method_counts.push(refined_counts);
}
log::debug!(" Method count ranges computed in {:.2?}", start.elapsed());
if params.is_spliced() {
print_method_counts(&refined_method_counts, params);
}
check_final_bounds(&refined_method_counts, &refined_len_range)?;
Ok(RefinedRanges {
length: refined_len_range,
method_counts: refined_method_counts,
})
}
fn possible_lengths(graph: &Graph, params: &Parameters) -> Vec<TotalLength> {
let start = Instant::now();
let simple_graph = compute_simplified_graph(params, graph);
log::debug!(" Simplified graph generated in {:.2?}", start.elapsed());
let start = Instant::now();
let mut frontier = simple_graph
.starts
.iter()
.map(|start| Reverse((TotalLength::ZERO, LinkSide::Chunk(start))))
.collect::<BinaryHeap<_>>();
let mut total_lengths = Vec::<TotalLength>::new();
let mut last_item = None;
while let Some(Reverse(item @ (length, next_link_side))) = frontier.pop() {
if std::mem::replace(&mut last_item, Some(item)) == Some(item) {
continue;
}
match next_link_side {
LinkSide::StartOrEnd => {
if total_lengths.last() != Some(&length) {
total_lengths.push(length);
}
if length > params.max_length() {
break;
}
}
LinkSide::Chunk(next_chunk) => {
if let Some(succs) = simple_graph.successors.get(next_chunk) {
for (chunk_length, succ) in succs {
frontier.push(Reverse((length + *chunk_length, succ.as_ref())));
}
}
}
}
}
log::debug!(" Lengths computed in {:.2?}", start.elapsed());
total_lengths
}
struct SimpleGraph {
starts: HashSet<SimpleChunk>,
successors: HashMap<SimpleChunk, HashSet<(TotalLength, LinkSide<SimpleChunk>)>>,
}
type SimpleChunk = (RowIdx, Mask);
fn compute_simplified_graph(params: &Parameters, graph: &Graph) -> SimpleGraph {
let allowed_lead_masks: MethodVec<_> = params
.methods
.iter()
.map(|m| m.allowed_lead_head_masks(params))
.collect();
let get_simple_chunks = |chunk_id: &ChunkId| -> Vec<SimpleChunk> {
let mut simple_chunks = Vec::new();
for mask in &allowed_lead_masks[chunk_id.row_idx.method] {
if mask.matches(&chunk_id.lead_head) {
simple_chunks.push((chunk_id.row_idx, mask.clone()));
}
}
simple_chunks
};
let mut starts = HashSet::<SimpleChunk>::new();
for (_start_link, start_chunk_id) in &graph.starts {
starts.extend(get_simple_chunks(start_chunk_id));
}
let mut successors =
HashMap::<SimpleChunk, HashSet<(TotalLength, LinkSide<SimpleChunk>)>>::new();
for (id, chunk) in &graph.chunks {
for simple_chunk in get_simple_chunks(id) {
let successors = successors.entry(simple_chunk).or_default();
for (_id, succ_link) in chunk.succ_links(graph) {
match &succ_link.to {
LinkSide::StartOrEnd => {
successors.insert((chunk.total_length, LinkSide::StartOrEnd));
}
LinkSide::Chunk(succ_id) => {
for succ_simple_chunk in get_simple_chunks(succ_id) {
successors
.insert((chunk.total_length, LinkSide::Chunk(succ_simple_chunk)));
}
}
}
}
}
}
SimpleGraph { starts, successors }
}
fn possible_method_counts(
method_idx: MethodIdx,
method: &crate::parameters::Method,
graph: &Graph,
params: &Parameters,
) -> Vec<TotalLength> {
log::trace!("Computing method counts for {}:", method.shorthand());
let mut start_counts = HashSet::new();
let mut end_counts = HashSet::new();
let mut interior_counts = HashSet::new();
let mut start_end_counts = HashSet::new(); for (id, chunk) in &graph.chunks {
let only_starts = chunk.pred_links(graph).all(|(_id, link)| link.is_start());
let any_starts = chunk.pred_links(graph).any(|(_id, link)| link.is_start());
let only_ends = chunk.succ_links(graph).all(|(_id, link)| link.is_end());
let any_ends = chunk.succ_links(graph).any(|(_id, link)| link.is_end());
if !only_ends && any_ends {
end_counts.insert(TotalLength::ZERO);
}
if id.row_idx.method != method_idx {
if any_starts {
start_counts.insert(TotalLength::ZERO);
}
if any_ends {
end_counts.insert(TotalLength::ZERO);
}
continue; }
let count_group = match (only_starts, only_ends) {
(true, true) => &mut start_end_counts,
(true, false) => &mut start_counts,
(false, true) => &mut end_counts,
(false, false) => &mut interior_counts,
};
count_group.insert(chunk.total_length);
}
log::trace!(" Start counts: {:?}", start_counts);
log::trace!(" Interior counts: {:?}", interior_counts);
log::trace!(" End counts: {:?}", end_counts);
log::trace!(" Start & end counts: {:?}", start_end_counts);
let mut counts = Vec::new();
let mut frontier = start_counts
.iter()
.cartesian_product(&end_counts)
.map(|(s, e)| Reverse(*s + *e))
.collect::<BinaryHeap<_>>();
while let Some(Reverse(count)) = frontier.pop() {
if counts.last() == Some(&count) {
continue; }
counts.push(count);
if count > params.max_length() {
break; }
for l in &interior_counts {
frontier.push(Reverse(count + *l));
}
}
for c in start_end_counts {
let idx = counts.binary_search(&c).unwrap_or_else(|idx| idx);
counts.insert(idx, c);
}
if counts.is_empty() {
counts.push(TotalLength::ZERO);
}
log::trace!(" Final counts: {:?}", counts);
counts
}
fn method_bounds(
params: &Parameters,
total_len_range: &RangeInclusive<TotalLength>,
bound: Bound,
) -> MethodVec<(BoundType, TotalLength)> {
let get_bound = |range: OptionalRangeInclusive| -> Option<usize> {
match bound {
Bound::Min => range.min,
Bound::Max => range.max,
}
};
let total_len_bound = match bound {
Bound::Min => total_len_range.start().as_usize() as f32,
Bound::Max => total_len_range.end().as_usize() as f32,
};
let total_method_weight = params
.methods
.iter()
.filter(|m| get_bound(m.count_range).is_none())
.map(|m| (m.lead_len() as f32).sqrt())
.sum::<f32>();
let method_bounds = params
.methods
.iter()
.map(|m| match get_bound(m.count_range) {
Some(count) => (BoundType::Explicit, TotalLength::new(count)),
None => {
let proportion_of_unbound_methods =
(m.lead_len() as f32).sqrt() / total_method_weight;
let preferred_bound = total_len_bound * proportion_of_unbound_methods;
let f = 1.0 + METHOD_COUNT_RELAX_FACTOR;
let rounded_bound = match bound {
Bound::Min => (preferred_bound / f - 1e-3).floor() as usize,
Bound::Max => (preferred_bound * f + 1e-3).ceil() as usize,
};
(BoundType::Preferred, TotalLength::new(rounded_bound))
}
})
.collect::<MethodVec<_>>();
method_bounds
}
fn refine_method_counts(
(min_type, mut min_len): (BoundType, TotalLength),
(max_type, mut max_len): (BoundType, TotalLength),
possible_lengths: &[TotalLength],
method: &crate::parameters::Method,
) -> crate::Result<RangeInclusive<TotalLength>> {
use BoundType::{Explicit as Expl, Preferred as Pref};
log::trace!("Refining method counts for {}", method.shorthand());
log::trace!(
" initial bounds: {}{} ..= {}{}",
min_len,
if min_type == Expl { " (explicit)" } else { "" },
max_len,
if max_type == Expl { " (explicit)" } else { "" },
);
if min_len > max_len {
match (min_type, max_type) {
(Expl, Pref) => max_len = min_len,
(Pref, Expl) => min_len = max_len,
(Expl, Expl) => {
unreachable!("Don't make the min range bigger than max range, you silly billy")
}
(Pref, Pref) => unreachable!(),
}
}
let matching_lengths = matching_lengths(possible_lengths, &(min_len..=max_len));
log::trace!(
" Matching lengths: {} < {}{} <= {} <= {}{} < {}",
match matching_lengths.next_smaller {
Some(len) => format!("[.., {}]", len),
None => "[]".to_owned(),
},
min_len,
if min_type == Expl { "(e)" } else { "" },
match &matching_lengths.range {
Some(range) => format!("[ {}..={} ]", range.start(), range.end()),
None => "[]".to_owned(),
},
max_len,
if max_type == Expl { "(e)" } else { "" },
match matching_lengths.next_larger {
Some(len) => format!("[{}, ..]", len),
None => "[]".to_owned(),
}
);
let refined_range = match matching_lengths {
LengthMatches {
range: Some(range), ..
} => range,
LengthMatches {
next_smaller,
range: None,
next_larger,
} => {
let refined_min = match min_type {
Expl => None, Pref => next_smaller, };
let refined_max = match max_type {
Expl => None, Pref => next_larger, };
match (refined_min, refined_max) {
(Some(min), Some(max)) => min..=max,
(Some(len), None) | (None, Some(len)) => len..=len,
(None, None) => {
assert_ne!((min_type, max_type), (Pref, Pref));
return Err(crate::Error::UnachievableMethodCount {
method_name: method.title(),
requested_range: method.count_range,
next_shorter_len: next_smaller.map(TotalLength::as_usize),
next_longer_len: next_larger.map(TotalLength::as_usize),
});
}
}
}
};
Ok(refined_range)
}
fn print_method_counts(
refined_method_counts: &MethodVec<RangeInclusive<TotalLength>>,
params: &Parameters,
) {
let mut methods_by_count_ranges = BTreeMap::<(TotalLength, TotalLength), Vec<MethodIdx>>::new();
for (idx, range) in refined_method_counts.iter_enumerated() {
methods_by_count_ranges
.entry((*range.start(), *range.end()))
.or_default()
.push(idx);
}
for ((min, max), methods) in methods_by_count_ranges {
let row_string = if min == max {
format!("exactly {min}")
} else {
format!("{min} to {max}")
};
let methods_string = params.method_list_string(&methods);
log::info!("Requiring {row_string} rows of {methods_string}");
}
}
enum Bound {
Min,
Max,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum BoundType {
Explicit,
Preferred,
}
fn check_final_bounds(
method_counts: &MethodVec<RangeInclusive<TotalLength>>,
length_range: &RangeInclusive<TotalLength>,
) -> crate::Result<()> {
let min_total_method_count = method_counts
.iter()
.map(|range| *range.start())
.sum::<TotalLength>();
let max_total_method_count = method_counts
.iter()
.map(|range| *range.end())
.sum::<TotalLength>();
let min_length = *length_range.start();
let max_length = *length_range.end();
if max_total_method_count < min_length {
return Err(crate::Error::TooLittleMethodCount {
max_total_method_count: max_total_method_count.as_usize(),
min_length: min_length.as_usize(),
});
}
if min_total_method_count > max_length {
return Err(crate::Error::TooMuchMethodCount {
min_total_method_count: min_total_method_count.as_usize(),
max_length: max_length.as_usize(),
});
}
Ok(())
}
#[derive(Debug)]
struct LengthMatches {
next_smaller: Option<TotalLength>,
range: Option<RangeInclusive<TotalLength>>,
next_larger: Option<TotalLength>,
}
fn matching_lengths(lengths: &[TotalLength], range: &RangeInclusive<TotalLength>) -> LengthMatches {
let mut next_smaller = None;
let mut min = None;
let mut max = None;
let mut next_larger = None;
for &l in lengths {
if l < *range.start() {
next_smaller = Some(l);
} else if l <= *range.end() {
max = Some(l);
if min.is_none() {
min = Some(l);
}
} else {
if next_larger.is_none() {
next_larger = Some(l);
}
}
}
LengthMatches {
next_smaller,
range: match (min, max) {
(Some(min), Some(max)) => Some(min..=max),
(None, None) => None,
_ => unreachable!("max/min values must be found together"),
},
next_larger,
}
}