use std::collections::HashSet;
use crate::{HashKind, Kind, StateId};
#[derive(Debug)]
pub struct Insert<Id> {
pub parent_id: Option<Id>,
pub id: Id,
}
pub struct Update<Id, S> {
pub id: Id,
pub state: S,
}
#[derive(Debug)]
pub struct Node<Id, S> {
pub id: Id,
pub state: S,
descendant_keys: HashSet<Id>, pub children: Vec<Node<Id, S>>,
}
impl<K> Node<StateId<K>, K::State>
where
K: Kind + HashKind,
{
#[must_use]
pub fn new(id: StateId<K>) -> Self {
Self {
state: id.new_state(),
id,
descendant_keys: HashSet::new(),
children: Vec::new(),
}
}
#[must_use]
pub const fn zipper(self) -> Zipper<StateId<K>, K::State> {
Zipper {
node: self,
parent: None,
self_idx: 0,
}
}
#[must_use]
pub fn get(&self, id: StateId<K>) -> Option<&Self> {
if self.id == id {
return Some(self);
}
if !self.descendant_keys.contains(&id) {
return None;
}
let mut node = self;
while node.descendant_keys.contains(&id) {
node = node.child(id).unwrap();
}
Some(node)
}
#[must_use]
pub fn get_state(&self, id: StateId<K>) -> Option<&K::State> {
self.get(id).map(|n| &n.state)
}
#[must_use]
pub fn child(&self, id: StateId<K>) -> Option<&Self> {
self.children
.iter()
.find(|node| node.id == id || node.descendant_keys.contains(&id))
}
#[must_use]
pub fn child_idx(&self, id: StateId<K>) -> Option<usize> {
self.children
.iter()
.enumerate()
.find(|(_idx, node)| node.id == id || node.descendant_keys.contains(&id))
.map(|(idx, _)| idx)
}
pub fn insert(&mut self, insert: Insert<StateId<K>>) {
let mut swap_node = Self::new(self.id);
std::mem::swap(&mut swap_node, self);
swap_node = swap_node.into_insert(insert);
std::mem::swap(&mut swap_node, self);
}
#[must_use]
pub fn into_insert(self, Insert { parent_id, id }: Insert<StateId<K>>) -> Self {
let parent_id = parent_id.unwrap();
self.zipper()
.by_id(parent_id)
.insert_child(id)
.finish_insert(id)
}
#[must_use]
pub fn get_parent_id(&self, id: StateId<K>) -> Option<StateId<K>> {
if !self.descendant_keys.contains(&id) {
return None;
}
let mut node = self;
while node.descendant_keys.contains(&id) {
let child_node = node.child(id).unwrap();
if child_node.id == id {
return Some(node.id);
}
node = child_node;
}
None
}
pub fn update(&mut self, update: Update<StateId<K>, K::State>) {
let mut swap_node = Self::new(self.id);
std::mem::swap(&mut swap_node, self);
swap_node = swap_node.into_update(update);
std::mem::swap(&mut swap_node, self);
}
pub fn update_and_get_parent_id(
&mut self,
Update { id, state }: Update<StateId<K>, K::State>,
) -> Option<StateId<K>> {
let mut swap_node = Self::new(self.id);
std::mem::swap(&mut swap_node, self);
let (parent_id, mut swap_node) = swap_node
.zipper()
.by_id(id)
.set_state(state)
.finish_update_parent_id();
std::mem::swap(&mut swap_node, self);
parent_id
}
pub fn update_all_fn<F>(&mut self, f: F)
where
F: Fn(Zipper<StateId<K>, K::State>) -> Self + Clone,
{
let mut swap_node = Self::new(self.id);
std::mem::swap(&mut swap_node, self);
swap_node = swap_node.zipper().finish_update_fn(f);
std::mem::swap(&mut swap_node, self);
}
#[must_use]
pub fn into_update(self, Update { id, state }: Update<StateId<K>, K::State>) -> Self {
self.zipper().by_id(id).set_state(state).finish_update()
}
}
pub struct Zipper<Id, S> {
pub node: Node<Id, S>,
pub parent: Option<Box<Zipper<Id, S>>>,
self_idx: usize,
}
type ZipperNode<K> = Node<StateId<K>, <K as Kind>::State>;
impl<K> Zipper<StateId<K>, K::State>
where
K: Kind + HashKind,
{
fn by_id(mut self, id: StateId<K>) -> Self {
let mut contains_id = self.node.descendant_keys.contains(&id);
while contains_id {
let idx = self.node.child_idx(id).unwrap();
self = self.child(idx);
contains_id = self.node.descendant_keys.contains(&id);
}
assert!(
!(self.node.id != id),
"id[{id}] should be in the node, this is a bug"
);
self
}
fn child(mut self, idx: usize) -> Self {
let child = self.node.children.swap_remove(idx);
Self {
node: child,
parent: Some(Box::new(self)),
self_idx: idx,
}
}
const fn set_state(mut self, state: K::State) -> Self {
self.node.state = state;
self
}
fn insert_child(mut self, id: StateId<K>) -> Self {
self.node.children.push(Node::new(id));
self
}
fn parent(self) -> Self {
let Self {
node,
parent,
self_idx,
} = self;
let mut parent = *parent.unwrap();
parent.node.children.push(node);
let last_idx = parent.node.children.len() - 1;
parent.node.children.swap(self_idx, last_idx);
Self {
node: parent.node,
parent: parent.parent,
self_idx: parent.self_idx,
}
}
fn finish_insert(mut self, id: StateId<K>) -> ZipperNode<K> {
self.node.descendant_keys.insert(id);
while self.parent.is_some() {
self = self.parent();
self.node.descendant_keys.insert(id);
}
self.node
}
#[must_use]
pub fn finish_update(mut self) -> ZipperNode<K> {
while self.parent.is_some() {
self = self.parent();
}
self.node
}
fn finish_update_parent_id(self) -> (Option<StateId<K>>, ZipperNode<K>) {
let parent_id = self.parent.as_ref().map(|z| z.node.id);
(parent_id, self.finish_update())
}
fn finish_update_fn<F>(mut self, f: F) -> ZipperNode<K>
where
F: Fn(Self) -> ZipperNode<K> + Clone,
{
self.node.children = self
.node
.children
.into_iter()
.map(|n| n.zipper().finish_update_fn(f.clone()))
.collect();
f(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{node_state, Kind, State, StateId};
node_state!(Alice, Bob, Charlie, Dave, Eve);
#[test]
fn insert_child_state() {
let alice_id = StateId::new_rand(NodeKind::Alice);
let bob_id = StateId::new_rand(NodeKind::Bob);
let charlie_id = StateId::new_rand(NodeKind::Charlie);
let dave_id = StateId::new_rand(NodeKind::Dave);
let eve_id = StateId::new_rand(NodeKind::Eve);
let mut tree = Node::new(alice_id);
tree.insert(Insert {
parent_id: Some(alice_id),
id: bob_id,
});
tree.insert(Insert {
parent_id: Some(alice_id),
id: charlie_id,
});
tree.insert(Insert {
parent_id: Some(charlie_id),
id: dave_id,
});
tree.insert(Insert {
parent_id: Some(dave_id),
id: eve_id,
});
let mut bob = tree.get_state(bob_id).unwrap();
assert_eq!(bob, &NodeState::Bob(Bob::New));
tree = tree.into_update(Update {
id: bob_id,
state: NodeState::Bob(Bob::Awaiting),
});
bob = tree.get_state(bob_id).unwrap();
assert_eq!(bob, &NodeState::Bob(Bob::Awaiting));
let mut charlie = tree.get_state(charlie_id).unwrap();
assert_eq!(charlie, &NodeState::Charlie(Charlie::New));
tree = tree.into_update(Update {
id: charlie_id,
state: NodeState::Charlie(Charlie::Awaiting),
});
charlie = tree.get_state(charlie_id).unwrap();
assert_eq!(charlie, &NodeState::Charlie(Charlie::Awaiting));
let mut dave = tree.get_state(dave_id).unwrap();
assert_eq!(dave, &NodeState::Dave(Dave::New));
tree = tree.into_update(Update {
id: dave_id,
state: NodeState::Dave(Dave::Completed),
});
dave = tree.get_state(dave_id).unwrap();
assert_eq!(dave, &NodeState::Dave(Dave::Completed));
let mut eve = tree.get_state(eve_id).unwrap();
assert_eq!(eve, &NodeState::Eve(Eve::New));
tree = tree.into_update(Update {
id: eve_id,
state: NodeState::Eve(Eve::Failed),
});
eve = tree.get_state(eve_id).unwrap();
assert_eq!(eve, &NodeState::Eve(Eve::Failed));
tree = tree.zipper().finish_update_fn(|mut z| {
let kind: NodeKind = *z.node.state.as_ref();
if !(z.node.state == kind.completed_state()) {
z.node.state = kind.failed_state();
}
z.finish_update()
});
assert_eq!(&tree.state, &NodeState::Alice(Alice::Failed));
assert_eq!(
tree.get_state(bob_id).unwrap(),
&NodeState::Bob(Bob::Failed)
);
assert_eq!(
tree.get_state(charlie_id).unwrap(),
&NodeState::Charlie(Charlie::Failed)
);
assert_eq!(
tree.get_state(dave_id).unwrap(),
&NodeState::Dave(Dave::Completed)
);
assert_eq!(
tree.get_state(eve_id).unwrap(),
&NodeState::Eve(Eve::Failed)
);
}
}