#[cfg(all(feature = "alloc", not(feature = "std")))]
use alloc::vec::Vec;
use core::fmt;
#[cfg(feature = "std")]
use std::{error, vec::Vec};
use generic_array::{ArrayLength, GenericArray};
use crate::{
values::{self, BytesId, ConstantId},
Constant, Section, SectionMut,
};
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord)]
pub(crate) struct PredicateId(pub(crate) u16);
impl From<BytesId> for PredicateId {
fn from(value: BytesId) -> Self {
Self(value.0)
}
}
impl From<PredicateId> for BytesId {
fn from(value: PredicateId) -> Self {
Self(value.0)
}
}
#[derive(Debug)]
struct PredIter<'a> {
data: &'a [u8],
}
pub(crate) fn pred_iter(facts: Section<'_>) -> impl Iterator<Item = PredicateId> + '_ {
PredIter { data: facts.data() }
}
impl<'a> Iterator for PredIter<'a> {
type Item = PredicateId;
fn next(&mut self) -> Option<Self::Item> {
if self.data.is_empty() {
return None;
}
let predicate = u16::from_be_bytes([self.data[0], self.data[1]]);
let len = usize::from(u16::from_be_bytes([self.data[2], self.data[3]]));
self.data = &self.data[4 + len..];
Some(PredicateId(predicate))
}
}
#[derive(Debug, Clone, PartialEq)]
pub(crate) struct TermIter<'a> {
data: &'a [u8],
}
impl<'a> TermIter<'a> {
fn new(data: &'a [u8]) -> Self {
Self { data }
}
}
impl<'a> Iterator for TermIter<'a> {
type Item = ConstantId;
fn next(&mut self) -> Option<Self::Item>
where
Self: 'a,
{
if self.data.is_empty() {
return None;
}
let id = ConstantId(u16::from_be_bytes([self.data[0], self.data[1]]));
self.data = &self.data[2..];
Some(id)
}
}
#[derive(Debug)]
struct FactTermsIter<'a> {
len: usize,
data: &'a [u8],
}
impl<'a> FactTermsIter<'a> {
fn new(data: &'a [u8], pred_len: usize) -> Self {
Self {
len: pred_len * 2,
data,
}
}
}
impl<'a> Iterator for FactTermsIter<'a> {
type Item = TermIter<'a>;
fn next(&mut self) -> Option<Self::Item>
where
Self: 'a,
{
if self.data.is_empty() {
return None;
}
let iter = TermIter::new(&self.data[..self.len]);
self.data = &self.data[self.len..];
Some(iter)
}
}
pub(crate) fn terms_iter<'a>(
facts: &Section<'a>,
pred: PredicateId,
pred_len: usize,
) -> impl Iterator<Item = TermIter<'a>> + 'a {
let mut data = facts.data();
loop {
if data.is_empty() {
return FactTermsIter::new(data, pred_len);
}
let predicate = PredicateId(u16::from_be_bytes([data[0], data[1]]));
let len = usize::from(u16::from_be_bytes([data[2], data[3]]));
if predicate == pred {
return FactTermsIter::new(&data[4..4 + len], pred_len);
}
data = &data[4 + len..];
}
}
const SECTION_INIT: [u8; 3] = [2, 0, 0];
struct ConstantBytesIter<T> {
iter: T,
size_hint: usize,
}
impl<T> Iterator for ConstantBytesIter<T>
where
T: Iterator<Item = u8>,
{
type Item = u8;
fn next(&mut self) -> Option<Self::Item> {
self.iter.next()
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.size_hint, Some(self.size_hint))
}
}
pub(crate) fn push<N: ArrayLength<ConstantId>>(
mut facts: SectionMut<'_>,
pred: PredicateId,
constants: GenericArray<ConstantId, N>,
) {
assert!(!constants.is_empty());
facts.init(&SECTION_INIT);
let constants_len = constants.len();
let bytes_len = constants_len * 2;
let mut pos = facts.start + 3;
loop {
if pos == facts.end {
break;
}
debug_assert!(pos < facts.end);
let pred_start_pos = pos;
let predicate = PredicateId(u16::from_be_bytes([facts.data[pos], facts.data[pos + 1]]));
let mut len = usize::from(u16::from_be_bytes([
facts.data[pos + 2],
facts.data[pos + 3],
]));
pos += 4;
if predicate != pred {
pos += len;
continue;
}
loop {
if len == 0 {
break;
}
let mut is_equal = true;
let mut offset = 0;
for i in 0..constants_len {
let first_byte = facts.data[pos + offset];
offset += 1;
let second_byte = facts.data[pos + offset];
offset += 1;
let existing_constant = ConstantId(u16::from_be_bytes([first_byte, second_byte]));
if existing_constant != constants[i] {
is_equal = false;
break;
}
}
if is_equal {
return;
}
len = len.checked_sub(bytes_len).unwrap();
pos += bytes_len;
}
let len = u16::from_be_bytes([
facts.data[pred_start_pos + 2],
facts.data[pred_start_pos + 3],
])
.checked_add(u16::try_from(bytes_len).unwrap())
.unwrap();
let len_bytes = len.to_be_bytes();
facts.data[pred_start_pos + 2] = len_bytes[0];
facts.data[pred_start_pos + 3] = len_bytes[1];
let constants = constants.into_iter().flat_map(|c| c.0.to_be_bytes());
facts.data.splice(
pos..pos,
ConstantBytesIter {
iter: constants,
size_hint: bytes_len,
},
);
facts.end += bytes_len;
facts.update_len();
return;
}
debug_assert_eq!(pos, facts.end);
facts
.data
.splice(facts.end..facts.end, pred.0.to_be_bytes());
facts.end += 2;
facts.data.splice(
facts.end..facts.end,
u16::try_from(bytes_len).unwrap().to_be_bytes(),
);
facts.end += 2;
let constants = constants.into_iter().flat_map(|c| c.0.to_be_bytes());
facts.data.splice(
facts.end..facts.end,
ConstantBytesIter {
iter: constants,
size_hint: bytes_len,
},
);
facts.end += bytes_len;
facts.update_len();
}
#[derive(Clone, Copy, PartialEq, Eq)]
pub enum FactTermsError {
InvalidLength,
}
impl fmt::Debug for FactTermsError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InvalidLength => f.write_str("invalid length"),
}
}
}
impl fmt::Display for FactTermsError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InvalidLength => f.write_str("invalid length"),
}
}
}
#[cfg(feature = "std")]
impl error::Error for FactTermsError {
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
match self {
Self::InvalidLength => None,
}
}
}
#[derive(Debug)]
pub struct FactTerms<'a> {
pub(crate) constants: TermIter<'a>,
pub(crate) context: Section<'a>,
}
impl<'a> FactTerms<'a> {
#[must_use]
pub fn to_vec(&self) -> Vec<Constant<'a>> {
self.constants
.clone()
.map(|c| values::constant(&self.context, c))
.collect::<Vec<_>>()
}
pub fn fill_buf<'c>(
&self,
dst: &'c mut [Constant<'a>],
) -> Result<&'c [Constant<'a>], FactTermsError> {
let mut idx = 0;
for id in self.constants.clone() {
if let Some(v) = dst.get_mut(idx) {
*v = values::constant(&self.context, id);
} else {
return Err(FactTermsError::InvalidLength);
}
idx += 1;
}
Ok(&dst[..idx])
}
}