use core::{
marker::PhantomData,
ptr::NonNull,
};
use alloc::vec::Vec;
pub enum NodeOrTree<T, N> {
Node(N),
Tree(T),
}
impl<T, N> NodeOrTree<T, N> {
pub fn map_tree<F, U>(self, f: F) -> NodeOrTree<U, N>
where
F: FnOnce(T) -> U,
{
match self {
Self::Tree(r) => NodeOrTree::Tree(f(r)),
Self::Node(n) => NodeOrTree::Node(n),
}
}
pub fn map_node<F, U>(self, f: F) -> NodeOrTree<T, U>
where
F: FnOnce(N) -> U,
{
match self {
Self::Tree(r) => NodeOrTree::Tree(r),
Self::Node(n) => NodeOrTree::Node(f(n)),
}
}
pub fn node(self) -> Option<N> {
match self {
Self::Tree(_) => None,
Self::Node(n) => Some(n),
}
}
}
impl<N> NodeOrTree<N, N> {
#[inline]
pub fn flatten(self) -> N {
match self {
Self::Tree(r) => r,
Self::Node(n) => n,
}
}
}
impl<N> NodeOrTree<Option<N>, N> {
#[inline]
pub fn flatten_optional(self) -> Option<N> {
match self {
Self::Tree(r) => r,
Self::Node(n) => Some(n),
}
}
}
pub struct WalkMut<'r, T: ?Sized, N: ?Sized, A = ()> {
_lifetime: PhantomData<&'r mut T>,
tree: NonNull<T>,
stack: Vec<(NonNull<N>, A)>,
}
impl<'r, T: ?Sized, N: ?Sized, A> WalkMut<'r, T, N, A> {
pub fn new(tree: &'r mut T) -> Self {
Self {
_lifetime: PhantomData,
tree: tree.into(),
stack: Vec::new(),
}
}
pub fn try_walk<F, E>(&mut self, with: F) -> Result<(), E>
where
F: for<'n> FnOnce(NodeOrTree<&'n mut T, &'n mut N>) -> Result<(&'n mut N, A), E>,
{
match with(self.current_mut()) {
Err(e) => Err(e),
Ok((next, add)) => {
let next: NonNull<N> = next.into();
self.stack.push((next, add));
Ok(())
},
}
}
pub fn pop(&mut self) -> Option<A> {
Some(self.stack.pop()?.1)
}
pub fn pop_all(&mut self) -> &mut T {
self.stack.clear();
unsafe { self.tree.as_mut() }
}
pub fn current_mut(&mut self) -> NodeOrTree<&mut T, &mut N> {
if let Some((cur, _)) = self.stack.last_mut() {
NodeOrTree::Node(unsafe { cur.as_mut() })
} else {
NodeOrTree::Tree(unsafe { self.tree.as_mut() })
}
}
pub fn into_current_mut(mut self) -> NodeOrTree<&'r mut T, &'r mut N> {
if let Some((cur, _)) = self.stack.last_mut() {
NodeOrTree::Node(unsafe { cur.as_mut() })
} else {
NodeOrTree::Tree(unsafe { self.tree.as_mut() })
}
}
pub fn into_tree_mut(mut self) -> &'r mut T {
self.stack.clear();
unsafe { self.tree.as_mut() }
}
pub fn current(&self) -> NodeOrTree<&T, &N> {
if let Some((cur, _)) = self.stack.last() {
NodeOrTree::Node(unsafe { cur.as_ref() })
} else {
NodeOrTree::Tree(unsafe { self.tree.as_ref() })
}
}
}