1#![allow(missing_docs)]
16
17use crate::delta::{Observation, TileVertexId};
18use core::mem::size_of;
19
20pub const MAX_HYPOTHESES: usize = 16;
22
23pub const WINDOW_SIZE: usize = 64;
25
26pub type LogEValue = i32;
29
30pub const LOG_E_STRONG: LogEValue = 282944;
37
38pub const LOG_E_VERY_STRONG: LogEValue = 436906;
40
41pub const LOG_LR_CONNECTIVITY_POS: LogEValue = 38550;
43
44pub const LOG_LR_CONNECTIVITY_NEG: LogEValue = -65536;
46
47pub const LOG_LR_WITNESS_POS: LogEValue = 65536;
49
50pub const LOG_LR_WITNESS_NEG: LogEValue = -65536;
52
53pub const FIXED_SCALE: i32 = 65536;
55
56#[inline]
75pub fn simd_aggregate_log_e(log_e_values: &[LogEValue]) -> i64 {
76 let mut lanes = [0i64; 4];
79
80 let chunks = log_e_values.chunks_exact(4);
82 let remainder = chunks.remainder();
83
84 for chunk in chunks {
85 lanes[0] += chunk[0] as i64;
87 lanes[1] += chunk[1] as i64;
88 lanes[2] += chunk[2] as i64;
89 lanes[3] += chunk[3] as i64;
90 }
91
92 for (i, &val) in remainder.iter().enumerate() {
94 lanes[i % 4] += val as i64;
95 }
96
97 lanes[0] + lanes[1] + lanes[2] + lanes[3]
99}
100
101#[inline]
106pub fn simd_aggregate_log_e_wide(log_e_values: &[LogEValue]) -> i64 {
107 let mut lanes = [0i64; 8];
109
110 let chunks = log_e_values.chunks_exact(8);
111 let remainder = chunks.remainder();
112
113 for chunk in chunks {
114 lanes[0] += chunk[0] as i64;
116 lanes[1] += chunk[1] as i64;
117 lanes[2] += chunk[2] as i64;
118 lanes[3] += chunk[3] as i64;
119 lanes[4] += chunk[4] as i64;
120 lanes[5] += chunk[5] as i64;
121 lanes[6] += chunk[6] as i64;
122 lanes[7] += chunk[7] as i64;
123 }
124
125 for (i, &val) in remainder.iter().enumerate() {
127 lanes[i % 8] += val as i64;
128 }
129
130 let sum_0_3 = lanes[0] + lanes[1] + lanes[2] + lanes[3];
132 let sum_4_7 = lanes[4] + lanes[5] + lanes[6] + lanes[7];
133 sum_0_3 + sum_4_7
134}
135
136#[inline]
153pub fn aggregate_tile_evidence(tile_log_e_values: &[LogEValue; 255]) -> i64 {
154 simd_aggregate_log_e(tile_log_e_values)
155}
156
157#[inline(always)]
161pub const fn log_e_to_f32(log_e: LogEValue) -> f32 {
162 let log2_val = (log_e as f32) / 65536.0;
166 log2_val
169}
170
171#[inline(always)]
175pub fn f32_to_log_e(e: f32) -> LogEValue {
176 if e <= 0.0 {
177 i32::MIN
178 } else if e == 1.0 {
179 0 } else if e == 2.0 {
181 FIXED_SCALE } else if e == 0.5 {
183 -FIXED_SCALE } else {
185 let log2_e = libm::log2f(e);
187 (log2_e * 65536.0) as i32
188 }
189}
190
191#[inline(always)]
196pub const fn log_lr_for_obs_type(obs_type: u8, flags: u8, value: u16) -> LogEValue {
197 match obs_type {
198 Observation::TYPE_CONNECTIVITY => {
199 if flags != 0 {
200 LOG_LR_CONNECTIVITY_POS
201 } else {
202 LOG_LR_CONNECTIVITY_NEG
203 }
204 }
205 Observation::TYPE_WITNESS => {
206 if flags != 0 {
207 LOG_LR_WITNESS_POS
208 } else {
209 LOG_LR_WITNESS_NEG
210 }
211 }
212 _ => 0,
214 }
215}
216
217#[derive(Debug, Clone, Copy)]
221#[repr(C, align(16))]
222pub struct HypothesisState {
223 pub log_e_value: LogEValue,
225 pub obs_count: u32,
227 pub id: u16,
229 pub target: TileVertexId,
231 pub threshold: TileVertexId,
233 pub hyp_type: u8,
235 pub flags: u8,
237}
238
239impl Default for HypothesisState {
240 #[inline]
241 fn default() -> Self {
242 Self::new(0, 0)
243 }
244}
245
246impl HypothesisState {
247 pub const FLAG_ACTIVE: u8 = 0x01;
249 pub const FLAG_REJECTED: u8 = 0x02;
251 pub const FLAG_STRONG: u8 = 0x04;
253 pub const FLAG_VERY_STRONG: u8 = 0x08;
255
256 pub const TYPE_CONNECTIVITY: u8 = 0;
258 pub const TYPE_CUT: u8 = 1;
260 pub const TYPE_FLOW: u8 = 2;
262
263 #[inline(always)]
265 pub const fn new(id: u16, hyp_type: u8) -> Self {
266 Self {
267 log_e_value: 0, obs_count: 0,
269 id,
270 target: 0,
271 threshold: 0,
272 hyp_type,
273 flags: Self::FLAG_ACTIVE,
274 }
275 }
276
277 #[inline(always)]
279 pub const fn connectivity(id: u16, vertex: TileVertexId) -> Self {
280 Self {
281 log_e_value: 0,
282 obs_count: 0,
283 id,
284 target: vertex,
285 threshold: 0,
286 hyp_type: Self::TYPE_CONNECTIVITY,
287 flags: Self::FLAG_ACTIVE,
288 }
289 }
290
291 #[inline(always)]
293 pub const fn cut_membership(id: u16, vertex: TileVertexId, threshold: TileVertexId) -> Self {
294 Self {
295 log_e_value: 0,
296 obs_count: 0,
297 id,
298 target: vertex,
299 threshold,
300 hyp_type: Self::TYPE_CUT,
301 flags: Self::FLAG_ACTIVE,
302 }
303 }
304
305 #[inline(always)]
309 pub const fn is_active(&self) -> bool {
310 self.flags & Self::FLAG_ACTIVE != 0
311 }
312
313 #[inline(always)]
315 pub const fn is_rejected(&self) -> bool {
316 self.flags & Self::FLAG_REJECTED != 0
317 }
318
319 #[inline(always)]
323 pub const fn can_update(&self) -> bool {
324 (self.flags & (Self::FLAG_ACTIVE | Self::FLAG_REJECTED)) == Self::FLAG_ACTIVE
326 }
327
328 #[inline(always)]
330 pub fn e_value_approx(&self) -> f32 {
331 let log2_val = (self.log_e_value as f32) / 65536.0;
332 libm::exp2f(log2_val)
333 }
334
335 #[inline]
340 pub fn update(&mut self, likelihood_ratio: f32) -> bool {
341 if !self.can_update() {
342 return self.is_rejected();
343 }
344
345 let log_lr = f32_to_log_e(likelihood_ratio);
347 self.update_with_log_lr(log_lr)
348 }
349
350 #[inline(always)]
355 pub fn update_with_log_lr(&mut self, log_lr: LogEValue) -> bool {
356 self.log_e_value = self.log_e_value.saturating_add(log_lr);
357 self.obs_count += 1;
358
359 if self.log_e_value > LOG_E_VERY_STRONG {
362 self.flags |= Self::FLAG_VERY_STRONG | Self::FLAG_STRONG;
363 } else if self.log_e_value > LOG_E_STRONG {
364 self.flags |= Self::FLAG_STRONG;
365 self.flags &= !Self::FLAG_VERY_STRONG;
366 } else {
367 self.flags &= !(Self::FLAG_STRONG | Self::FLAG_VERY_STRONG);
368 }
369
370 if self.log_e_value > LOG_E_STRONG {
372 self.flags |= Self::FLAG_REJECTED;
373 return true;
374 }
375
376 false
377 }
378
379 #[inline]
381 pub fn reset(&mut self) {
382 self.log_e_value = 0;
383 self.obs_count = 0;
384 self.flags = Self::FLAG_ACTIVE;
385 }
386}
387
388#[derive(Debug, Clone, Copy, Default)]
390#[repr(C)]
391pub struct ObsRecord {
392 pub obs: Observation,
394 pub tick: u32,
396}
397
398#[derive(Clone)]
402#[repr(C, align(64))]
403pub struct EvidenceAccumulator {
404 pub global_log_e: LogEValue,
407 pub total_obs: u32,
409 pub current_tick: u32,
411 pub window_head: u16,
413 pub window_count: u16,
415 pub num_hypotheses: u8,
417 pub _reserved: [u8; 1],
419 pub rejected_count: u16,
421 pub status: u16,
423 _hot_pad: [u8; 40],
425
426 pub hypotheses: [HypothesisState; MAX_HYPOTHESES],
429 pub window: [ObsRecord; WINDOW_SIZE],
431}
432
433impl Default for EvidenceAccumulator {
434 #[inline]
435 fn default() -> Self {
436 Self::new()
437 }
438}
439
440impl EvidenceAccumulator {
441 pub const STATUS_ACTIVE: u16 = 0x0001;
443 pub const STATUS_HAS_REJECTION: u16 = 0x0002;
445 pub const STATUS_SIGNIFICANT: u16 = 0x0004;
447
448 pub const fn new() -> Self {
450 Self {
451 global_log_e: 0,
452 total_obs: 0,
453 current_tick: 0,
454 window_head: 0,
455 window_count: 0,
456 num_hypotheses: 0,
457 _reserved: [0; 1],
458 rejected_count: 0,
459 status: Self::STATUS_ACTIVE,
460 _hot_pad: [0; 40],
461 hypotheses: [HypothesisState::new(0, 0); MAX_HYPOTHESES],
462 window: [ObsRecord {
463 obs: Observation {
464 vertex: 0,
465 obs_type: 0,
466 flags: 0,
467 value: 0,
468 },
469 tick: 0,
470 }; WINDOW_SIZE],
471 }
472 }
473
474 pub fn add_hypothesis(&mut self, hypothesis: HypothesisState) -> bool {
476 if self.num_hypotheses as usize >= MAX_HYPOTHESES {
477 return false;
478 }
479
480 self.hypotheses[self.num_hypotheses as usize] = hypothesis;
481 self.num_hypotheses += 1;
482 true
483 }
484
485 pub fn add_connectivity_hypothesis(&mut self, vertex: TileVertexId) -> bool {
487 let id = self.num_hypotheses as u16;
488 self.add_hypothesis(HypothesisState::connectivity(id, vertex))
489 }
490
491 pub fn add_cut_hypothesis(&mut self, vertex: TileVertexId, threshold: TileVertexId) -> bool {
493 let id = self.num_hypotheses as u16;
494 self.add_hypothesis(HypothesisState::cut_membership(id, vertex, threshold))
495 }
496
497 #[inline]
502 pub fn process_observation(&mut self, obs: Observation, tick: u32) {
503 self.current_tick = tick;
504 self.total_obs += 1;
505
506 let idx = self.window_head as usize;
509 unsafe {
511 *self.window.get_unchecked_mut(idx) = ObsRecord { obs, tick };
512 }
513 self.window_head = ((self.window_head + 1) & (WINDOW_SIZE as u16 - 1));
515 if (self.window_count as usize) < WINDOW_SIZE {
516 self.window_count += 1;
517 }
518
519 let log_lr = self.compute_log_likelihood_ratio(&obs);
522
523 self.global_log_e = self.global_log_e.saturating_add(log_lr);
525
526 let num_hyp = self.num_hypotheses as usize;
529 for i in 0..num_hyp {
530 let hyp = unsafe { self.hypotheses.get_unchecked(i) };
532
533 if !hyp.can_update() {
535 continue;
536 }
537
538 let is_relevant = self.is_obs_relevant(hyp, &obs);
541
542 if is_relevant {
543 let hyp_mut = unsafe { self.hypotheses.get_unchecked_mut(i) };
545 if hyp_mut.update_with_log_lr(log_lr) {
546 self.rejected_count += 1;
547 self.status |= Self::STATUS_HAS_REJECTION;
548 }
549 }
550 }
551
552 if self.global_log_e > LOG_E_STRONG {
554 self.status |= Self::STATUS_SIGNIFICANT;
555 }
556 }
557
558 #[inline(always)]
562 fn is_obs_relevant(&self, hyp: &HypothesisState, obs: &Observation) -> bool {
563 match (hyp.hyp_type, obs.obs_type) {
564 (HypothesisState::TYPE_CONNECTIVITY, Observation::TYPE_CONNECTIVITY) => {
565 obs.vertex == hyp.target
566 }
567 (HypothesisState::TYPE_CUT, Observation::TYPE_CUT_MEMBERSHIP) => {
568 obs.vertex == hyp.target
569 }
570 (HypothesisState::TYPE_FLOW, Observation::TYPE_FLOW) => obs.vertex == hyp.target,
571 _ => false,
572 }
573 }
574
575 #[inline(always)]
580 fn compute_log_likelihood_ratio(&self, obs: &Observation) -> LogEValue {
581 match obs.obs_type {
582 Observation::TYPE_CONNECTIVITY => {
583 if obs.flags != 0 {
585 LOG_LR_CONNECTIVITY_POS } else {
587 LOG_LR_CONNECTIVITY_NEG }
589 }
590 Observation::TYPE_WITNESS => {
591 if obs.flags != 0 {
593 LOG_LR_WITNESS_POS } else {
595 LOG_LR_WITNESS_NEG }
597 }
598 Observation::TYPE_CUT_MEMBERSHIP => {
599 let confidence_fixed = (obs.value as i32) >> 1; confidence_fixed
604 }
605 Observation::TYPE_FLOW => {
606 let flow = (obs.value as f32) / 1000.0;
608 let lr = if flow > 0.5 {
609 1.0 + flow
610 } else {
611 1.0 / (1.0 + flow)
612 };
613 f32_to_log_e(lr)
614 }
615 _ => 0, }
617 }
618
619 #[inline]
621 fn compute_likelihood_ratio(&self, obs: &Observation) -> f32 {
622 match obs.obs_type {
623 Observation::TYPE_CONNECTIVITY => {
624 if obs.flags != 0 { 1.5 } else { 0.5 }
625 }
626 Observation::TYPE_CUT_MEMBERSHIP => {
627 let confidence = (obs.value as f32) / 65535.0;
628 1.0 + confidence
629 }
630 Observation::TYPE_FLOW => {
631 let flow = (obs.value as f32) / 1000.0;
632 if flow > 0.5 { 1.0 + flow } else { 1.0 / (1.0 + flow) }
633 }
634 Observation::TYPE_WITNESS => {
635 if obs.flags != 0 { 2.0 } else { 0.5 }
636 }
637 _ => 1.0,
638 }
639 }
640
641 #[inline(always)]
643 pub fn global_e_value(&self) -> f32 {
644 let log2_val = (self.global_log_e as f32) / 65536.0;
645 libm::exp2f(log2_val)
646 }
647
648 #[inline(always)]
650 pub fn has_rejection(&self) -> bool {
651 self.status & Self::STATUS_HAS_REJECTION != 0
652 }
653
654 #[inline(always)]
656 pub fn is_significant(&self) -> bool {
657 self.status & Self::STATUS_SIGNIFICANT != 0
658 }
659
660 pub fn reset(&mut self) {
662 for h in self.hypotheses[..self.num_hypotheses as usize].iter_mut() {
663 h.reset();
664 }
665 self.window_head = 0;
666 self.window_count = 0;
667 self.global_log_e = 0;
668 self.rejected_count = 0;
669 self.status = Self::STATUS_ACTIVE;
670 }
671
672 #[inline]
680 pub fn process_observation_batch(&mut self, observations: &[(Observation, u32)]) {
681 let batch_size = observations.len().min(64);
684
685 for &(obs, tick) in observations.iter().take(batch_size) {
687 self.process_observation(obs, tick);
688 }
689 }
690
691 #[inline]
699 pub fn aggregate_hypotheses_simd(&self) -> i64 {
700 let mut lanes = [0i64; 4];
701 let num_hyp = self.num_hypotheses as usize;
702
703 for i in 0..num_hyp {
705 let hyp = &self.hypotheses[i];
706 if hyp.is_active() {
707 lanes[i % 4] += hyp.log_e_value as i64;
708 }
709 }
710
711 lanes[0] + lanes[1] + lanes[2] + lanes[3]
712 }
713
714 #[inline(always)]
725 pub fn exceeds_threshold(&self, threshold_log: LogEValue) -> bool {
726 self.global_log_e > threshold_log
727 }
728
729 pub const fn memory_size() -> usize {
731 size_of::<Self>()
732 }
733}
734
735const _: () = assert!(
737 size_of::<HypothesisState>() == 16,
738 "HypothesisState must be 16 bytes"
739);
740const _: () = assert!(size_of::<ObsRecord>() == 12, "ObsRecord must be 12 bytes");
741
742#[cfg(test)]
743mod tests {
744 use super::*;
745
746 #[test]
747 fn test_log_e_conversion() {
748 assert_eq!(f32_to_log_e(1.0), 0);
750
751 let log_2 = f32_to_log_e(2.0);
753 assert!((log_2 - 65536).abs() < 100);
754
755 let log_4 = f32_to_log_e(4.0);
757 assert!((log_4 - 131072).abs() < 100);
758 }
759
760 #[test]
761 fn test_hypothesis_state() {
762 let mut hyp = HypothesisState::new(0, HypothesisState::TYPE_CONNECTIVITY);
763 assert!(hyp.is_active());
764 assert!(!hyp.is_rejected());
765 assert_eq!(hyp.obs_count, 0);
766
767 for _ in 0..5 {
769 hyp.update(2.0);
770 }
771 assert_eq!(hyp.obs_count, 5);
772 assert!(hyp.e_value_approx() > 20.0); }
774
775 #[test]
776 fn test_hypothesis_rejection() {
777 let mut hyp = HypothesisState::new(0, HypothesisState::TYPE_CUT);
778
779 for _ in 0..10 {
781 if hyp.update(2.0) {
782 break;
783 }
784 }
785
786 assert!(hyp.is_rejected());
787 }
788
789 #[test]
790 fn test_accumulator_new() {
791 let acc = EvidenceAccumulator::new();
792 assert_eq!(acc.num_hypotheses, 0);
793 assert_eq!(acc.total_obs, 0);
794 assert!(!acc.has_rejection());
795 }
796
797 #[test]
798 fn test_add_hypothesis() {
799 let mut acc = EvidenceAccumulator::new();
800 assert!(acc.add_connectivity_hypothesis(5));
801 assert!(acc.add_cut_hypothesis(10, 15));
802 assert_eq!(acc.num_hypotheses, 2);
803 }
804
805 #[test]
806 fn test_process_observation() {
807 let mut acc = EvidenceAccumulator::new();
808 acc.add_connectivity_hypothesis(5);
809
810 for tick in 0..10 {
812 let obs = Observation::connectivity(5, true);
813 acc.process_observation(obs, tick);
814 }
815
816 assert_eq!(acc.total_obs, 10);
817 assert!(acc.global_e_value() > 1.0);
818 }
819
820 #[test]
821 fn test_sliding_window() {
822 let mut acc = EvidenceAccumulator::new();
823
824 for tick in 0..(WINDOW_SIZE as u32 + 10) {
826 let obs = Observation::connectivity(0, true);
827 acc.process_observation(obs, tick);
828 }
829
830 assert_eq!(acc.window_count, WINDOW_SIZE as u16);
831 }
832
833 #[test]
834 fn test_memory_size() {
835 let size = EvidenceAccumulator::memory_size();
836 assert!(size < 4096, "EvidenceAccumulator too large: {} bytes", size);
838 }
839}