use std::{
cmp, fmt,
num::{NonZeroU8, NonZeroUsize},
ops::{Add, AddAssign, IndexMut, RangeBounds, Sub, SubAssign},
};
use crate::{types::Lit, utils::unreachable_none};
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[repr(transparent)]
pub struct NodeId(pub usize);
impl fmt::Display for NodeId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "#{}", self.0)
}
}
impl Add<usize> for NodeId {
type Output = NodeId;
fn add(self, rhs: usize) -> Self::Output {
NodeId(self.0 + rhs)
}
}
impl Add for &NodeId {
type Output = NodeId;
fn add(self, rhs: Self) -> Self::Output {
NodeId(self.0 + rhs.0)
}
}
impl Add<usize> for &NodeId {
type Output = NodeId;
fn add(self, rhs: usize) -> Self::Output {
NodeId(self.0 + rhs)
}
}
impl AddAssign for NodeId {
fn add_assign(&mut self, rhs: Self) {
self.0 += rhs.0;
}
}
impl AddAssign<usize> for NodeId {
fn add_assign(&mut self, rhs: usize) {
self.0 += rhs;
}
}
impl Sub for NodeId {
type Output = usize;
fn sub(self, rhs: Self) -> Self::Output {
self.0 - rhs.0
}
}
impl Sub<usize> for NodeId {
type Output = NodeId;
fn sub(self, rhs: usize) -> Self::Output {
NodeId(self.0 - rhs)
}
}
impl Sub for &NodeId {
type Output = NodeId;
fn sub(self, rhs: Self) -> Self::Output {
NodeId(self.0 - rhs.0)
}
}
impl Sub<usize> for &NodeId {
type Output = NodeId;
fn sub(self, rhs: usize) -> Self::Output {
NodeId(self.0 - rhs)
}
}
impl SubAssign for NodeId {
fn sub_assign(&mut self, rhs: Self) {
self.0 -= rhs.0;
}
}
impl SubAssign<usize> for NodeId {
fn sub_assign(&mut self, rhs: usize) {
self.0 -= rhs;
}
}
#[allow(clippy::len_without_is_empty)]
pub trait NodeLike {
type ValIter: DoubleEndedIterator<Item = usize>;
fn is_leaf(&self) -> bool {
self.len() == 1
}
fn max_val(&self) -> usize;
fn len(&self) -> usize;
fn vals<R>(&self, range: R) -> Self::ValIter
where
R: RangeBounds<usize>;
fn right(&self) -> Option<NodeCon>;
fn left(&self) -> Option<NodeCon>;
fn depth(&self) -> usize;
fn internal<Db>(left: NodeCon, right: NodeCon, db: &Db) -> Self
where
Db: NodeById<Node = Self>;
fn leaf(lit: Lit) -> Self;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct NodeCon {
pub id: NodeId,
pub(crate) offset: usize,
pub(crate) divisor: NonZeroU8,
pub(crate) multiplier: NonZeroUsize,
pub(crate) len_limit: Option<NonZeroUsize>,
}
impl NodeCon {
#[must_use]
pub fn full(id: NodeId) -> NodeCon {
NodeCon {
id,
offset: 0,
divisor: unreachable_none!(NonZeroU8::new(1)),
multiplier: unreachable_none!(NonZeroUsize::new(1)),
len_limit: None,
}
}
#[must_use]
pub fn weighted(id: NodeId, weight: usize) -> NodeCon {
NodeCon {
id,
offset: 0,
divisor: unreachable_none!(NonZeroU8::new(1)),
multiplier: weight.try_into().unwrap(),
len_limit: None,
}
}
#[cfg(any(test, feature = "internals"))]
#[must_use]
pub fn offset_weighted(id: NodeId, offset: usize, weight: usize) -> NodeCon {
NodeCon {
id,
offset,
divisor: unreachable_none!(NonZeroU8::new(1)),
multiplier: weight.try_into().unwrap(),
len_limit: None,
}
}
#[cfg(any(test, feature = "internals"))]
#[must_use]
pub fn single(id: NodeId, output: usize, weight: usize) -> NodeCon {
NodeCon {
id,
offset: output - 1,
divisor: unreachable_none!(NonZeroU8::new(1)),
multiplier: weight.try_into().unwrap(),
len_limit: NonZeroUsize::new(1),
}
}
#[cfg(any(test, feature = "internals"))]
#[must_use]
pub fn limited(id: NodeId, offset: usize, n_lits: usize, weight: usize) -> NodeCon {
assert_ne!(n_lits, 0);
NodeCon {
id,
offset,
divisor: unreachable_none!(NonZeroU8::new(1)),
multiplier: weight.try_into().unwrap(),
len_limit: NonZeroUsize::new(n_lits),
}
}
#[inline]
#[cfg(feature = "internals")]
#[must_use]
pub fn reweight(self, weight: usize) -> NodeCon {
NodeCon {
multiplier: weight.try_into().unwrap(),
..self
}
}
#[inline]
#[must_use]
pub fn offset(&self) -> usize {
self.offset
}
#[inline]
#[must_use]
pub fn divisor(&self) -> usize {
let div: u8 = self.divisor.into();
div.into()
}
#[inline]
#[must_use]
pub fn multiplier(&self) -> usize {
self.multiplier.into()
}
#[inline]
#[must_use]
pub fn map(&self, val: usize) -> usize {
if let Some(limit) = self.len_limit {
cmp::min((val - self.offset()) / self.divisor(), limit.into()) * self.multiplier()
} else {
(val - self.offset()) / self.divisor() * self.multiplier()
}
}
#[inline]
#[must_use]
pub fn rev_map(&self, val: usize) -> usize {
if let Some(limit) = self.len_limit {
match cmp::min(val / self.multiplier(), limit.into()) * self.divisor() {
0 => 0,
x => x + self.offset(),
}
} else {
val / self.multiplier() * self.divisor() + self.offset()
}
}
#[inline]
#[must_use]
pub fn rev_map_round_up(&self, mut val: usize) -> usize {
if let Some(limit) = self.len_limit {
if (val - 1) / self.multiplier() >= limit.into() {
return (Into::<usize>::into(limit) + 1) * self.divisor() + self.offset();
}
}
if val % self.multiplier() > 0 {
val += self.multiplier();
}
self.rev_map(val)
}
#[inline]
#[must_use]
pub fn is_possible(&self, val: usize) -> bool {
if let Some(limit) = self.len_limit {
val % self.multiplier() == 0 && val / self.multiplier() <= limit.into()
} else {
val % self.multiplier() == 0
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
#[error(
"node {referencing} from after the drain range references node {referenced} in the drain range"
)]
pub struct DrainError {
pub referencing: NodeId,
pub referenced: NodeId,
}
#[allow(dead_code)]
pub trait NodeById: IndexMut<NodeId, Output = Self::Node> {
type Node: NodeLike;
fn insert(&mut self, node: Self::Node) -> NodeId;
type Iter<'own>: Iterator<Item = &'own Self::Node>
where
Self: 'own;
fn iter(&self) -> Self::Iter<'_>;
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
fn con_len(&self, con: NodeCon) -> usize {
let len = (self[con.id].len() - con.offset()) / con.divisor();
if let Some(limit) = con.len_limit {
cmp::min(len, limit.into())
} else {
len
}
}
type Drain<'own>: Iterator<Item = Self::Node>
where
Self: 'own;
fn drain<R: RangeBounds<NodeId>>(&mut self, range: R) -> Result<Self::Drain<'_>, DrainError>;
fn lit_tree(&mut self, lits: &[Lit]) -> NodeId
where
Self: Sized,
{
debug_assert!(!lits.is_empty());
if lits.len() == 1 {
return self.insert(Self::Node::leaf(lits[0]));
}
let split = lits.len() / 2;
let lid = self.lit_tree(&lits[..split]);
let rid = self.lit_tree(&lits[split..]);
self.insert(Self::Node::internal(
NodeCon::full(lid),
NodeCon::full(rid),
self,
))
}
fn weighted_lit_tree(&mut self, lits: &[(Lit, usize)]) -> NodeCon
where
Self: Sized,
{
debug_assert!(!lits.is_empty());
let mut seg_begin = 0;
let mut seg_end = 0;
let mut cons = vec![];
loop {
seg_end += 1;
if seg_end < lits.len() && lits[seg_end].1 == lits[seg_begin].1 {
continue;
}
let seg: Vec<_> = lits[seg_begin..seg_end]
.iter()
.map(|(lit, _)| *lit)
.collect();
let id = self.lit_tree(&seg);
cons.push(NodeCon::weighted(id, lits[seg_begin].1));
seg_begin = seg_end;
if seg_end >= lits.len() {
break;
}
}
self.merge_balanced(&cons)
}
fn merge(&mut self, cons: &[NodeCon]) -> NodeCon
where
Self: Sized,
{
debug_assert!(!cons.is_empty());
if cons.len() == 1 {
return cons[0];
}
let split = cons.len() / 2;
let lcon = self.merge(&cons[..split]);
let rcon = self.merge(&cons[split..]);
NodeCon::full(self.insert(Self::Node::internal(lcon, rcon, self)))
}
fn merge_balanced(&mut self, cons: &[NodeCon]) -> NodeCon
where
Self: Sized,
{
debug_assert!(!cons.is_empty());
if cons.len() == 1 {
return cons[0];
}
let total_sum = cons.iter().fold(0, |sum, &con| sum + self.con_len(con));
let mut split = 1;
let mut lsum = self.con_len(cons[0]);
while lsum + self.con_len(cons[split]) < total_sum / 2 {
lsum += self.con_len(cons[split]);
split += 1;
}
let lcon = self.merge(&cons[..split]);
let rcon = self.merge(&cons[split..]);
NodeCon::full(self.insert(Self::Node::internal(lcon, rcon, self)))
}
#[cfg(feature = "internals")]
fn merge_thorough(&mut self, cons: &mut [NodeCon]) -> NodeCon
where
Self: Sized,
{
debug_assert!(!cons.is_empty());
cons.sort_unstable_by_key(NodeCon::multiplier);
let mut seg_begin = 0;
let mut seg_end = 0;
let mut merged_cons = vec![];
loop {
seg_end += 1;
if seg_end < cons.len() && cons[seg_end].multiplier() == cons[seg_begin].multiplier() {
continue;
}
if seg_end > seg_begin + 1 {
let mut seg: Vec<_> = cons[seg_begin..seg_end]
.iter()
.map(|&con| con.reweight(1))
.collect();
seg.sort_unstable_by_key(|&con| self.con_len(con));
let con = self.merge_balanced(&seg);
debug_assert_eq!(con.multiplier(), 1);
merged_cons.push(con.reweight(cons[seg_begin].multiplier()));
} else {
merged_cons.push(cons[seg_begin]);
}
seg_begin = seg_end;
if seg_end >= cons.len() {
break;
}
}
merged_cons.sort_unstable_by_key(|&con| self.con_len(con));
self.merge_balanced(&merged_cons)
}
}
#[cfg(test)]
mod tests {
use super::{NodeCon, NodeId};
#[test]
fn node_con_map_full() {
let id = NodeId(0);
let nc = NodeCon::full(id);
for val in 1..=10 {
debug_assert_eq!(nc.map(val), val);
debug_assert_eq!(nc.rev_map(val), val);
debug_assert_eq!(nc.rev_map_round_up(val), val);
}
}
#[test]
fn node_con_map_mult() {
let id = NodeId(0);
let weight = 3;
let nc = NodeCon::weighted(id, weight);
for val in 1..=10 {
debug_assert_eq!(nc.map(val), weight * val);
debug_assert_eq!(nc.rev_map(val), val / weight);
debug_assert_eq!(
nc.rev_map_round_up(val),
if val % weight == 0 {
val / weight
} else {
val / weight + 1
}
);
}
}
#[test]
fn node_con_map_div() {
let id = NodeId(0);
let div = 2;
let nc = NodeCon {
id,
offset: 0,
divisor: div.try_into().unwrap(),
multiplier: 1.try_into().unwrap(),
len_limit: None,
};
let div: usize = div.into();
for val in 1..=10 {
debug_assert_eq!(nc.map(val), val / div);
debug_assert_eq!(nc.rev_map(val), val * div);
debug_assert_eq!(nc.rev_map_round_up(val), val * div);
}
}
#[test]
fn node_con_map_offset_weighted() {
let id = NodeId(0);
let offset = 3;
let weight = 5;
let nc = NodeCon::offset_weighted(id, offset, weight);
for val in offset..=10 {
debug_assert_eq!(nc.map(val), (val - offset) * weight);
debug_assert_eq!(nc.rev_map(val), val / weight + offset);
debug_assert_eq!(
nc.rev_map_round_up(val),
if val % weight == 0 {
val / weight + offset
} else {
val / weight + offset + 1
}
);
}
}
#[test]
fn node_con_map_single() {
let id = NodeId(0);
let output = 5;
let weight = 7;
let nc = NodeCon::single(id, output, weight);
for val in output - 1..=20 {
println!("{val}");
debug_assert_eq!(nc.map(val), if val >= output { weight } else { 0 });
debug_assert_eq!(nc.rev_map(val), if val >= weight { output } else { 0 });
debug_assert_eq!(
nc.rev_map_round_up(val),
if val > weight { output + 1 } else { output }
);
}
}
#[test]
fn node_con_map_limited() {
let id = NodeId(0);
let offset = 2;
let weight = 3;
let limit = 5;
let nc = NodeCon::limited(id, offset, limit, weight);
for val in offset..=20 {
println!("{val}");
debug_assert_eq!(nc.map(val), std::cmp::min(val - offset, limit) * weight);
debug_assert_eq!(
nc.rev_map(val),
if val >= weight {
std::cmp::min(val / weight, limit) + offset
} else {
0
}
);
debug_assert_eq!(
nc.rev_map_round_up(val),
if val > weight * limit {
limit + offset + 1
} else if val % weight == 0 {
val / weight + offset
} else {
val / weight + offset + 1
}
);
}
}
}