#![allow(missing_docs)]
use crate::delta::{Observation, TileVertexId};
use core::mem::size_of;
pub const MAX_HYPOTHESES: usize = 16;
pub const WINDOW_SIZE: usize = 64;
pub type LogEValue = i32;
pub const LOG_E_STRONG: LogEValue = 282944;
pub const LOG_E_VERY_STRONG: LogEValue = 436906;
pub const LOG_LR_CONNECTIVITY_POS: LogEValue = 38550;
pub const LOG_LR_CONNECTIVITY_NEG: LogEValue = -65536;
pub const LOG_LR_WITNESS_POS: LogEValue = 65536;
pub const LOG_LR_WITNESS_NEG: LogEValue = -65536;
pub const FIXED_SCALE: i32 = 65536;
#[inline]
pub fn simd_aggregate_log_e(log_e_values: &[LogEValue]) -> i64 {
let mut lanes = [0i64; 4];
let chunks = log_e_values.chunks_exact(4);
let remainder = chunks.remainder();
for chunk in chunks {
lanes[0] += chunk[0] as i64;
lanes[1] += chunk[1] as i64;
lanes[2] += chunk[2] as i64;
lanes[3] += chunk[3] as i64;
}
for (i, &val) in remainder.iter().enumerate() {
lanes[i % 4] += val as i64;
}
lanes[0] + lanes[1] + lanes[2] + lanes[3]
}
#[inline]
pub fn simd_aggregate_log_e_wide(log_e_values: &[LogEValue]) -> i64 {
let mut lanes = [0i64; 8];
let chunks = log_e_values.chunks_exact(8);
let remainder = chunks.remainder();
for chunk in chunks {
lanes[0] += chunk[0] as i64;
lanes[1] += chunk[1] as i64;
lanes[2] += chunk[2] as i64;
lanes[3] += chunk[3] as i64;
lanes[4] += chunk[4] as i64;
lanes[5] += chunk[5] as i64;
lanes[6] += chunk[6] as i64;
lanes[7] += chunk[7] as i64;
}
for (i, &val) in remainder.iter().enumerate() {
lanes[i % 8] += val as i64;
}
let sum_0_3 = lanes[0] + lanes[1] + lanes[2] + lanes[3];
let sum_4_7 = lanes[4] + lanes[5] + lanes[6] + lanes[7];
sum_0_3 + sum_4_7
}
#[inline]
pub fn aggregate_tile_evidence(tile_log_e_values: &[LogEValue; 255]) -> i64 {
simd_aggregate_log_e(tile_log_e_values)
}
#[inline(always)]
pub const fn log_e_to_f32(log_e: LogEValue) -> f32 {
let log2_val = (log_e as f32) / 65536.0;
log2_val
}
#[inline(always)]
pub fn f32_to_log_e(e: f32) -> LogEValue {
if e <= 0.0 {
i32::MIN
} else if e == 1.0 {
0 } else if e == 2.0 {
FIXED_SCALE } else if e == 0.5 {
-FIXED_SCALE } else {
let log2_e = libm::log2f(e);
(log2_e * 65536.0) as i32
}
}
#[inline(always)]
pub const fn log_lr_for_obs_type(obs_type: u8, flags: u8, value: u16) -> LogEValue {
match obs_type {
Observation::TYPE_CONNECTIVITY => {
if flags != 0 {
LOG_LR_CONNECTIVITY_POS
} else {
LOG_LR_CONNECTIVITY_NEG
}
}
Observation::TYPE_WITNESS => {
if flags != 0 {
LOG_LR_WITNESS_POS
} else {
LOG_LR_WITNESS_NEG
}
}
_ => 0,
}
}
#[derive(Debug, Clone, Copy)]
#[repr(C, align(16))]
pub struct HypothesisState {
pub log_e_value: LogEValue,
pub obs_count: u32,
pub id: u16,
pub target: TileVertexId,
pub threshold: TileVertexId,
pub hyp_type: u8,
pub flags: u8,
}
impl Default for HypothesisState {
#[inline]
fn default() -> Self {
Self::new(0, 0)
}
}
impl HypothesisState {
pub const FLAG_ACTIVE: u8 = 0x01;
pub const FLAG_REJECTED: u8 = 0x02;
pub const FLAG_STRONG: u8 = 0x04;
pub const FLAG_VERY_STRONG: u8 = 0x08;
pub const TYPE_CONNECTIVITY: u8 = 0;
pub const TYPE_CUT: u8 = 1;
pub const TYPE_FLOW: u8 = 2;
#[inline(always)]
pub const fn new(id: u16, hyp_type: u8) -> Self {
Self {
log_e_value: 0, obs_count: 0,
id,
target: 0,
threshold: 0,
hyp_type,
flags: Self::FLAG_ACTIVE,
}
}
#[inline(always)]
pub const fn connectivity(id: u16, vertex: TileVertexId) -> Self {
Self {
log_e_value: 0,
obs_count: 0,
id,
target: vertex,
threshold: 0,
hyp_type: Self::TYPE_CONNECTIVITY,
flags: Self::FLAG_ACTIVE,
}
}
#[inline(always)]
pub const fn cut_membership(id: u16, vertex: TileVertexId, threshold: TileVertexId) -> Self {
Self {
log_e_value: 0,
obs_count: 0,
id,
target: vertex,
threshold,
hyp_type: Self::TYPE_CUT,
flags: Self::FLAG_ACTIVE,
}
}
#[inline(always)]
pub const fn is_active(&self) -> bool {
self.flags & Self::FLAG_ACTIVE != 0
}
#[inline(always)]
pub const fn is_rejected(&self) -> bool {
self.flags & Self::FLAG_REJECTED != 0
}
#[inline(always)]
pub const fn can_update(&self) -> bool {
(self.flags & (Self::FLAG_ACTIVE | Self::FLAG_REJECTED)) == Self::FLAG_ACTIVE
}
#[inline(always)]
pub fn e_value_approx(&self) -> f32 {
let log2_val = (self.log_e_value as f32) / 65536.0;
libm::exp2f(log2_val)
}
#[inline]
pub fn update(&mut self, likelihood_ratio: f32) -> bool {
if !self.can_update() {
return self.is_rejected();
}
let log_lr = f32_to_log_e(likelihood_ratio);
self.update_with_log_lr(log_lr)
}
#[inline(always)]
pub fn update_with_log_lr(&mut self, log_lr: LogEValue) -> bool {
self.log_e_value = self.log_e_value.saturating_add(log_lr);
self.obs_count += 1;
if self.log_e_value > LOG_E_VERY_STRONG {
self.flags |= Self::FLAG_VERY_STRONG | Self::FLAG_STRONG;
} else if self.log_e_value > LOG_E_STRONG {
self.flags |= Self::FLAG_STRONG;
self.flags &= !Self::FLAG_VERY_STRONG;
} else {
self.flags &= !(Self::FLAG_STRONG | Self::FLAG_VERY_STRONG);
}
if self.log_e_value > LOG_E_STRONG {
self.flags |= Self::FLAG_REJECTED;
return true;
}
false
}
#[inline]
pub fn reset(&mut self) {
self.log_e_value = 0;
self.obs_count = 0;
self.flags = Self::FLAG_ACTIVE;
}
}
#[derive(Debug, Clone, Copy, Default)]
#[repr(C)]
pub struct ObsRecord {
pub obs: Observation,
pub tick: u32,
}
#[derive(Clone)]
#[repr(C, align(64))]
pub struct EvidenceAccumulator {
pub global_log_e: LogEValue,
pub total_obs: u32,
pub current_tick: u32,
pub window_head: u16,
pub window_count: u16,
pub num_hypotheses: u8,
pub _reserved: [u8; 1],
pub rejected_count: u16,
pub status: u16,
_hot_pad: [u8; 40],
pub hypotheses: [HypothesisState; MAX_HYPOTHESES],
pub window: [ObsRecord; WINDOW_SIZE],
}
impl Default for EvidenceAccumulator {
#[inline]
fn default() -> Self {
Self::new()
}
}
impl EvidenceAccumulator {
pub const STATUS_ACTIVE: u16 = 0x0001;
pub const STATUS_HAS_REJECTION: u16 = 0x0002;
pub const STATUS_SIGNIFICANT: u16 = 0x0004;
pub const fn new() -> Self {
Self {
global_log_e: 0,
total_obs: 0,
current_tick: 0,
window_head: 0,
window_count: 0,
num_hypotheses: 0,
_reserved: [0; 1],
rejected_count: 0,
status: Self::STATUS_ACTIVE,
_hot_pad: [0; 40],
hypotheses: [HypothesisState::new(0, 0); MAX_HYPOTHESES],
window: [ObsRecord {
obs: Observation {
vertex: 0,
obs_type: 0,
flags: 0,
value: 0,
},
tick: 0,
}; WINDOW_SIZE],
}
}
pub fn add_hypothesis(&mut self, hypothesis: HypothesisState) -> bool {
if self.num_hypotheses as usize >= MAX_HYPOTHESES {
return false;
}
self.hypotheses[self.num_hypotheses as usize] = hypothesis;
self.num_hypotheses += 1;
true
}
pub fn add_connectivity_hypothesis(&mut self, vertex: TileVertexId) -> bool {
let id = self.num_hypotheses as u16;
self.add_hypothesis(HypothesisState::connectivity(id, vertex))
}
pub fn add_cut_hypothesis(&mut self, vertex: TileVertexId, threshold: TileVertexId) -> bool {
let id = self.num_hypotheses as u16;
self.add_hypothesis(HypothesisState::cut_membership(id, vertex, threshold))
}
#[inline]
pub fn process_observation(&mut self, obs: Observation, tick: u32) {
self.current_tick = tick;
self.total_obs += 1;
let idx = self.window_head as usize;
unsafe {
*self.window.get_unchecked_mut(idx) = ObsRecord { obs, tick };
}
self.window_head = ((self.window_head + 1) & (WINDOW_SIZE as u16 - 1));
if (self.window_count as usize) < WINDOW_SIZE {
self.window_count += 1;
}
let log_lr = self.compute_log_likelihood_ratio(&obs);
self.global_log_e = self.global_log_e.saturating_add(log_lr);
let num_hyp = self.num_hypotheses as usize;
for i in 0..num_hyp {
let hyp = unsafe { self.hypotheses.get_unchecked(i) };
if !hyp.can_update() {
continue;
}
let is_relevant = self.is_obs_relevant(hyp, &obs);
if is_relevant {
let hyp_mut = unsafe { self.hypotheses.get_unchecked_mut(i) };
if hyp_mut.update_with_log_lr(log_lr) {
self.rejected_count += 1;
self.status |= Self::STATUS_HAS_REJECTION;
}
}
}
if self.global_log_e > LOG_E_STRONG {
self.status |= Self::STATUS_SIGNIFICANT;
}
}
#[inline(always)]
fn is_obs_relevant(&self, hyp: &HypothesisState, obs: &Observation) -> bool {
match (hyp.hyp_type, obs.obs_type) {
(HypothesisState::TYPE_CONNECTIVITY, Observation::TYPE_CONNECTIVITY) => {
obs.vertex == hyp.target
}
(HypothesisState::TYPE_CUT, Observation::TYPE_CUT_MEMBERSHIP) => {
obs.vertex == hyp.target
}
(HypothesisState::TYPE_FLOW, Observation::TYPE_FLOW) => obs.vertex == hyp.target,
_ => false,
}
}
#[inline(always)]
fn compute_log_likelihood_ratio(&self, obs: &Observation) -> LogEValue {
match obs.obs_type {
Observation::TYPE_CONNECTIVITY => {
if obs.flags != 0 {
LOG_LR_CONNECTIVITY_POS } else {
LOG_LR_CONNECTIVITY_NEG }
}
Observation::TYPE_WITNESS => {
if obs.flags != 0 {
LOG_LR_WITNESS_POS } else {
LOG_LR_WITNESS_NEG }
}
Observation::TYPE_CUT_MEMBERSHIP => {
let confidence_fixed = (obs.value as i32) >> 1; confidence_fixed
}
Observation::TYPE_FLOW => {
let flow = (obs.value as f32) / 1000.0;
let lr = if flow > 0.5 {
1.0 + flow
} else {
1.0 / (1.0 + flow)
};
f32_to_log_e(lr)
}
_ => 0, }
}
#[inline]
fn compute_likelihood_ratio(&self, obs: &Observation) -> f32 {
match obs.obs_type {
Observation::TYPE_CONNECTIVITY => {
if obs.flags != 0 {
1.5
} else {
0.5
}
}
Observation::TYPE_CUT_MEMBERSHIP => {
let confidence = (obs.value as f32) / 65535.0;
1.0 + confidence
}
Observation::TYPE_FLOW => {
let flow = (obs.value as f32) / 1000.0;
if flow > 0.5 {
1.0 + flow
} else {
1.0 / (1.0 + flow)
}
}
Observation::TYPE_WITNESS => {
if obs.flags != 0 {
2.0
} else {
0.5
}
}
_ => 1.0,
}
}
#[inline(always)]
pub fn global_e_value(&self) -> f32 {
let log2_val = (self.global_log_e as f32) / 65536.0;
libm::exp2f(log2_val)
}
#[inline(always)]
pub fn has_rejection(&self) -> bool {
self.status & Self::STATUS_HAS_REJECTION != 0
}
#[inline(always)]
pub fn is_significant(&self) -> bool {
self.status & Self::STATUS_SIGNIFICANT != 0
}
pub fn reset(&mut self) {
for h in self.hypotheses[..self.num_hypotheses as usize].iter_mut() {
h.reset();
}
self.window_head = 0;
self.window_count = 0;
self.global_log_e = 0;
self.rejected_count = 0;
self.status = Self::STATUS_ACTIVE;
}
#[inline]
pub fn process_observation_batch(&mut self, observations: &[(Observation, u32)]) {
let batch_size = observations.len().min(64);
for &(obs, tick) in observations.iter().take(batch_size) {
self.process_observation(obs, tick);
}
}
#[inline]
pub fn aggregate_hypotheses_simd(&self) -> i64 {
let mut lanes = [0i64; 4];
let num_hyp = self.num_hypotheses as usize;
for i in 0..num_hyp {
let hyp = &self.hypotheses[i];
if hyp.is_active() {
lanes[i % 4] += hyp.log_e_value as i64;
}
}
lanes[0] + lanes[1] + lanes[2] + lanes[3]
}
#[inline(always)]
pub fn exceeds_threshold(&self, threshold_log: LogEValue) -> bool {
self.global_log_e > threshold_log
}
pub const fn memory_size() -> usize {
size_of::<Self>()
}
}
const _: () = assert!(
size_of::<HypothesisState>() == 16,
"HypothesisState must be 16 bytes"
);
const _: () = assert!(size_of::<ObsRecord>() == 12, "ObsRecord must be 12 bytes");
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_log_e_conversion() {
assert_eq!(f32_to_log_e(1.0), 0);
let log_2 = f32_to_log_e(2.0);
assert!((log_2 - 65536).abs() < 100);
let log_4 = f32_to_log_e(4.0);
assert!((log_4 - 131072).abs() < 100);
}
#[test]
fn test_hypothesis_state() {
let mut hyp = HypothesisState::new(0, HypothesisState::TYPE_CONNECTIVITY);
assert!(hyp.is_active());
assert!(!hyp.is_rejected());
assert_eq!(hyp.obs_count, 0);
for _ in 0..5 {
hyp.update(2.0);
}
assert_eq!(hyp.obs_count, 5);
assert!(hyp.e_value_approx() > 20.0); }
#[test]
fn test_hypothesis_rejection() {
let mut hyp = HypothesisState::new(0, HypothesisState::TYPE_CUT);
for _ in 0..10 {
if hyp.update(2.0) {
break;
}
}
assert!(hyp.is_rejected());
}
#[test]
fn test_accumulator_new() {
let acc = EvidenceAccumulator::new();
assert_eq!(acc.num_hypotheses, 0);
assert_eq!(acc.total_obs, 0);
assert!(!acc.has_rejection());
}
#[test]
fn test_add_hypothesis() {
let mut acc = EvidenceAccumulator::new();
assert!(acc.add_connectivity_hypothesis(5));
assert!(acc.add_cut_hypothesis(10, 15));
assert_eq!(acc.num_hypotheses, 2);
}
#[test]
fn test_process_observation() {
let mut acc = EvidenceAccumulator::new();
acc.add_connectivity_hypothesis(5);
for tick in 0..10 {
let obs = Observation::connectivity(5, true);
acc.process_observation(obs, tick);
}
assert_eq!(acc.total_obs, 10);
assert!(acc.global_e_value() > 1.0);
}
#[test]
fn test_sliding_window() {
let mut acc = EvidenceAccumulator::new();
for tick in 0..(WINDOW_SIZE as u32 + 10) {
let obs = Observation::connectivity(0, true);
acc.process_observation(obs, tick);
}
assert_eq!(acc.window_count, WINDOW_SIZE as u16);
}
#[test]
fn test_memory_size() {
let size = EvidenceAccumulator::memory_size();
assert!(size < 4096, "EvidenceAccumulator too large: {} bytes", size);
}
}