use super::budget::StepReservation;
use super::rewrite::{PreparedRewrite, RewriteScratch};
use crate::allocation::AllocationError;
use crate::bytes::{
NonEmptyPayloadNeedle, Payload, PayloadByteCount, PayloadNeedle, RuntimeByte,
RuntimeStateByteCount,
};
use crate::error::{RewriteSizeError, RunStepError};
use crate::input::InitialStateBytes;
use crate::program::RuntimeStateSnapshot;
use crate::rule::RewriteAction;
use crate::trace::RuntimeStateView;
use alloc::vec::Vec;
#[derive(Debug, PartialEq, Eq)]
pub(crate) struct State {
bytes: Vec<RuntimeByte>,
}
impl State {
pub(crate) fn from_input(input: InitialStateBytes) -> Self {
Self {
bytes: input.into_runtime_bytes(),
}
}
pub(crate) fn byte_count(&self) -> RuntimeStateByteCount {
RuntimeStateByteCount::new(self.bytes.len())
}
pub(crate) fn view(&self) -> RuntimeStateView<'_> {
RuntimeStateView::new(&self.bytes)
}
pub(crate) fn commit_rewrite(
&mut self,
rewrite: PreparedRewrite,
scratch: &mut RewriteScratch,
) {
let previous_state = core::mem::replace(&mut self.bytes, rewrite.into_runtime_bytes());
scratch.store_previous_state(previous_state);
}
pub(crate) fn starts_with_payload(&self, payload: &Payload) -> Option<StateMatch<'_>> {
match payload.needle() {
PayloadNeedle::Empty(needle) => StateMatch::at_start(needle.byte_count(), &self.bytes),
PayloadNeedle::NonEmpty(needle) => self.matches_payload_at(StateIndex::start(), needle),
}
}
pub(crate) fn ends_with_payload(&self, payload: &Payload) -> Option<StateMatch<'_>> {
match payload.needle() {
PayloadNeedle::Empty(needle) => StateMatch::at_end(needle.byte_count(), &self.bytes),
PayloadNeedle::NonEmpty(needle) => {
let start = StateIndex::ending_match_start(self.byte_count(), needle.byte_count())?;
self.matches_payload_at(start, needle)
}
}
}
pub(crate) fn find_payload(&self, payload: &Payload) -> Option<StateMatch<'_>> {
match payload.needle() {
PayloadNeedle::Empty(needle) => StateMatch::at_start(needle.byte_count(), &self.bytes),
PayloadNeedle::NonEmpty(needle) => {
let last_start =
StateIndex::ending_match_start(self.byte_count(), needle.byte_count())?;
for position in StateSearchRange::from_start_to(last_start) {
let first_byte_matches = self
.bytes
.get(position.get())
.copied()
.and_then(RuntimeByte::program_byte)
== Some(needle.first_byte());
if !first_byte_matches {
continue;
}
if let Some(state_match) = self.matches_payload_at(position, needle) {
return Some(state_match);
}
}
None
}
}
}
fn matches_payload_at(
&self,
position: StateIndex,
needle: NonEmptyPayloadNeedle<'_>,
) -> Option<StateMatch<'_>> {
let state_match = StateMatch::at_position(position, needle.byte_count(), &self.bytes)?;
let matches = state_match
.matched_bytes()
.zip(needle.program_bytes().iter().copied())
.all(|(actual, expected)| actual.program_byte() == Some(expected));
matches.then_some(state_match)
}
pub(crate) fn into_snapshot(self) -> Result<RuntimeStateSnapshot, AllocationError> {
RuntimeStateSnapshot::from_final_state_view(self.view())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct StateIndex {
zero_based: usize,
}
impl StateIndex {
const fn start() -> Self {
Self { zero_based: 0 }
}
const fn from_zero_based(zero_based: usize) -> Self {
Self { zero_based }
}
fn ending_match_start(
state_len: RuntimeStateByteCount,
matched_len: PayloadByteCount,
) -> Option<Self> {
let start = state_len.get().checked_sub(matched_len.get())?;
Some(Self::from_zero_based(start))
}
fn checked_add_count(self, count: PayloadByteCount) -> Option<Self> {
let zero_based = self.zero_based.checked_add(count.get())?;
Some(Self { zero_based })
}
fn checked_next(self) -> Option<Self> {
let zero_based = self.zero_based.checked_add(1)?;
Some(Self { zero_based })
}
const fn get(self) -> usize {
self.zero_based
}
}
struct StateSearchRange {
cursor: StateSearchCursor,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum StateSearchCursor {
Active {
next: StateIndex,
end: StateIndex,
},
Done,
}
impl StateSearchRange {
const fn from_start_to(end: StateIndex) -> Self {
Self {
cursor: StateSearchCursor::Active {
next: StateIndex::start(),
end,
},
}
}
}
impl Iterator for StateSearchRange {
type Item = StateIndex;
fn next(&mut self) -> Option<Self::Item> {
let StateSearchCursor::Active { next, end } = self.cursor else {
return None;
};
let current = next;
if current == end {
self.cursor = StateSearchCursor::Done;
} else if let Some(next) = next.checked_next() {
self.cursor = StateSearchCursor::Active { next, end };
} else {
self.cursor = StateSearchCursor::Done;
}
Some(current)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct StateSpanRange {
start: StateIndex,
end: StateIndex,
matched_len: PayloadByteCount,
}
impl StateSpanRange {
fn at_position(
start: StateIndex,
matched_len: PayloadByteCount,
state_len: RuntimeStateByteCount,
) -> Option<Self> {
let end = start.checked_add_count(matched_len)?;
(start.get() <= state_len.get() && end.get() <= state_len.get()).then_some(Self {
start,
end,
matched_len,
})
}
const fn start(self) -> usize {
self.start.get()
}
const fn end(self) -> usize {
self.end.get()
}
fn byte_count(self) -> PayloadByteCount {
self.matched_len
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) struct StateMatch<'state> {
range: StateSpanRange,
bytes: &'state [RuntimeByte],
}
impl<'state> StateMatch<'state> {
fn at_start(matched_len: PayloadByteCount, bytes: &'state [RuntimeByte]) -> Option<Self> {
Self::at_position(StateIndex::start(), matched_len, bytes)
}
fn at_end(matched_len: PayloadByteCount, bytes: &'state [RuntimeByte]) -> Option<Self> {
let state_len = RuntimeStateByteCount::new(bytes.len());
let start = state_len.get().checked_sub(matched_len.get())?;
Self::at_position(StateIndex::from_zero_based(start), matched_len, bytes)
}
fn at_position(
start: StateIndex,
matched_len: PayloadByteCount,
bytes: &'state [RuntimeByte],
) -> Option<Self> {
let state_len = RuntimeStateByteCount::new(bytes.len());
let range = StateSpanRange::at_position(start, matched_len, state_len)?;
Some(Self { range, bytes })
}
fn matched_bytes(self) -> impl Iterator<Item = RuntimeByte> + 'state {
self.bytes
.iter()
.copied()
.skip(self.range.start())
.take(self.range.byte_count().get())
}
}
impl<'state> StateMatch<'state> {
pub(crate) fn rewrite_into(
self,
action: &RewriteAction,
output: &mut RewriteScratch,
step: &StepReservation<'_>,
) -> Result<PreparedRewrite, RunStepError> {
self.prepare_replacement_buffer(action.payload(), output, step)?;
match action {
RewriteAction::Replace(rhs) => {
output.push_existing(self.prefix_bytes())?;
output.push_payload(rhs)?;
output.push_existing(self.suffix_bytes())?;
}
RewriteAction::MoveStart(rhs) => {
output.push_payload(rhs)?;
output.push_existing(self.prefix_bytes())?;
output.push_existing(self.suffix_bytes())?;
}
RewriteAction::MoveEnd(rhs) => {
output.push_existing(self.prefix_bytes())?;
output.push_existing(self.suffix_bytes())?;
output.push_payload(rhs)?;
}
}
Ok(output.take_prepared())
}
fn replaced_byte_count(self, rhs: &Payload) -> Result<RuntimeStateByteCount, RewriteSizeError> {
let state_len = RuntimeStateByteCount::new(self.bytes.len());
let lhs_len = self.matched_len();
let rhs_len = rhs.byte_count();
state_len
.get()
.checked_sub(lhs_len.get())
.and_then(|base| base.checked_add(rhs_len.get()))
.map(RuntimeStateByteCount::new)
.ok_or_else(|| RewriteSizeError::new(state_len, lhs_len, rhs_len))
}
fn prepare_replacement_buffer(
self,
rhs: &Payload,
output: &mut RewriteScratch,
step: &StepReservation<'_>,
) -> Result<(), RunStepError> {
let capacity = self.replaced_byte_count(rhs)?;
step.ensure_rewrite_state_len(capacity)?;
output.clear_and_reserve(capacity)?;
Ok(())
}
fn matched_len(self) -> PayloadByteCount {
self.range.byte_count()
}
fn prefix_bytes(self) -> impl Iterator<Item = RuntimeByte> + 'state {
self.bytes.iter().copied().take(self.range.start())
}
fn suffix_bytes(self) -> impl Iterator<Item = RuntimeByte> + 'state {
self.bytes.iter().copied().skip(self.range.end())
}
}