use std::fmt::{self, Debug};
use std::hash::Hash;
use rustc_hash::FxHashMap;
use slab::Slab;
use crate::constraint::CallSequence;
pub struct CallTree<C, T> {
inner: Slab<InnerNode<C>>,
leaves: Slab<LeafNode<T>>,
start: FxHashMap<u128, NodeId>,
edges: FxHashMap<(InnerId, u128), NodeId>,
}
struct InnerNode<C> {
call: C,
children: usize,
parent: Option<InnerId>,
}
struct LeafNode<T> {
value: T,
parent: Option<InnerId>,
}
impl<C, T> CallTree<C, T> {
pub fn new() -> Self {
Self {
inner: Slab::new(),
leaves: Slab::new(),
edges: FxHashMap::default(),
start: FxHashMap::default(),
}
}
}
impl<C: Hash, T> CallTree<C, T> {
pub fn get(&self, key: u128, mut oracle: impl FnMut(&C) -> u128) -> Option<&T> {
let mut cursor = *self.start.get(&key)?;
loop {
match cursor.kind() {
NodeIdKind::Leaf(id) => {
return Some(&self.leaves[id].value);
}
NodeIdKind::Inner(id) => {
let call = &self.inner[id].call;
let ret = oracle(call);
cursor = *self.edges.get(&(id, ret))?;
}
}
}
}
pub fn insert(
&mut self,
key: u128,
mut sequence: CallSequence<C>,
value: T,
) -> Result<(), InsertError> {
let mut cursor = self.start.get(&key).copied();
let mut predecessor = None;
loop {
if predecessor.is_none()
&& let Some(pos) = cursor
{
let NodeIdKind::Inner(id) = pos.kind() else {
return Err(InsertError::AlreadyExists);
};
let call = &self.inner[id].call;
let Some(ret) = sequence.extract(call) else {
return Err(InsertError::MissingCall);
};
let pair = (id, ret);
if let Some(&next) = self.edges.get(&pair) {
cursor = Some(next);
} else {
predecessor = Some(pair);
}
} else {
let Some((call, ret)) = sequence.next() else { break };
let new_inner_id = self.inner.insert(InnerNode {
call,
children: 0,
parent: predecessor.map(|(id, _)| id),
});
let new_id = NodeId::inner(new_inner_id);
self.link(cursor.is_none(), key, predecessor.take(), new_id);
predecessor = Some((new_inner_id, ret));
cursor = Some(new_id);
}
}
if predecessor.is_none() && cursor.is_some() {
return Err(InsertError::AlreadyExists);
}
let target = NodeId::leaf(
self.leaves
.insert(LeafNode { value, parent: predecessor.map(|(id, _)| id) }),
);
self.link(cursor.is_none(), key, predecessor, target);
Ok(())
}
fn link(
&mut self,
at_start: bool,
key: u128,
from: Option<(InnerId, u128)>,
to: NodeId,
) {
if at_start {
self.start.insert(key, to);
}
if let Some(pair) = from {
self.inner[pair.0].children += 1;
self.edges.insert(pair, to);
}
}
}
impl<C, T> CallTree<C, T> {
pub fn retain(&mut self, mut f: impl FnMut(&mut T) -> bool) {
self.leaves.retain(|_, node| {
let keep = f(&mut node.value);
if !keep {
let mut parent = node.parent;
while let Some(inner_id) = parent {
let node = &mut self.inner[inner_id];
if node.children > 1 {
node.children -= 1;
break;
} else {
parent = self.inner[inner_id].parent;
self.inner.remove(inner_id);
}
}
}
keep
});
let exists = |node: NodeId| match node.kind() {
NodeIdKind::Inner(id) => self.inner.contains(id),
NodeIdKind::Leaf(id) => self.leaves.contains(id),
};
self.edges.retain(|_, node| exists(*node));
self.start.retain(|_, node| exists(*node));
}
#[cfg(test)]
fn assert_consistency(&self) {
let exists = |node: NodeId| match node.kind() {
NodeIdKind::Inner(id) => self.inner.contains(id),
NodeIdKind::Leaf(id) => self.leaves.contains(id),
};
for &node in self.start.values() {
assert!(exists(node));
}
for (&(inner_id, _), &node) in &self.edges {
assert!(exists(node));
assert!(self.inner.contains(inner_id));
}
}
}
impl<C: Debug, T: Debug> Debug for CallTree<C, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
for (&(inner_id, ret), next) in &self.edges {
let call = &self.inner[inner_id].call;
write!(f, "[{inner_id}] ({call:?}, {ret:?}) -> ")?;
match next.kind() {
NodeIdKind::Inner(id) => writeln!(f, "{id}")?,
NodeIdKind::Leaf(id) => writeln!(f, "{:?}", &self.leaves[id].value)?,
}
}
Ok(())
}
}
impl<C, T> Default for CallTree<C, T> {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
struct NodeId(isize);
impl NodeId {
fn inner(i: usize) -> Self {
Self(i as isize)
}
fn leaf(i: usize) -> Self {
Self(-(i as isize) - 1)
}
fn kind(self) -> NodeIdKind {
if self.0 >= 0 {
NodeIdKind::Inner(self.0 as usize)
} else {
NodeIdKind::Leaf((-self.0) as usize - 1)
}
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
enum NodeIdKind {
Inner(InnerId),
Leaf(LeafId),
}
type InnerId = usize;
type LeafId = usize;
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum InsertError {
AlreadyExists,
MissingCall,
}
#[cfg(test)]
mod tests {
use quickcheck::Arbitrary;
use super::*;
#[test]
fn test_call_tree() {
test_ops([
Op::Insert(0, vec![('a', 10), ('b', 15)], "first"),
Op::Insert(0, vec![('a', 10), ('b', 20)], "second"),
Op::Insert(0, vec![('a', 15), ('c', 15)], "third"),
]);
test_ops([
Op::Insert(0, vec![('a', 10), ('b', 15)], "first"),
Op::Insert(0, vec![('a', 10), ('c', 15), ('b', 20)], "second"),
Op::Insert(0, vec![('a', 15), ('b', 30), ('c', 15)], "third"),
Op::Manual(|tree| {
assert_eq!(tree.inner.len(), 5);
assert_eq!(tree.leaves.len(), 3);
assert_eq!(tree.edges.len(), 7);
assert_eq!(tree.start.len(), 1);
}),
Op::Retain(Box::new(|v| *v == "second")),
Op::Manual(|tree| {
assert_eq!(tree.inner.len(), 3);
assert_eq!(tree.leaves.len(), 1);
assert_eq!(tree.edges.len(), 3);
assert_eq!(tree.start.len(), 1);
}),
]);
}
#[quickcheck_macros::quickcheck]
fn test_call_tree_quickcheck(ops: Vec<ArbitraryOp>) {
test_ops(
std::iter::once(Op::IgnoreInsertErrors)
.chain(ops.into_iter().map(ArbitraryOp::into_op)),
);
}
#[derive(Debug, Clone)]
enum ArbitraryOp {
Insert(u128, Vec<u16>, u8),
Retain(u8),
}
impl ArbitraryOp {
fn into_op(self) -> Op<u64, u8> {
match self {
Self::Insert(key, nums, output) => {
let mut state = 50;
Op::Insert(
key,
nums.iter()
.map(move |&v| {
let pair = (state, v as u128);
state += 1 + v as u64;
pair
})
.collect(),
output,
)
}
Self::Retain(mid) => Op::Retain(Box::new(move |v| *v > mid)),
}
}
}
impl Arbitrary for ArbitraryOp {
fn arbitrary(g: &mut quickcheck::Gen) -> Self {
if bool::arbitrary(g) {
Self::Insert(
Arbitrary::arbitrary(g),
Arbitrary::arbitrary(g),
Arbitrary::arbitrary(g),
)
} else {
Self::Retain(Arbitrary::arbitrary(g))
}
}
}
enum Op<C, T> {
IgnoreInsertErrors,
Insert(u128, Vec<(C, u128)>, T),
Retain(Box<dyn Fn(&T) -> bool>),
Manual(fn(&mut CallTree<C, T>)),
}
#[track_caller]
fn test_ops<C, T>(ops: impl IntoIterator<Item = Op<C, T>>)
where
C: Clone + Hash + Eq,
T: Debug + PartialEq + Clone,
{
let mut tree = CallTree::new();
let mut kept = Vec::<(u128, FxHashMap<C, u128>, T)>::new();
let mut ignore_insert_errors = false;
for op in ops {
match op {
Op::IgnoreInsertErrors => ignore_insert_errors = true,
Op::Insert(key, seq, value) => {
match tree.insert(key, seq.iter().cloned().collect(), value.clone()) {
Ok(()) => kept.push((
key,
seq.iter().map(|(k, v)| (k.clone(), *v)).collect(),
value.clone(),
)),
Err(_) if ignore_insert_errors => {}
Err(e) => panic!("{e:?}"),
}
}
Op::Retain(f) => {
tree.retain(|v| f(v));
kept.retain_mut(|(key, map, v)| {
let keep = f(v);
if !keep {
assert_eq!(tree.get(*key, |s| map[s]), None);
}
keep
});
}
Op::Manual(f) => f(&mut tree),
}
tree.assert_consistency();
for (key, map, value) in &kept {
assert_eq!(tree.get(*key, |s| map[s]), Some(value));
}
}
}
}