use std::{
cmp, fmt,
num::{NonZeroU8, NonZeroUsize},
ops::{self, Add, AddAssign, IndexMut, RangeBounds, Sub, SubAssign},
};
use crate::{types::Lit, utils::unreachable_none};
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[repr(transparent)]
pub struct NodeId(pub usize);
impl fmt::Display for NodeId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "#{}", self.0)
}
}
impl Add<usize> for NodeId {
type Output = NodeId;
fn add(self, rhs: usize) -> Self::Output {
NodeId(self.0 + rhs)
}
}
impl Add for &NodeId {
type Output = NodeId;
fn add(self, rhs: Self) -> Self::Output {
NodeId(self.0 + rhs.0)
}
}
impl Add<usize> for &NodeId {
type Output = NodeId;
fn add(self, rhs: usize) -> Self::Output {
NodeId(self.0 + rhs)
}
}
impl AddAssign for NodeId {
fn add_assign(&mut self, rhs: Self) {
self.0 += rhs.0;
}
}
impl AddAssign<usize> for NodeId {
fn add_assign(&mut self, rhs: usize) {
self.0 += rhs;
}
}
impl Sub for NodeId {
type Output = usize;
fn sub(self, rhs: Self) -> Self::Output {
self.0 - rhs.0
}
}
impl Sub<usize> for NodeId {
type Output = NodeId;
fn sub(self, rhs: usize) -> Self::Output {
NodeId(self.0 - rhs)
}
}
impl Sub for &NodeId {
type Output = NodeId;
fn sub(self, rhs: Self) -> Self::Output {
NodeId(self.0 - rhs.0)
}
}
impl Sub<usize> for &NodeId {
type Output = NodeId;
fn sub(self, rhs: usize) -> Self::Output {
NodeId(self.0 - rhs)
}
}
impl SubAssign for NodeId {
fn sub_assign(&mut self, rhs: Self) {
self.0 -= rhs.0;
}
}
impl SubAssign<usize> for NodeId {
fn sub_assign(&mut self, rhs: usize) {
self.0 -= rhs;
}
}
#[allow(clippy::len_without_is_empty)]
pub trait NodeLike: ops::Index<usize, Output = Lit> {
type ValIter: DoubleEndedIterator<Item = usize>;
fn is_leaf(&self) -> bool {
self.len() == 1
}
fn max_val(&self) -> usize;
fn len(&self) -> usize;
fn vals<R>(&self, range: R) -> Self::ValIter
where
R: RangeBounds<usize>;
fn right(&self) -> Option<NodeCon>;
fn left(&self) -> Option<NodeCon>;
fn depth(&self) -> usize;
fn n_leaves(&self) -> usize;
fn internal<Db>(left: NodeCon, right: NodeCon, db: &Db) -> Self
where
Db: NodeById<Node = Self>;
fn leaf(lit: Lit) -> Self;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct NodeCon {
pub id: NodeId,
pub(crate) offset: usize,
pub(crate) divisor: NonZeroU8,
pub(crate) multiplier: NonZeroUsize,
pub(crate) len_limit: Option<NonZeroUsize>,
}
impl NodeCon {
#[must_use]
pub fn full(id: NodeId) -> NodeCon {
NodeCon {
id,
offset: 0,
divisor: unreachable_none!(NonZeroU8::new(1)),
multiplier: unreachable_none!(NonZeroUsize::new(1)),
len_limit: None,
}
}
#[must_use]
pub fn weighted(id: NodeId, weight: usize) -> NodeCon {
NodeCon {
id,
offset: 0,
divisor: unreachable_none!(NonZeroU8::new(1)),
multiplier: weight.try_into().unwrap(),
len_limit: None,
}
}
#[cfg_attr(feature = "_internals", visibility::make(pub))]
#[must_use]
pub(crate) fn offset_weighted(id: NodeId, offset: usize, weight: usize) -> NodeCon {
NodeCon {
id,
offset,
divisor: unreachable_none!(NonZeroU8::new(1)),
multiplier: weight.try_into().unwrap(),
len_limit: None,
}
}
#[cfg(any(test, feature = "_internals"))]
#[must_use]
pub fn single(id: NodeId, output: usize, weight: usize) -> NodeCon {
NodeCon {
id,
offset: output - 1,
divisor: unreachable_none!(NonZeroU8::new(1)),
multiplier: weight.try_into().unwrap(),
len_limit: NonZeroUsize::new(1),
}
}
#[cfg(any(test, feature = "_internals"))]
#[must_use]
pub fn limited(id: NodeId, offset: usize, n_lits: usize, weight: usize) -> NodeCon {
assert_ne!(n_lits, 0);
NodeCon {
id,
offset,
divisor: unreachable_none!(NonZeroU8::new(1)),
multiplier: weight.try_into().unwrap(),
len_limit: NonZeroUsize::new(n_lits),
}
}
#[inline]
#[cfg(feature = "_internals")]
#[must_use]
pub fn reweight(self, weight: usize) -> NodeCon {
NodeCon {
multiplier: weight.try_into().unwrap(),
..self
}
}
#[inline]
#[must_use]
pub fn offset(&self) -> usize {
self.offset
}
#[inline]
#[must_use]
pub fn divisor(&self) -> usize {
let div: u8 = self.divisor.into();
div.into()
}
#[inline]
#[must_use]
pub fn multiplier(&self) -> usize {
self.multiplier.into()
}
#[inline]
#[must_use]
pub fn map(&self, val: usize) -> usize {
if val <= self.offset() {
0
} else if let Some(limit) = self.len_limit {
cmp::min((val - self.offset()) / self.divisor(), limit.into()) * self.multiplier()
} else {
(val - self.offset()) / self.divisor() * self.multiplier()
}
}
#[inline]
#[must_use]
pub fn rev_map(&self, val: usize) -> usize {
if let Some(limit) = self.len_limit {
match cmp::min(val / self.multiplier(), limit.into()) * self.divisor() {
0 => 0,
x => x + self.offset(),
}
} else {
val / self.multiplier() * self.divisor() + self.offset()
}
}
#[inline]
#[must_use]
#[cfg(any(feature = "_internals", feature = "proof-logging"))]
pub fn rev_map_no_limit(&self, val: usize) -> usize {
val / self.multiplier() * self.divisor() + self.offset()
}
#[inline]
#[must_use]
pub fn rev_map_round_up(&self, mut val: usize) -> usize {
if let Some(limit) = self.len_limit {
if (val - 1) / self.multiplier() >= limit.into() {
return (Into::<usize>::into(limit) + 1) * self.divisor() + self.offset();
}
}
if val % self.multiplier() > 0 {
val += self.multiplier();
}
self.rev_map(val)
}
#[inline]
#[must_use]
pub fn is_possible(&self, val: usize) -> bool {
if let Some(limit) = self.len_limit {
val % self.multiplier() == 0 && val / self.multiplier() <= limit.into()
} else {
val % self.multiplier() == 0
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
#[error(
"node {referencing} from after the drain range references node {referenced} in the drain range"
)]
pub struct DrainError {
pub referencing: NodeId,
pub referenced: NodeId,
}
#[allow(dead_code)]
pub trait NodeById: IndexMut<NodeId, Output = Self::Node> {
type Node: NodeLike;
fn insert(&mut self, node: Self::Node) -> NodeId;
type Iter<'own>: Iterator<Item = &'own Self::Node>
where
Self: 'own;
fn iter(&self) -> Self::Iter<'_>;
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
fn con_len(&self, con: NodeCon) -> usize {
let len = (self[con.id].len() - con.offset()) / con.divisor();
if let Some(limit) = con.len_limit {
cmp::min(len, limit.into())
} else {
len
}
}
type Drain<'own>: Iterator<Item = Self::Node>
where
Self: 'own;
fn drain<R: RangeBounds<NodeId>>(&mut self, range: R) -> Result<Self::Drain<'_>, DrainError>;
fn lit_tree<I>(&mut self, lits: I) -> Option<NodeId>
where
I: IntoIterator<Item = Lit>,
Self: Sized,
{
let mut cons: Vec<_> = lits
.into_iter()
.map(|l| NodeCon::full(self.insert(Self::Node::leaf(l))))
.collect();
let con = self.merge(&mut cons)?;
debug_assert_eq!(con.offset(), 0);
debug_assert_eq!(con.divisor(), 1);
debug_assert_eq!(con.multiplier(), 1);
Some(con.id)
}
fn weighted_lit_tree(&mut self, lits: &[(Lit, usize)]) -> Option<NodeCon>
where
Self: Sized,
{
debug_assert!(!lits.is_empty());
let mut seg_begin = 0;
let mut cons = vec![];
for seg_end in 1..lits.len() {
if lits[seg_end].1 == lits[seg_begin].1 {
continue;
}
let seg = lits[seg_begin..seg_end].iter().map(|&(lit, _)| lit);
let id = self.lit_tree(seg).unwrap();
cons.push(NodeCon::weighted(id, lits[seg_begin].1));
seg_begin = seg_end;
}
let seg = lits[seg_begin..].iter().map(|&(lit, _)| lit);
let id = self.lit_tree(seg).unwrap();
cons.push(NodeCon::weighted(id, lits[seg_begin].1));
self.merge_balanced(&cons)
}
fn merge(&mut self, cons: &mut [NodeCon]) -> Option<NodeCon>
where
Self: Sized,
{
if cons.is_empty() {
return None;
}
assert!(
cons.len() < isize::MAX.unsigned_abs(),
"due to bit operations the number of literals must be at most `isize::MAX`"
);
let mut reverse_width = 0;
let mut width = cons.len();
while width > 1 {
reverse_width = (reverse_width << 1) | (width & 1);
width /= 2;
}
debug_assert_eq!(width, 1);
#[allow(clippy::cast_possible_wrap)]
let mut reverse_width = reverse_width as isize;
let mut last_reverse_width = reverse_width << 1;
while width <= cons.len() {
let mut start = 0;
let mut split_idx = 1;
while start < cons.len() {
let extend = (split_idx & -split_idx & reverse_width) != 0;
let true_width = width + usize::from(extend);
if true_width == 1 {
split_idx += 1;
start += 1;
continue;
}
if true_width % 2 == 0 {
let lcon = cons[start];
let rcon = cons[start + true_width / 2];
cons[start] = if lcon.multiplier() > 1 && lcon.multiplier() == rcon.multiplier()
{
let weight = lcon.multiplier();
let lcon = NodeCon {
multiplier: unreachable_none!(NonZeroUsize::new(1)),
..lcon
};
let rcon = NodeCon {
multiplier: unreachable_none!(NonZeroUsize::new(1)),
..rcon
};
NodeCon::weighted(
self.insert(Self::Node::internal(lcon, rcon, self)),
weight,
)
} else {
NodeCon::full(self.insert(Self::Node::internal(lcon, rcon, self)))
};
} else {
let left_child_split_idx = (split_idx - 1) * 2 + 1;
let left_child_extend =
(left_child_split_idx & -left_child_split_idx & last_reverse_width) != 0;
let lcon = cons[start];
let rcon = cons[start + true_width / 2 + usize::from(left_child_extend)];
cons[start] = if lcon.multiplier() > 1 && lcon.multiplier() == rcon.multiplier()
{
let weight = lcon.multiplier();
let lcon = NodeCon {
multiplier: unreachable_none!(NonZeroUsize::new(1)),
..lcon
};
let rcon = NodeCon {
multiplier: unreachable_none!(NonZeroUsize::new(1)),
..rcon
};
NodeCon::weighted(
self.insert(Self::Node::internal(lcon, rcon, self)),
weight,
)
} else {
NodeCon::full(self.insert(Self::Node::internal(lcon, rcon, self)))
};
}
split_idx += 1;
start += true_width;
}
#[allow(clippy::cast_sign_loss)]
{
width = (width << 1) | (reverse_width as usize & 1);
}
reverse_width >>= 1;
last_reverse_width >>= 1;
}
debug_assert_eq!(width, cons.len() * 2);
Some(cons[0])
}
fn merge_balanced(&mut self, cons: &[NodeCon]) -> Option<NodeCon>
where
Self: Sized,
{
if cons.is_empty() {
return None;
}
let cum_weight = cons
.iter()
.fold(Vec::with_capacity(cons.len()), |mut cum_weight, con| {
cum_weight.push(cum_weight.last().copied().unwrap_or(0) + self.con_len(*con));
cum_weight
});
Some(merge_balanced_recursive(self, cons, &cum_weight, 0))
}
#[cfg(feature = "_internals")]
fn merge_thorough(&mut self, cons: &mut [NodeCon]) -> Option<NodeCon>
where
Self: Sized,
{
if cons.is_empty() {
return None;
}
cons.sort_unstable_by_key(NodeCon::multiplier);
let mut seg_begin = 0;
let mut merged_cons = vec![];
for seg_end in 1..cons.len() {
if cons[seg_end].multiplier() == cons[seg_begin].multiplier() {
continue;
}
if seg_end > seg_begin + 1 {
let mut seg: Vec<_> = cons[seg_begin..seg_end]
.iter()
.map(|&con| con.reweight(1))
.collect();
seg.sort_unstable_by_key(|&con| self.con_len(con));
let con = self.merge_balanced(&seg).unwrap();
debug_assert_eq!(con.multiplier(), 1);
merged_cons.push(con.reweight(cons[seg_begin].multiplier()));
} else {
merged_cons.push(cons[seg_begin]);
}
seg_begin = seg_end;
}
if cons.len() > seg_begin + 1 {
let mut seg: Vec<_> = cons[seg_begin..]
.iter()
.map(|&con| con.reweight(1))
.collect();
seg.sort_unstable_by_key(|&con| self.con_len(con));
let con = self.merge_balanced(&seg).unwrap();
debug_assert_eq!(con.multiplier(), 1);
merged_cons.push(con.reweight(cons[seg_begin].multiplier()));
} else {
merged_cons.push(cons[seg_begin]);
}
merged_cons.sort_unstable_by_key(|&con| self.con_len(con));
self.merge_balanced(&merged_cons)
}
#[cfg(any(feature = "_internals", feature = "proof-logging"))]
fn leaf_iter(&self, node: NodeId) -> LeafIter<'_, Self>
where
Self: Sized,
{
LeafIter::new(self, node)
}
}
fn merge_balanced_recursive<NDb>(
db: &mut NDb,
cons: &[NodeCon],
cum_weight: &[usize],
offset: usize,
) -> NodeCon
where
NDb: NodeById,
{
debug_assert!(!cons.is_empty());
if cons.len() == 1 {
return cons[0];
}
let threshold = (cum_weight[cum_weight.len() - 1] - offset) / 2 + offset;
let (split, _) = cum_weight
.iter()
.enumerate()
.skip(1)
.find(|&(_, &val)| val >= threshold)
.unwrap();
let lcon = merge_balanced_recursive(db, &cons[..split], &cum_weight[..split], offset);
let rcon = merge_balanced_recursive(
db,
&cons[split..],
&cum_weight[split..],
cum_weight[split - 1],
);
if lcon.multiplier() > 1 && lcon.multiplier() == rcon.multiplier() {
let weight = lcon.multiplier();
let lcon = NodeCon {
multiplier: unreachable_none!(NonZeroUsize::new(1)),
..lcon
};
let rcon = NodeCon {
multiplier: unreachable_none!(NonZeroUsize::new(1)),
..rcon
};
NodeCon::weighted(db.insert(NDb::Node::internal(lcon, rcon, db)), weight)
} else {
NodeCon::full(db.insert(NDb::Node::internal(lcon, rcon, db)))
}
}
#[derive(Debug)]
#[cfg(any(feature = "_internals", feature = "proof-logging"))]
pub struct LeafIter<'db, Db> {
db: &'db Db,
trace: Vec<(NodeId, bool, usize)>,
val_range: std::ops::Range<usize>,
}
#[cfg(any(feature = "_internals", feature = "proof-logging"))]
impl<'db, Db> LeafIter<'db, Db>
where
Db: NodeById,
{
pub fn new(db: &'db Db, root: NodeId) -> Self {
let mut trace = vec![(root, false, 1)];
let mut current = root;
let mut mult = 1;
let mut val_range = 1..2;
while let Some(con) = db[current].left() {
debug_assert_eq!(con.divisor(), 1);
mult *= con.multiplier();
trace.push((con.id, false, mult));
if con.offset() > 0 || con.len_limit.is_some() {
val_range = con.offset() + 1
..con
.len_limit
.map_or(db[con.id].max_val() + 1, |lim| lim.get() + con.offset() + 1);
break;
}
current = con.id;
}
Self {
db,
trace,
val_range,
}
}
fn find_next_leaf_node(&mut self) {
let mut last = self.trace.len();
while last > 0 && self.trace[last - 1].1 {
last -= 1;
}
last -= 1;
self.trace.drain(last..);
if last == 0 {
return;
}
let con = unreachable_none!(self.db[self.trace.last().unwrap().0].right());
let mut mult = unreachable_none!(self.trace.last()).2 * con.multiplier();
self.trace.push((con.id, true, mult));
if con.offset() > 0 || con.len_limit.is_some() {
self.val_range = con.offset() + 1
..con.len_limit.map_or(self.db[con.id].max_val() + 1, |lim| {
lim.get() + con.offset() + 1
});
return;
}
let mut current = con.id;
while let Some(con) = self.db[current].left() {
mult *= con.multiplier();
self.trace.push((con.id, false, mult));
if con.offset() > 0 || con.len_limit.is_some() {
self.val_range = con.offset() + 1
..con.len_limit.map_or(self.db[con.id].max_val() + 1, |lim| {
lim.get() + con.offset() + 1
});
return;
}
current = con.id;
}
self.val_range = 1..2;
}
pub fn lits(self) -> LeafLitIter<'db, Db> {
LeafLitIter::new(self)
}
}
#[derive(Debug, Clone)]
#[cfg(any(feature = "_internals", feature = "proof-logging"))]
pub struct LeafInfo {
pub id: NodeId,
pub weight: usize,
pub val_range: std::ops::Range<usize>,
}
#[cfg(any(feature = "_internals", feature = "proof-logging"))]
impl<Db> Iterator for LeafIter<'_, Db>
where
Db: NodeById,
{
type Item = LeafInfo;
fn next(&mut self) -> Option<Self::Item> {
let elem = *self.trace.last()?;
let info = LeafInfo {
id: elem.0,
weight: elem.2,
val_range: self.val_range.clone(),
};
self.find_next_leaf_node();
Some(info)
}
}
#[derive(Debug)]
#[cfg(any(feature = "_internals", feature = "proof-logging"))]
pub struct LeafLitIter<'db, Db> {
leaves: LeafIter<'db, Db>,
current: LeafInfo,
last_val: usize,
}
#[cfg(any(feature = "_internals", feature = "proof-logging"))]
impl<'db, Db> LeafLitIter<'db, Db>
where
Db: NodeById,
{
fn new(leaves: LeafIter<'db, Db>) -> Self {
Self {
leaves,
current: LeafInfo {
id: NodeId(0),
weight: 0,
val_range: 0..0,
},
last_val: 0,
}
}
}
#[cfg(any(feature = "_internals", feature = "proof-logging"))]
impl<Db> Iterator for LeafLitIter<'_, Db>
where
Db: NodeById,
{
type Item = (Lit, usize);
fn next(&mut self) -> Option<Self::Item> {
if self.current.val_range.is_empty() {
self.current = self.leaves.next()?;
self.last_val = self.current.val_range.start - 1;
}
let val = loop {
let Some(val) = self.leaves.db[self.current.id]
.vals(self.current.val_range.clone())
.next()
else {
self.current = self.leaves.next()?;
self.last_val = self.current.val_range.start - 1;
continue;
};
break val;
};
let lit = self.leaves.db[self.current.id][val];
let weight = self.current.weight * (val - self.last_val);
self.current.val_range.start = val + 1;
self.last_val = val;
Some((lit, weight))
}
}
#[cfg(test)]
mod tests {
use super::{NodeCon, NodeId};
#[test]
fn node_con_map_full() {
let id = NodeId(0);
let nc = NodeCon::full(id);
for val in 1..=10 {
debug_assert_eq!(nc.map(val), val);
debug_assert_eq!(nc.rev_map(val), val);
debug_assert_eq!(nc.rev_map_round_up(val), val);
}
}
#[test]
fn node_con_map_mult() {
let id = NodeId(0);
let weight = 3;
let nc = NodeCon::weighted(id, weight);
for val in 1..=10 {
debug_assert_eq!(nc.map(val), weight * val);
debug_assert_eq!(nc.rev_map(val), val / weight);
debug_assert_eq!(
nc.rev_map_round_up(val),
if val % weight == 0 {
val / weight
} else {
val / weight + 1
}
);
}
}
#[test]
fn node_con_map_div() {
let id = NodeId(0);
let div = 2;
let nc = NodeCon {
id,
offset: 0,
divisor: div.try_into().unwrap(),
multiplier: 1.try_into().unwrap(),
len_limit: None,
};
let div: usize = div.into();
for val in 1..=10 {
debug_assert_eq!(nc.map(val), val / div);
debug_assert_eq!(nc.rev_map(val), val * div);
debug_assert_eq!(nc.rev_map_round_up(val), val * div);
}
}
#[test]
fn node_con_map_offset_weighted() {
let id = NodeId(0);
let offset = 3;
let weight = 5;
let nc = NodeCon::offset_weighted(id, offset, weight);
for val in offset..=10 {
debug_assert_eq!(nc.map(val), (val - offset) * weight);
debug_assert_eq!(nc.rev_map(val), val / weight + offset);
debug_assert_eq!(
nc.rev_map_round_up(val),
if val % weight == 0 {
val / weight + offset
} else {
val / weight + offset + 1
}
);
}
}
#[test]
fn node_con_map_single() {
let id = NodeId(0);
let output = 5;
let weight = 7;
let nc = NodeCon::single(id, output, weight);
for val in output - 1..=20 {
println!("{val}");
debug_assert_eq!(nc.map(val), if val >= output { weight } else { 0 });
debug_assert_eq!(nc.rev_map(val), if val >= weight { output } else { 0 });
debug_assert_eq!(
nc.rev_map_round_up(val),
if val > weight { output + 1 } else { output }
);
}
}
#[test]
fn node_con_map_limited() {
let id = NodeId(0);
let offset = 2;
let weight = 3;
let limit = 5;
let nc = NodeCon::limited(id, offset, limit, weight);
for val in offset..=20 {
println!("{val}");
debug_assert_eq!(nc.map(val), std::cmp::min(val - offset, limit) * weight);
debug_assert_eq!(
nc.rev_map(val),
if val >= weight {
std::cmp::min(val / weight, limit) + offset
} else {
0
}
);
debug_assert_eq!(
nc.rev_map_round_up(val),
if val > weight * limit {
limit + offset + 1
} else if val % weight == 0 {
val / weight + offset
} else {
val / weight + offset + 1
}
);
}
}
}