use super::budget::RuntimeBudgetState;
use super::rewrite::{PreparedRewrite, RewriteScratch};
use crate::allocation::{
AllocationContext, AllocationError, RequestedCapacity, try_push, try_reserve_total_exact,
};
use crate::bytes::{
NonEmptyPayloadNeedle, Payload, PayloadByteCount, PayloadNeedle, RuntimeByte,
RuntimeStateByteCount,
};
use crate::error::{RunError, RunInvariantError, StateSizeError};
use crate::input::InitialStateBytes;
use crate::program::RuntimeStateSnapshot;
use crate::rule::RewriteAction;
use crate::trace::RuntimeStateView;
use alloc::vec::Vec;
#[derive(Debug, PartialEq, Eq)]
pub(crate) struct State {
bytes: Vec<RuntimeByte>,
}
impl State {
pub(crate) fn from_input(input: InitialStateBytes) -> Self {
Self {
bytes: input.into_runtime_bytes(),
}
}
pub(crate) fn byte_count(&self) -> RuntimeStateByteCount {
RuntimeStateByteCount::new(self.bytes.len())
}
pub(crate) fn view(&self) -> RuntimeStateView<'_> {
RuntimeStateView::new(&self.bytes)
}
pub(crate) fn commit_rewrite(
&mut self,
rewrite: PreparedRewrite,
scratch: &mut RewriteScratch,
) {
let previous_state = core::mem::replace(&mut self.bytes, rewrite.into_runtime_bytes());
scratch.store_previous_state(previous_state);
}
pub(crate) fn starts_with_payload(&self, payload: &Payload) -> Option<StateMatch> {
match payload.needle() {
PayloadNeedle::Empty(needle) => {
StateMatch::at_start(needle.byte_count(), self.byte_count())
}
PayloadNeedle::NonEmpty(needle) => self.matches_payload_at(StateIndex::start(), needle),
}
}
pub(crate) fn ends_with_payload(&self, payload: &Payload) -> Option<StateMatch> {
match payload.needle() {
PayloadNeedle::Empty(needle) => {
StateMatch::at_end(needle.byte_count(), self.byte_count())
}
PayloadNeedle::NonEmpty(needle) => {
let start = StateIndex::ending_match_start(self.byte_count(), needle.byte_count())?;
self.matches_payload_at(start, needle)
}
}
}
pub(crate) fn find_payload(&self, payload: &Payload) -> Option<StateMatch> {
match payload.needle() {
PayloadNeedle::Empty(needle) => {
StateMatch::at_start(needle.byte_count(), self.byte_count())
}
PayloadNeedle::NonEmpty(needle) => {
let last_start =
StateIndex::ending_match_start(self.byte_count(), needle.byte_count())?;
for position in StateSearchRange::from_start_to(last_start) {
let first_byte_matches = self
.bytes
.get(position.get())
.copied()
.and_then(RuntimeByte::program_byte)
== Some(needle.first_byte());
if !first_byte_matches {
continue;
}
if let Some(state_match) = self.matches_payload_at(position, needle) {
return Some(state_match);
}
}
None
}
}
}
fn matches_payload_at(
&self,
position: StateIndex,
needle: NonEmptyPayloadNeedle<'_>,
) -> Option<StateMatch> {
let state_match =
StateMatch::at_position(position, needle.byte_count(), self.byte_count())?;
let matches = state_match
.matched_bytes(&self.bytes)
.zip(needle.program_bytes().iter().copied())
.all(|(actual, expected)| actual.program_byte() == Some(expected));
matches.then_some(state_match)
}
pub(crate) fn rewrite_into(
&self,
state_match: StateMatch,
action: &RewriteAction,
output: &mut RewriteScratch,
budget: RuntimeBudgetState,
) -> Result<PreparedRewrite, RunError> {
let state_match = state_match.open(self.byte_count())?;
self.prepare_replacement_buffer(state_match, action.payload(), output, budget)?;
match action {
RewriteAction::Replace(rhs) => {
output.push_existing(state_match.prefix_bytes(&self.bytes))?;
output.push_payload(rhs)?;
output.push_existing(state_match.suffix_bytes(&self.bytes))?;
}
RewriteAction::MoveStart(rhs) => {
output.push_payload(rhs)?;
output.push_existing(state_match.prefix_bytes(&self.bytes))?;
output.push_existing(state_match.suffix_bytes(&self.bytes))?;
}
RewriteAction::MoveEnd(rhs) => {
output.push_existing(state_match.prefix_bytes(&self.bytes))?;
output.push_existing(state_match.suffix_bytes(&self.bytes))?;
output.push_payload(rhs)?;
}
}
Ok(output.take_prepared())
}
fn replaced_byte_count(
&self,
state_match: CheckedStateMatch,
rhs: &Payload,
) -> Result<RuntimeStateByteCount, StateSizeError> {
let state_len = self.byte_count();
let lhs_len = state_match.matched_len();
let rhs_len = rhs.byte_count();
state_len
.get()
.checked_sub(lhs_len.get())
.and_then(|base| base.checked_add(rhs_len.get()))
.map(RuntimeStateByteCount::new)
.ok_or_else(|| StateSizeError::new(state_len, lhs_len, rhs_len))
}
fn prepare_replacement_buffer(
&self,
state_match: CheckedStateMatch,
rhs: &Payload,
output: &mut RewriteScratch,
budget: RuntimeBudgetState,
) -> Result<(), RunError> {
let capacity = self.replaced_byte_count(state_match, rhs)?;
budget.ensure_rewrite_state_len(capacity)?;
output.clear_and_reserve(capacity)?;
Ok(())
}
fn materialize(&self, context: AllocationContext) -> Result<Vec<u8>, AllocationError> {
let mut output = Vec::new();
try_reserve_total_exact(
&mut output,
RequestedCapacity::from_runtime_state_count(self.byte_count()),
context,
)?;
for byte in self.bytes.iter().copied() {
try_push(&mut output, byte.materialize(), context)?;
}
Ok(output)
}
pub(crate) fn into_snapshot(self) -> Result<RuntimeStateSnapshot, RunError> {
let bytes = self
.materialize(AllocationContext::FinalOutput)
.map_err(RunError::from)?;
Ok(RuntimeStateSnapshot::from_materialized(bytes))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct StateIndex {
zero_based: usize,
}
impl StateIndex {
const fn start() -> Self {
Self { zero_based: 0 }
}
const fn from_zero_based(zero_based: usize) -> Self {
Self { zero_based }
}
fn ending_match_start(
state_len: RuntimeStateByteCount,
matched_len: PayloadByteCount,
) -> Option<Self> {
let start = state_len.get().checked_sub(matched_len.get())?;
Some(Self::from_zero_based(start))
}
fn checked_add_count(self, count: PayloadByteCount) -> Option<Self> {
let zero_based = self.zero_based.checked_add(count.get())?;
Some(Self { zero_based })
}
fn checked_next(self) -> Option<Self> {
let zero_based = self.zero_based.checked_add(1)?;
Some(Self { zero_based })
}
const fn get(self) -> usize {
self.zero_based
}
}
struct StateSearchRange {
cursor: StateSearchCursor,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum StateSearchCursor {
Active {
next: StateIndex,
end: StateIndex,
},
Done,
}
impl StateSearchRange {
const fn from_start_to(end: StateIndex) -> Self {
Self {
cursor: StateSearchCursor::Active {
next: StateIndex::start(),
end,
},
}
}
}
impl Iterator for StateSearchRange {
type Item = StateIndex;
fn next(&mut self) -> Option<Self::Item> {
let StateSearchCursor::Active { next, end } = self.cursor else {
return None;
};
let current = next;
if current == end {
self.cursor = StateSearchCursor::Done;
} else if let Some(next) = next.checked_next() {
self.cursor = StateSearchCursor::Active { next, end };
} else {
self.cursor = StateSearchCursor::Done;
}
Some(current)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct StateSpanRange {
start: StateIndex,
end: StateIndex,
matched_len: PayloadByteCount,
state_len: RuntimeStateByteCount,
}
impl StateSpanRange {
fn at_position(
start: StateIndex,
matched_len: PayloadByteCount,
state_len: RuntimeStateByteCount,
) -> Option<Self> {
let end = start.checked_add_count(matched_len)?;
(start.get() <= state_len.get() && end.get() <= state_len.get()).then_some(Self {
start,
end,
matched_len,
state_len,
})
}
const fn start(self) -> usize {
self.start.get()
}
const fn end(self) -> usize {
self.end.get()
}
fn byte_count(self) -> PayloadByteCount {
self.matched_len
}
const fn state_len(self) -> RuntimeStateByteCount {
self.state_len
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) struct StateMatch {
range: StateSpanRange,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct CheckedStateMatch {
range: StateSpanRange,
}
impl StateMatch {
fn at_start(matched_len: PayloadByteCount, state_len: RuntimeStateByteCount) -> Option<Self> {
Self::at_position(StateIndex::start(), matched_len, state_len)
}
fn at_end(matched_len: PayloadByteCount, state_len: RuntimeStateByteCount) -> Option<Self> {
let start = state_len.get().checked_sub(matched_len.get())?;
Self::at_position(StateIndex::from_zero_based(start), matched_len, state_len)
}
fn at_position(
start: StateIndex,
matched_len: PayloadByteCount,
state_len: RuntimeStateByteCount,
) -> Option<Self> {
let range = StateSpanRange::at_position(start, matched_len, state_len)?;
Some(Self { range })
}
fn open(self, current_state_len: RuntimeStateByteCount) -> Result<CheckedStateMatch, RunError> {
if self.range.state_len() == current_state_len {
Ok(CheckedStateMatch { range: self.range })
} else {
Err(RunInvariantError::InvalidStateMatchRange {
matched_state_len: self.range.state_len(),
current_state_len,
}
.into())
}
}
fn matched_bytes(self, bytes: &[RuntimeByte]) -> impl Iterator<Item = RuntimeByte> + '_ {
bytes
.iter()
.copied()
.skip(self.range.start())
.take(self.range.byte_count().get())
}
}
impl CheckedStateMatch {
fn matched_len(self) -> PayloadByteCount {
self.range.byte_count()
}
fn prefix_bytes(self, bytes: &[RuntimeByte]) -> impl Iterator<Item = RuntimeByte> + '_ {
bytes.iter().copied().take(self.range.start())
}
fn suffix_bytes(self, bytes: &[RuntimeByte]) -> impl Iterator<Item = RuntimeByte> + '_ {
bytes.iter().copied().skip(self.range.end())
}
}