#![no_std]
use core::{fmt::Debug, iter::FusedIterator, mem::take, pin::Pin};
use either::Either;
use temp_inst::{TempInst, TempInstPin, TempRefPin, TempRepr, TempReprMut};
#[derive(TempRepr, TempReprMut)]
pub enum TempStack<Root: TempRepr, Frame: TempRepr> {
Root {
data: Root,
},
Frame {
data: Frame,
parent: TempRefPin<TempStack<Root, Frame>>,
},
}
impl<Root: TempRepr, Frame: TempRepr> TempStack<Root, Frame> {
pub fn new_root(data: Root::Shared<'_>) -> TempStackFrame<'_, Root, Frame> {
TempInst::new(Either::Left(data))
}
pub fn new_frame<'a>(&'a self, data: Frame::Shared<'a>) -> TempStackFrame<'a, Root, Frame> {
TempInst::new(Either::Right((data, self)))
}
pub fn iter(&self) -> TempStackIter<'_, Root, Frame> {
TempStackIter::new(self)
}
}
impl<Root: TempReprMut, Frame: TempReprMut> TempStack<Root, Frame> {
pub fn new_root_mut(data: Root::Mutable<'_>) -> TempStackFrameMut<'_, Root, Frame> {
TempInstPin::new(Either::Left(data))
}
pub fn new_frame_mut<'a>(
self: Pin<&'a mut Self>,
data: Frame::Mutable<'a>,
) -> TempStackFrameMut<'a, Root, Frame> {
TempInstPin::new(Either::Right((data, self)))
}
pub fn iter_mut(self: Pin<&mut Self>) -> TempStackIterMut<'_, Root, Frame> {
TempStackIterMut::new(self)
}
}
impl<Root: TempRepr + Debug, Frame: TempRepr + Debug> Debug for TempStack<Root, Frame> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str("[")?;
self.fmt_contents(f)?;
f.write_str("]")?;
Ok(())
}
}
impl<Root: TempRepr + Debug, Frame: TempRepr + Debug> TempStack<Root, Frame> {
fn fmt_contents(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
TempStack::Root { data } => data.fmt(f),
TempStack::Frame { data, parent } => {
parent.fmt_contents(f)?;
let separator = if matches!(**parent, TempStack::Root { .. }) {
"; "
} else {
", "
};
f.write_str(separator)?;
data.fmt(f)
}
}
}
}
pub type TempStackRef<'a, Root, Frame> = &'a TempStack<Root, Frame>;
pub type TempStackRefMut<'a, Root, Frame> = Pin<&'a mut TempStack<Root, Frame>>;
pub type TempStackFrame<'a, Root, Frame> = TempInst<'a, TempStack<Root, Frame>>;
pub type TempStackFrameMut<'a, Root, Frame> = TempInstPin<'a, TempStack<Root, Frame>>;
pub struct TempStackIter<'a, Root: TempRepr, Frame: TempRepr>(TempStackRef<'a, Root, Frame>);
impl<'a, Root: TempRepr, Frame: TempRepr> TempStackIter<'a, Root, Frame> {
fn new(start: TempStackRef<'a, Root, Frame>) -> Self {
TempStackIter(start)
}
pub fn into_root(mut self) -> Root::Shared<'a> {
loop {
match self.0 {
TempStack::Root { data } => {
return data.get();
}
TempStack::Frame { parent, .. } => {
self.0 = parent.get();
}
}
}
}
}
impl<'a, Root: TempRepr, Frame: TempRepr> Copy for TempStackIter<'a, Root, Frame> {}
impl<'a, Root: TempRepr, Frame: TempRepr> Clone for TempStackIter<'a, Root, Frame> {
fn clone(&self) -> Self {
*self
}
}
impl<'a, Root: TempRepr, Frame: TempRepr> Iterator for TempStackIter<'a, Root, Frame> {
type Item = Frame::Shared<'a>;
fn next(&mut self) -> Option<Self::Item> {
match self.0 {
TempStack::Root { .. } => None,
TempStack::Frame { data, parent } => {
self.0 = parent.get();
Some(data.get())
}
}
}
}
impl<'a, Root: TempRepr, Frame: TempRepr> FusedIterator for TempStackIter<'a, Root, Frame> {}
pub struct TempStackIterMut<'a, Root: TempReprMut, Frame: TempReprMut>(
Option<TempStackRefMut<'a, Root, Frame>>,
);
impl<'a, Root: TempReprMut, Frame: TempReprMut> TempStackIterMut<'a, Root, Frame> {
fn new(start: TempStackRefMut<'a, Root, Frame>) -> Self {
TempStackIterMut(Some(start))
}
pub fn into_root(self) -> Root::Mutable<'a> {
let mut temp = self.0.unwrap();
unsafe {
loop {
match temp.get_unchecked_mut() {
TempStack::Root { data } => {
return Pin::new_unchecked(data).get_mut_pinned();
}
TempStack::Frame { parent, .. } => {
temp = Pin::new_unchecked(parent).get_mut_pinned();
}
}
}
}
}
}
impl<'a, Root: TempReprMut, Frame: TempReprMut> Iterator for TempStackIterMut<'a, Root, Frame> {
type Item = Frame::Mutable<'a>;
fn next(&mut self) -> Option<Self::Item> {
let temp = take(&mut self.0).unwrap();
unsafe {
let temp = temp.get_unchecked_mut();
match temp {
TempStack::Root { .. } => {
self.0 = Some(Pin::new_unchecked(temp));
None
}
TempStack::Frame { data, parent } => {
self.0 = Some(Pin::new_unchecked(parent).get_mut_pinned());
Some(Pin::new_unchecked(data).get_mut_pinned())
}
}
}
}
}
impl<'a, Root: TempReprMut, Frame: TempReprMut> FusedIterator
for TempStackIterMut<'a, Root, Frame>
{
}
#[cfg(test)]
mod tests {
use core::pin::pin;
use temp_inst::{TempRef, TempRefMut};
use super::*;
#[test]
fn empty_stack() {
let root = 42;
let stack = TempStack::<TempRef<i32>, ()>::new_root(&root);
let mut iter = stack.iter();
assert!(iter.next().is_none());
let root_ref = iter.into_root();
assert_eq!(*root_ref, 42);
}
#[test]
fn empty_stack_mut() {
let mut root = 42;
let stack = pin!(TempStack::<TempRefMut<i32>, ()>::new_root_mut(&mut root));
let mut iter = stack.deref_pin().iter_mut();
assert!(iter.next().is_none());
let root_ref = iter.into_root();
assert_eq!(*root_ref, 42);
*root_ref += 1;
assert_eq!(root, 43);
}
#[test]
fn stack_with_frames() {
let root = 42;
let stack = TempStack::<TempRef<i32>, TempRef<i32>>::new_root(&root);
let stack = stack.new_frame(&1);
let stack = stack.new_frame(&2);
let stack = stack.new_frame(&3);
let mut iter = stack.iter();
assert_eq!(iter.next(), Some(&3));
assert_eq!(iter.next(), Some(&2));
assert_eq!(iter.next(), Some(&1));
assert!(iter.next().is_none());
let root_ref = iter.into_root();
assert_eq!(*root_ref, 42);
let iter = stack.iter();
let root_ref = iter.into_root();
assert_eq!(*root_ref, 42);
}
#[test]
fn stack_with_frames_mut() {
let mut root = 42;
let stack = pin!(TempStack::<TempRefMut<i32>, TempRefMut<i32>>::new_root_mut(
&mut root
));
let mut frame1 = 1;
let stack = pin!(stack.deref_pin().new_frame_mut(&mut frame1));
let mut frame2 = 2;
let stack = pin!(stack.deref_pin().new_frame_mut(&mut frame2));
let mut frame3 = 3;
let mut stack = pin!(stack.deref_pin().new_frame_mut(&mut frame3));
let mut iter = stack.as_mut().deref_pin().iter_mut();
let frame3_ref = iter.next().unwrap();
assert_eq!(frame3_ref, &mut 3);
*frame3_ref += 1;
assert_eq!(iter.next(), Some(&mut 2));
let frame1_ref = iter.next().unwrap();
assert_eq!(frame1_ref, &mut 1);
*frame1_ref -= 1;
assert!(iter.next().is_none());
let root_ref = iter.into_root();
assert_eq!(*root_ref, 42);
*root_ref += 1;
let iter = stack.deref_pin().iter_mut();
let root_ref = iter.into_root();
assert_eq!(*root_ref, 43);
assert_eq!(root, 43);
assert_eq!(frame1, 0);
assert_eq!(frame2, 2);
assert_eq!(frame3, 4);
}
#[test]
fn stack_with_branching() {
let root = 42;
let stack = TempStack::<TempRef<i32>, TempRef<i32>>::new_root(&root);
let stack = stack.new_frame(&1);
let stack = stack.new_frame(&2);
let stack2 = stack.new_frame(&11);
let stack = stack.new_frame(&3);
let stack2 = stack2.new_frame(&12);
let stack2 = stack2.new_frame(&13);
let mut iter = stack.iter();
assert_eq!(iter.next(), Some(&3));
assert_eq!(iter.next(), Some(&2));
assert_eq!(iter.next(), Some(&1));
assert!(iter.next().is_none());
let mut iter2 = stack2.iter();
assert_eq!(iter2.next(), Some(&13));
assert_eq!(iter2.next(), Some(&12));
assert_eq!(iter2.next(), Some(&11));
assert_eq!(iter2.next(), Some(&2));
assert_eq!(iter2.next(), Some(&1));
assert!(iter2.next().is_none());
}
#[test]
fn stack_with_branching_mut() {
let mut root = 42;
let stack = pin!(TempStack::<TempRefMut<i32>, TempRefMut<i32>>::new_root_mut(
&mut root
));
let mut frame1 = 1;
let stack = pin!(stack.deref_pin().new_frame_mut(&mut frame1));
let mut frame2 = 2;
let mut stack = pin!(stack.deref_pin().new_frame_mut(&mut frame2));
let mut frame3 = 3;
let stack2 = pin!(stack.as_mut().deref_pin().new_frame_mut(&mut frame3));
let mut iter = stack2.deref_pin().iter_mut();
let frame3_ref = iter.next().unwrap();
assert_eq!(frame3_ref, &mut 3);
*frame3_ref += 1;
assert_eq!(iter.next(), Some(&mut 2));
let frame1_ref = iter.next().unwrap();
assert_eq!(frame1_ref, &mut 1);
*frame1_ref -= 1;
assert!(iter.next().is_none());
let root_ref = iter.into_root();
assert_eq!(*root_ref, 42);
*root_ref += 1;
let mut iter = stack.deref_pin().iter_mut();
assert_eq!(iter.next(), Some(&mut 2));
assert_eq!(iter.next(), Some(&mut 0));
assert!(iter.next().is_none());
let root_ref = iter.into_root();
assert_eq!(*root_ref, 43);
assert_eq!(root, 43);
assert_eq!(frame1, 0);
assert_eq!(frame2, 2);
assert_eq!(frame3, 4);
}
}