use crate::event::{AttrSet, KindRef};
use crate::location::action_entry::ActionEntry;
use crate::schema::attr::{AttrType, Value};
use crate::schema::schema::Schema;
use crate::{AttrId, KindId, UnixTime};
use smallvec::SmallVec;
use std::sync::Arc;
pub struct LocationState<T> {
pub buf: Box<[u8]>,
pub arr_buf: Vec<u8>,
pub kinds_present: Vec<u64>,
pub version: u64,
pub actions: SmallVec<[ActionEntry<T>; 16]>,
pub schema: Arc<Schema>,
}
impl<T> LocationState<T> {
#[must_use]
pub fn new(schema: Arc<Schema>) -> Self {
#[allow(clippy::cast_possible_truncation)]
let nbytes = schema.slot_layout.total_bytes as usize;
#[allow(clippy::integer_division)]
let nwords = (schema.kind_names.len() + 63) / 64;
Self {
buf: vec![0u8; nbytes].into_boxed_slice(),
arr_buf: Vec::new(),
kinds_present: vec![0u64; nwords],
version: 0,
actions: SmallVec::new(),
schema,
}
}
fn resolve_kind(&self, k: KindRef<'_>) -> Option<KindId> {
match k {
KindRef::Id(id) => (usize::from(id.0) < self.schema.kind_names.len()).then_some(id),
KindRef::Name(n) => self.schema.kind(n),
}
}
pub fn apply_update(&mut self, kind: KindRef<'_>, attrs: &AttrSet<'_>) -> bool {
let Some(kind_id) = self.resolve_kind(kind) else {
return false;
};
for &(aid, ref val) in attrs.iter() {
let Some(slot) = self.schema.slot_layout.resolve(kind_id, aid) else {
return false;
};
#[allow(clippy::cast_possible_truncation)]
match (slot.ty, val) {
(AttrType::Int, Value::Int(n)) => {
write_i64(&mut self.buf, slot.offset as usize, *n);
}
(AttrType::F32, Value::F32(x)) => {
write_f32(&mut self.buf, slot.offset as usize, *x);
}
(AttrType::F64, Value::F64(x)) => {
write_f64(&mut self.buf, slot.offset as usize, *x);
}
(AttrType::EnumStr, Value::EnumCode(c)) => {
write_u32(&mut self.buf, slot.offset as usize, *c);
}
(AttrType::F32Arr, Value::F32Arr(v)) => {
use crate::schema::attr::MAX_EMBEDDING_DIM;
if v.len() > MAX_EMBEDDING_DIM {
return false;
}
if self.arr_buf.len().saturating_add(v.len() * 4) > u32::MAX as usize {
return false;
}
let off = self.arr_buf.len() as u32;
let len = v.len() as u32;
let bytes: &[u8] =
unsafe { std::slice::from_raw_parts(v.as_ptr().cast::<u8>(), v.len() * 4) };
self.arr_buf.extend_from_slice(bytes);
write_u32(&mut self.buf, slot.offset as usize, off);
write_u32(&mut self.buf, slot.offset as usize + 4, len);
}
_ => return false,
}
}
let bit = usize::from(kind_id.0);
self.kinds_present[bit / 64] |= 1u64 << (bit % 64);
self.version = self.version.wrapping_add(1);
true
}
pub fn expire(&mut self, now: UnixTime) -> usize {
let k = self.actions.partition_point(|a| a.end <= now);
self.actions.drain(..k);
k
}
#[must_use]
pub fn view(&self) -> LocationView<'_> {
LocationView {
buf: &self.buf,
schema: &self.schema,
version: self.version,
}
}
#[must_use]
pub fn read_f32_arr(&self, kind: KindId, attr: AttrId) -> Option<&[f32]> {
let slot = self.schema.slot_layout.resolve(kind, attr)?;
if !matches!(slot.ty, AttrType::F32Arr) {
return None;
}
let off_bytes = slot.offset as usize;
let off = read_u32_le(&self.buf, off_bytes) as usize;
let len = read_u32_le(&self.buf, off_bytes + 4) as usize;
if len == 0 {
return None; }
let end = off + len * 4;
if end > self.arr_buf.len() {
return None; }
let bytes = &self.arr_buf[off..end];
debug_assert_eq!(
bytes.as_ptr().align_offset(std::mem::align_of::<f32>()),
0,
"arr_buf alignment invariant"
);
#[allow(clippy::cast_ptr_alignment)]
Some(unsafe { std::slice::from_raw_parts(bytes.as_ptr().cast::<f32>(), len) })
}
}
#[inline]
const fn read_u32_le(buf: &[u8], off: usize) -> u32 {
u32::from_le_bytes([buf[off], buf[off + 1], buf[off + 2], buf[off + 3]])
}
#[inline]
fn write_i64(buf: &mut [u8], off: usize, v: i64) {
buf[off..off + 8].copy_from_slice(&v.to_le_bytes());
}
#[inline]
fn write_f32(buf: &mut [u8], off: usize, v: f32) {
buf[off..off + 4].copy_from_slice(&v.to_le_bytes());
}
#[inline]
fn write_f64(buf: &mut [u8], off: usize, v: f64) {
buf[off..off + 8].copy_from_slice(&v.to_le_bytes());
}
#[inline]
fn write_u32(buf: &mut [u8], off: usize, v: u32) {
buf[off..off + 4].copy_from_slice(&v.to_le_bytes());
}
pub struct LocationView<'a> {
pub buf: &'a [u8],
pub schema: &'a Schema,
pub version: u64,
}
impl<'a> LocationView<'a> {
#[must_use]
#[allow(clippy::expect_used)]
pub fn read_i64(&self, off: usize) -> i64 {
let bytes: [u8; 8] = self.buf[off..off + 8].try_into().expect("slot layout bug");
i64::from_le_bytes(bytes)
}
#[must_use]
#[allow(clippy::expect_used)]
pub fn read_f32(&self, off: usize) -> f32 {
let bytes: [u8; 4] = self.buf[off..off + 4].try_into().expect("slot layout bug");
f32::from_le_bytes(bytes)
}
#[must_use]
#[allow(clippy::expect_used)]
pub fn read_f64(&self, off: usize) -> f64 {
let bytes: [u8; 8] = self.buf[off..off + 8].try_into().expect("slot layout bug");
f64::from_le_bytes(bytes)
}
#[must_use]
#[allow(clippy::expect_used)]
pub fn read_u32(&self, off: usize) -> u32 {
let bytes: [u8; 4] = self.buf[off..off + 4].try_into().expect("slot layout bug");
u32::from_le_bytes(bytes)
}
#[must_use]
pub fn read_f32_arr_in(
&self,
arr_buf: &'a [u8],
kind: KindId,
attr: AttrId,
) -> Option<&'a [f32]> {
let slot = self.schema.slot_layout.resolve(kind, attr)?;
if !matches!(slot.ty, AttrType::F32Arr) {
return None;
}
let off_bytes = slot.offset as usize;
let off = read_u32_le(self.buf, off_bytes) as usize;
let len = read_u32_le(self.buf, off_bytes + 4) as usize;
if len == 0 {
return None;
}
let end = off + len * 4;
if end > arr_buf.len() {
return None;
}
let bytes = &arr_buf[off..end];
debug_assert_eq!(
bytes.as_ptr().align_offset(std::mem::align_of::<f32>()),
0,
"arr_buf alignment invariant"
);
#[allow(clippy::cast_ptr_alignment)]
Some(unsafe { std::slice::from_raw_parts(bytes.as_ptr().cast::<f32>(), len) })
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
#![allow(clippy::items_after_statements)]
use crate::engine::compiled_scorer::CompiledScorer;
use crate::location::action_entry::ActionEntry;
use crate::schema::SchemaBuilder;
use crate::scoring::backends::predicate::bytecode::Program;
use smallvec::smallvec;
use super::*;
fn zero_scorer() -> CompiledScorer {
CompiledScorer::Predicate(Program {
ops: vec![crate::scoring::backends::predicate::bytecode::Op::PushF32(
0.0,
)],
max_stack: 1,
})
}
fn tiny_schema() -> Arc<Schema> {
let mut b = SchemaBuilder::new();
let _ = b.kind(
"audience",
&[("male_frac", AttrType::F32), ("dwell", AttrType::Int)],
);
b.build()
}
#[test]
fn apply_update_writes_slots_and_bumps_version() {
let schema = tiny_schema();
let mut st: LocationState<()> = LocationState::new(schema.clone());
let aid_male = schema.attr("male_frac").unwrap();
let aid_dwell = schema.attr("dwell").unwrap();
let kid = schema.kind("audience").unwrap();
let attrs = AttrSet {
entries: smallvec![(aid_male, Value::F32(0.7)), (aid_dwell, Value::Int(42)),],
};
assert!(st.apply_update(KindRef::Id(kid), &attrs));
assert_eq!(st.version, 1);
let view = st.view();
let male = schema.slot_layout.resolve(kid, aid_male).unwrap();
let dwell = schema.slot_layout.resolve(kid, aid_dwell).unwrap();
#[allow(clippy::cast_possible_truncation)]
{
assert!((view.read_f32(male.offset as usize) - 0.7).abs() < 1e-6);
assert_eq!(view.read_i64(dwell.offset as usize), 42);
}
}
#[test]
fn apply_update_rejects_unknown_kind() {
let schema = tiny_schema();
let mut st: LocationState<()> = LocationState::new(schema);
let attrs = AttrSet::new();
assert!(!st.apply_update(KindRef::Name("nope"), &attrs));
}
#[test]
fn location_state_f32_arr_round_trip() {
use crate::event::{AttrSet, KindRef};
use crate::schema::attr::{AttrType, Value};
use crate::schema::builder::SchemaBuilder;
let mut b = SchemaBuilder::new();
let kind_id = b.kind("loc", &[("embed", AttrType::F32Arr)]);
let schema = b.build();
let attr_id = schema.attr_names.get("embed").unwrap();
let mut s: LocationState<()> = LocationState::new(schema);
let mut attrs = AttrSet::new();
attrs.push(attr_id, Value::F32Arr(&[1.0, 2.0, -0.5, 3.25]));
assert!(s.apply_update(KindRef::Id(kind_id), &attrs));
let got = s.read_f32_arr(kind_id, attr_id).unwrap();
assert_eq!(got, &[1.0, 2.0, -0.5, 3.25]);
}
#[test]
fn location_state_rejects_oversize_embedding() {
use crate::event::{AttrSet, KindRef};
use crate::schema::attr::{AttrType, Value, MAX_EMBEDDING_DIM};
use crate::schema::builder::SchemaBuilder;
let mut b = SchemaBuilder::new();
let kind_id = b.kind("loc", &[("embed", AttrType::F32Arr)]);
let schema = b.build();
let attr_id = schema.attr_names.get("embed").unwrap();
let mut s: LocationState<()> = LocationState::new(schema);
let oversize: Vec<f32> = vec![0.0; MAX_EMBEDDING_DIM + 1];
let mut attrs = AttrSet::new();
attrs.push(attr_id, Value::F32Arr(&oversize));
assert!(!s.apply_update(KindRef::Id(kind_id), &attrs));
}
#[test]
fn expire_drops_past_end_actions() {
let schema = tiny_schema();
let mut st: LocationState<()> = LocationState::new(schema);
st.actions.push(ActionEntry {
action_id: crate::ActionId::from("action-1"),
start: 0,
end: 10,
priority: 0,
scorer: zero_scorer(),
payload: (),
post: None,
});
st.actions.push(ActionEntry {
action_id: crate::ActionId::from("action-2"),
start: 0,
end: 20,
priority: 0,
scorer: zero_scorer(),
payload: (),
post: None,
});
assert_eq!(st.expire(15), 1);
assert_eq!(st.actions.len(), 1);
assert_eq!(st.actions[0].action_id.as_str(), "action-2");
}
}