use crate::{
bflatn_from::{read_tag, vlr_blob_bytes},
blob_store::BlobStore,
eq::BytesPage,
indexes::PageOffset,
page::Page,
row_hash::{read_from_bytes, run_vlo_bytes},
var_len::{VarLenGranule, VarLenRef},
};
use core::str;
use spacetimedb_sats::bsatn::{eq::eq_bsatn, Deserializer};
use spacetimedb_sats::layout::{align_to, AlgebraicTypeLayout, HasLayout as _, ProductTypeLayoutView, RowTypeLayout};
use spacetimedb_sats::{AlgebraicValue, ProductValue};
pub unsafe fn eq_row_in_page_to_pv(
blob_store: &dyn BlobStore,
page: &Page,
fixed_offset: PageOffset,
rhs: &ProductValue,
ty: &RowTypeLayout,
) -> bool {
let mut ctx = EqCtx {
lhs: BytesPage::new(page, fixed_offset, ty),
blob_store,
curr_offset: 0,
};
unsafe { eq_product(&mut ctx, ty.product(), rhs) }
}
#[derive(Clone, Copy)]
struct EqCtx<'page> {
lhs: BytesPage<'page>,
blob_store: &'page dyn BlobStore,
curr_offset: usize,
}
unsafe fn eq_product(ctx: &mut EqCtx<'_>, ty: ProductTypeLayoutView<'_>, rhs: &ProductValue) -> bool {
let base_offset = ctx.curr_offset;
ty.elements.len() == rhs.elements.len()
&& ty.elements.iter().zip(&*rhs.elements).all(|(elem_ty, rhs)| {
ctx.curr_offset = base_offset + elem_ty.offset as usize;
unsafe { eq_value(ctx, &elem_ty.ty, rhs) }
})
}
unsafe fn eq_value(ctx: &mut EqCtx<'_>, ty: &AlgebraicTypeLayout, rhs: &AlgebraicValue) -> bool {
debug_assert_eq!(
ctx.curr_offset,
align_to(ctx.curr_offset, ty.align()),
"curr_offset {} insufficiently aligned for type {:?}",
ctx.curr_offset,
ty
);
match (ty, rhs) {
(AlgebraicTypeLayout::Sum(ty), AlgebraicValue::Sum(rhs)) => {
let (tag_lhs, data_ty) = read_tag(ctx.lhs.bytes, ty, ctx.curr_offset);
if tag_lhs != rhs.tag {
return false;
}
let curr_offset = ctx.curr_offset + ty.offset_of_variant_data(tag_lhs);
ctx.curr_offset += ty.size();
let mut ctx = EqCtx { curr_offset, ..*ctx };
unsafe { eq_value(&mut ctx, data_ty, &rhs.value) }
}
(AlgebraicTypeLayout::Product(ty), AlgebraicValue::Product(rhs)) => {
unsafe { eq_product(ctx, ty.view(), rhs) }
}
(&AlgebraicTypeLayout::Bool, AlgebraicValue::Bool(rhs)) => unsafe { eq_at(ctx, rhs) },
(&AlgebraicTypeLayout::I8, AlgebraicValue::I8(rhs)) => unsafe { eq_at(ctx, rhs) },
(&AlgebraicTypeLayout::U8, AlgebraicValue::U8(rhs)) => unsafe { eq_at(ctx, rhs) },
(&AlgebraicTypeLayout::I16, AlgebraicValue::I16(rhs)) => unsafe { eq_at(ctx, rhs) },
(&AlgebraicTypeLayout::U16, AlgebraicValue::U16(rhs)) => unsafe { eq_at(ctx, rhs) },
(&AlgebraicTypeLayout::I32, AlgebraicValue::I32(rhs)) => unsafe { eq_at(ctx, rhs) },
(&AlgebraicTypeLayout::U32, AlgebraicValue::U32(rhs)) => unsafe { eq_at(ctx, rhs) },
(&AlgebraicTypeLayout::I64, AlgebraicValue::I64(rhs)) => unsafe { eq_at(ctx, rhs) },
(&AlgebraicTypeLayout::U64, AlgebraicValue::U64(rhs)) => unsafe { eq_at(ctx, rhs) },
(&AlgebraicTypeLayout::I128, AlgebraicValue::I128(rhs)) => unsafe { eq_at(ctx, rhs) },
(&AlgebraicTypeLayout::U128, AlgebraicValue::U128(rhs)) => unsafe { eq_at(ctx, rhs) },
(&AlgebraicTypeLayout::I256, AlgebraicValue::I256(rhs)) => unsafe { eq_at(ctx, &**rhs) },
(&AlgebraicTypeLayout::U256, AlgebraicValue::U256(rhs)) => unsafe { eq_at(ctx, &**rhs) },
(&AlgebraicTypeLayout::F32, AlgebraicValue::F32(rhs)) => unsafe { eq_at(ctx, rhs) },
(&AlgebraicTypeLayout::F64, AlgebraicValue::F64(rhs)) => unsafe { eq_at(ctx, rhs) },
(&AlgebraicTypeLayout::String, AlgebraicValue::String(rhs)) => {
unsafe { eq_str(ctx, rhs) }
}
(AlgebraicTypeLayout::VarLen(_), AlgebraicValue::Array(_)) => {
unsafe {
run_vlo_bytes(
ctx.lhs.page,
ctx.lhs.bytes,
ctx.blob_store,
&mut ctx.curr_offset,
|mut bsatn| {
let lhs = Deserializer::new(&mut bsatn);
eq_bsatn(rhs, lhs)
},
)
}
}
_ => false,
}
}
unsafe fn eq_str(ctx: &mut EqCtx<'_>, rhs: &str) -> bool {
let vlr = unsafe { read_from_bytes::<VarLenRef>(ctx.lhs.bytes, &mut ctx.curr_offset) };
if vlr.is_large_blob() {
let bytes = unsafe { vlr_blob_bytes(ctx.lhs.page, ctx.blob_store, vlr) };
rhs == unsafe { str::from_utf8_unchecked(bytes) }
} else {
let lhs_chunks = unsafe { ctx.lhs.page.iter_vlo_data(vlr.first_granule) };
let total_len = vlr.length_in_bytes as usize;
if total_len != rhs.len() {
return false;
}
lhs_chunks
.zip(rhs.as_bytes().chunks(VarLenGranule::DATA_SIZE))
.all(|(l, r)| l == r)
}
}
unsafe fn eq_at<T: Copy + Eq>(ctx: &mut EqCtx<'_>, rhs: &T) -> bool {
&unsafe { read_from_bytes::<T>(ctx.lhs.bytes, &mut ctx.curr_offset) } == rhs
}
#[cfg(test)]
mod tests {
use crate::{blob_store::HashMapBlobStore, page_pool::PagePool};
use proptest::prelude::*;
use spacetimedb_sats::proptest::generate_typed_row;
proptest! {
#![proptest_config(ProptestConfig::with_cases(if cfg!(miri) { 8 } else { 2048 }))]
#[test]
fn pv_row_ref_eq((ty, val) in generate_typed_row()) {
let mut table = crate::table::test::table(ty);
let blob_store = &mut HashMapBlobStore::default();
let (_, row) = table.insert(&PagePool::new_for_test(), blob_store, &val).unwrap();
prop_assert_eq!(row, val);
}
}
}