use std::{
iter::{Chain, Enumerate, Rev},
ops::RangeBounds,
};
use gap_buf::GapBuffer;
use crate::utils::{binary_search_by_key_and_index, get_range};
#[derive(Default, Clone)]
pub(super) struct ShiftList<S: Shiftable> {
pub(super) buf: GapBuffer<S>,
pub(super) from: usize,
pub(super) shift: S::Shift,
pub(super) max: S::Shift,
}
impl<S: Shiftable> ShiftList<S> {
pub(super) fn new(max: S::Shift) -> Self {
Self {
buf: GapBuffer::new(),
from: 0,
shift: S::Shift::default(),
max,
}
}
#[track_caller]
pub(super) fn insert(&mut self, i: usize, new: S) {
if self.shift != S::Shift::default() {
if i >= self.from {
for s in self.buf.range_mut(self.from..i).iter_mut() {
*s = s.shift(self.shift);
}
} else {
for s in self.buf.range_mut(i..self.from).iter_mut() {
*s = s.shift(self.shift.neg())
}
}
}
self.buf.insert(i, new);
if i + 1 < self.buf.len() && self.shift != S::Shift::default() {
self.from = i + 1;
} else {
self.from = 0;
self.shift = S::Shift::default();
}
}
#[track_caller]
pub(super) fn remove(&mut self, i: usize) {
self.buf.remove(i);
if self.from > i {
self.from -= 1;
}
}
#[inline]
#[track_caller]
pub(super) fn extract_if_while<'a>(
&'a mut self,
range: impl RangeBounds<usize>,
mut f: impl FnMut(usize, S) -> Option<bool> + 'a,
) -> impl Iterator<Item = (usize, S)> + 'a {
let mut range = get_range(range, self.buf.len());
std::iter::from_fn(move || {
while range.start < range.end {
let shifted = if range.start >= self.from {
self.buf[range.start].shift(self.shift)
} else {
self.buf[range.start]
};
if f(range.start, shifted)? {
self.buf.remove(range.start);
self.from -= (range.start < self.from) as usize;
range.end -= 1;
return Some((range.start, shifted));
} else {
range.start += 1;
}
}
None
})
}
#[inline]
#[track_caller]
pub(super) fn rextract_if_while<'a>(
&'a mut self,
range: impl RangeBounds<usize>,
mut f: impl FnMut(usize, S) -> Option<bool> + 'a,
) -> impl Iterator<Item = (usize, S)> + 'a {
let mut range = get_range(range, self.buf.len());
std::iter::from_fn(move || {
while range.end > range.start {
range.end -= 1;
let shifted = if range.end >= self.from {
self.buf[range.end].shift(self.shift)
} else {
self.buf[range.end]
};
if f(range.end, shifted)? {
self.buf.remove(range.end);
self.from -= (range.end < self.from) as usize;
return Some((range.end, shifted));
}
}
None
})
}
#[track_caller]
pub(super) fn shift_by(&mut self, from: usize, by: S::Shift) {
if self.shift != S::Shift::default() {
if from >= self.from {
for s in self.buf.range_mut(self.from..from).iter_mut() {
*s = s.shift(self.shift);
}
} else {
for s in self.buf.range_mut(from..self.from).iter_mut() {
*s = s.shift(self.shift.neg());
}
}
}
self.from = from;
self.shift = self.shift.add(by);
self.max = self.max.add(by);
}
pub(super) fn extend(&mut self, mut other: Self) {
for s in self.buf.range_mut(self.from..).iter_mut() {
*s = s.shift(self.shift);
}
self.from = self.buf.len() + other.from;
self.shift = other.shift;
self.buf
.extend(other.buf.drain(..).map(|s| s.shift(self.max)));
self.max = self.max.add(other.max);
}
#[inline]
#[track_caller]
pub(super) fn iter_fwd(&self, range: impl RangeBounds<usize>) -> IterFwd<'_, S> {
let range = get_range(range, self.buf.len());
let (s0, s1) = self.buf.range(range.clone()).as_slices();
IterFwd {
iter: s0.iter().chain(s1).enumerate(),
from: self.from,
shift: self.shift,
start: range.start,
}
}
#[inline]
#[track_caller]
pub(super) fn iter_rev(&self, range: impl RangeBounds<usize>) -> IterRev<'_, S> {
let range = get_range(range, self.buf.len());
let (s0, s1) = self.buf.range(range.clone()).as_slices();
IterRev {
iter: s1.iter().rev().chain(s0.iter().rev()).enumerate(),
from: self.from,
shift: self.shift,
end: range.end,
}
}
#[inline(always)]
pub(super) fn find_by_key<K: Copy + Eq + Ord>(
&self,
key: K,
f: impl Fn(S) -> K,
) -> Result<usize, usize> {
let sh = |i: usize, s: &S| {
f(if i >= self.from {
s.shift(self.shift)
} else {
*s
})
};
match binary_search_by_key_and_index(&self.buf, self.buf.len(), key, sh) {
Ok(mut i) => Ok(loop {
if let Some(prev_i) = i.checked_sub(1)
&& sh(prev_i, &self.buf[prev_i]) == key
{
i = prev_i
} else {
break i;
}
}),
Err(i) => Err(i),
}
}
#[inline]
pub(super) fn get(&self, i: usize) -> Option<S> {
if i >= self.from {
self.buf.get(i).map(|s| s.shift(self.shift))
} else {
self.buf.get(i).copied()
}
}
#[inline]
pub(super) fn max(&self) -> S::Shift {
self.max
}
#[inline]
pub(super) fn len(&self) -> usize {
self.buf.len()
}
#[inline]
pub(super) fn is_empty(&self) -> bool {
self.buf.is_empty()
}
}
#[derive(Debug, Clone)]
pub struct IterFwd<'a, S: Shiftable> {
iter: Enumerate<Chain<std::slice::Iter<'a, S>, std::slice::Iter<'a, S>>>,
from: usize,
shift: S::Shift,
start: usize,
}
impl<'a, S: Shiftable> Iterator for IterFwd<'a, S> {
type Item = (usize, S);
fn next(&mut self) -> Option<Self::Item> {
self.iter.next().map(move |(i, s)| {
if i + self.start >= self.from {
(i + self.start, s.shift(self.shift))
} else {
(i + self.start, *s)
}
})
}
}
#[derive(Debug, Clone)]
pub struct IterRev<'a, S: Shiftable> {
iter: Enumerate<Chain<Rev<std::slice::Iter<'a, S>>, Rev<std::slice::Iter<'a, S>>>>,
from: usize,
shift: S::Shift,
end: usize,
}
impl<'a, S: Shiftable> Iterator for IterRev<'a, S> {
type Item = (usize, S);
fn next(&mut self) -> Option<Self::Item> {
self.iter.next().map(move |(i, s)| {
if self.end - (i + 1) >= self.from {
(self.end - (i + 1), s.shift(self.shift))
} else {
(self.end - (i + 1), *s)
}
})
}
}
pub(super) trait Shiftable: Copy + Ord + std::fmt::Debug {
type Shift: Shift;
fn shift(self, by: Self::Shift) -> Self;
}
pub(super) trait Shift: Default + Copy + Eq + std::fmt::Debug {
fn neg(self) -> Self;
fn add(self, other: Self) -> Self;
}
impl<S: Shiftable + std::fmt::Debug> std::fmt::Debug for ShiftList<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
assert_eq!(self.iter_fwd(..).count(), self.len());
f.debug_struct("ShiftList")
.field("buf", &DebugBuf(&self.buf, self.from, self.shift))
.field("from", &self.from)
.field("by", &self.shift)
.field("max", &self.max)
.finish()
}
}
struct DebugBuf<'a, S: Shiftable>(&'a GapBuffer<S>, usize, S::Shift);
impl<'a, S: Shiftable> std::fmt::Debug for DebugBuf<'a, S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if f.alternate() && !self.0.is_empty() {
writeln!(f, "[")?;
for (i, elem) in self.0.iter().enumerate() {
let elem = if i >= self.1 {
elem.shift(self.2)
} else {
*elem
};
writeln!(f, " {i}: {elem:?}")?
}
writeln!(f, "]")
} else {
f.debug_list().entries(self.0.iter()).finish()
}
}
}