#![cfg(feature = "std")]
use std::vec::Vec;
use crate::consensus::ConsensusCell;
use crate::fixed::Q16;
use crate::motif::MotifClass;
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
pub struct CandidateConfig {
pub min_detector_count: u32,
pub min_residual_q_raw: i32,
pub min_length_windows: u32,
}
impl CandidateConfig {
pub const CANONICAL: Self = Self {
min_detector_count: 2,
min_residual_q_raw: 3 * 65_536, min_length_windows: 1,
};
}
#[repr(C)]
#[derive(Copy, Clone, Eq, PartialEq, Debug, Default)]
pub struct CandidateInterval {
pub entity_id: u32,
pub start_window: u32,
pub end_window: u32,
pub length_windows: u32,
pub union_mask: u32,
pub peak_residual_q: Q16,
pub peak_drift_q: Q16,
pub peak_slew_q: Q16,
pub peak_temporal_q: Q16,
pub peak_consensus_q: Q16,
pub entity_avg_q: Q16,
pub grid_avg_q: Q16,
}
impl CandidateInterval {
#[must_use]
pub const fn covers(&self, class: MotifClass) -> bool {
(self.union_mask & class.bit_mask()) != 0
}
}
#[inline]
const fn flat(entity_id: u32, window_idx: u32, n_windows: u32) -> usize {
(entity_id * n_windows + window_idx) as usize
}
fn is_interesting(cell: &ConsensusCell, config: &CandidateConfig) -> bool {
let count_ok = cell.detector_count >= config.min_detector_count;
let residual_ok = cell.axis1_residual_q.raw() >= config.min_residual_q_raw;
count_ok || residual_ok
}
#[must_use]
pub fn prepare(
consensus: &[ConsensusCell],
n_windows: u32,
n_entities: u32,
config: &CandidateConfig,
) -> Vec<CandidateInterval> {
let mut out: Vec<CandidateInterval> = Vec::new();
for entity_id in 0..n_entities {
let mut run_start: Option<u32> = None;
let mut acc = CandidateInterval {
entity_id,
start_window: 0,
end_window: 0,
length_windows: 0,
union_mask: 0,
peak_residual_q: Q16::ZERO,
peak_drift_q: Q16::ZERO,
peak_slew_q: Q16::ZERO,
peak_temporal_q: Q16::ZERO,
peak_consensus_q: Q16::ZERO,
entity_avg_q: Q16::ZERO,
grid_avg_q: Q16::ZERO,
};
for window_idx in 0..n_windows {
let idx = flat(entity_id, window_idx, n_windows);
let cell = &consensus[idx];
let interesting = is_interesting(cell, config);
if interesting {
if run_start.is_none() {
run_start = Some(window_idx);
acc = CandidateInterval {
entity_id,
start_window: window_idx,
end_window: window_idx + 1,
length_windows: 1,
union_mask: 0,
peak_residual_q: Q16::ZERO,
peak_drift_q: Q16::ZERO,
peak_slew_q: Q16::ZERO,
peak_temporal_q: Q16::ZERO,
peak_consensus_q: Q16::ZERO,
entity_avg_q: Q16::ZERO,
grid_avg_q: Q16::ZERO,
};
} else {
acc.end_window = window_idx + 1;
acc.length_windows = acc.end_window - acc.start_window;
}
acc.union_mask |= 0u32; acc.peak_residual_q = peak(acc.peak_residual_q, cell.axis1_residual_q);
acc.peak_drift_q = peak(acc.peak_drift_q, cell.axis2_drift_q);
acc.peak_slew_q = peak(acc.peak_slew_q, cell.axis3_slew_q);
acc.peak_temporal_q = peak(acc.peak_temporal_q, cell.axis4_temporal_q);
acc.peak_consensus_q = peak(acc.peak_consensus_q, cell.axis7_consensus_q);
} else if run_start.is_some() {
if acc.length_windows >= config.min_length_windows {
out.push(acc);
}
run_start = None;
}
}
if run_start.is_some() && acc.length_windows >= config.min_length_windows {
out.push(acc);
}
}
out
}
#[must_use]
pub fn prepare_with_detectors(
consensus: &[ConsensusCell],
detector_masks: &[u32],
n_windows: u32,
n_entities: u32,
config: &CandidateConfig,
) -> Vec<CandidateInterval> {
debug_assert_eq!(consensus.len(), detector_masks.len());
let mut intervals = prepare(consensus, n_windows, n_entities, config);
for interval in &mut intervals {
let mut mask = 0u32;
let mut entity_sum: i64 = 0;
let mut grid_sum: i64 = 0;
let mut grid_count: i64 = 0;
for w in interval.start_window..interval.end_window {
mask |= detector_masks[flat(interval.entity_id, w, n_windows)];
for entity_id in 0..n_entities {
let idx = flat(entity_id, w, n_windows);
let q = i64::from(consensus[idx].axis7_consensus_q.raw());
grid_sum += q;
grid_count += 1;
if entity_id == interval.entity_id {
entity_sum += q;
}
}
}
let span = i64::from(interval.end_window - interval.start_window).max(1);
let entity_avg_raw = (entity_sum / span) as i32;
let grid_avg_raw = if grid_count > 0 {
(grid_sum / grid_count) as i32
} else {
0
};
interval.union_mask = mask;
interval.entity_avg_q = Q16::from_raw(entity_avg_raw);
interval.grid_avg_q = Q16::from_raw(grid_avg_raw);
}
intervals
}
#[inline]
fn peak(a: Q16, b: Q16) -> Q16 {
if b.raw() > a.raw() {
b
} else {
a
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::consensus::form as consensus_form;
use crate::detector::{evaluate as detector_evaluate, DetectorThresholds};
use crate::fixture::{synthesize, DEFAULT_SEED, N_ENTITIES, N_WINDOWS, WINDOW_SIZE_NS};
use crate::residual::{compute as residual_compute, Baseline};
use crate::sign::compute as sign_compute;
use crate::window::compute_features;
const ALPHA: Q16 = Q16::from_raw(0x2000);
fn full_pipeline() -> (Vec<ConsensusCell>, Vec<u32>) {
let events = synthesize(DEFAULT_SEED);
let features = compute_features(&events, N_WINDOWS, N_ENTITIES, WINDOW_SIZE_NS);
let residuals = residual_compute(&features, &Baseline::CANONICAL);
let signs = sign_compute(&residuals, ALPHA, N_WINDOWS, N_ENTITIES);
let detectors = detector_evaluate(
&residuals,
&signs,
&DetectorThresholds::CANONICAL,
N_WINDOWS,
N_ENTITIES,
);
let consensus = consensus_form(&signs, &detectors, N_WINDOWS, N_ENTITIES);
let masks: Vec<u32> = detectors.iter().map(|d| d.detector_mask).collect();
(consensus, masks)
}
#[test]
fn candidate_extraction_is_deterministic() {
let (consensus, masks) = full_pipeline();
let a = prepare_with_detectors(
&consensus,
&masks,
N_WINDOWS,
N_ENTITIES,
&CandidateConfig::CANONICAL,
);
let b = prepare_with_detectors(
&consensus,
&masks,
N_WINDOWS,
N_ENTITIES,
&CandidateConfig::CANONICAL,
);
assert_eq!(a, b);
}
#[test]
fn ramp_episode_yields_a_candidate_on_entity_three() {
let (consensus, masks) = full_pipeline();
let intervals = prepare_with_detectors(
&consensus,
&masks,
N_WINDOWS,
N_ENTITIES,
&CandidateConfig::CANONICAL,
);
let any_ramp = intervals
.iter()
.any(|c| c.entity_id == 3 && c.start_window <= 25 && c.end_window >= 30);
assert!(
any_ramp,
"no ramp candidate found among {} intervals",
intervals.len()
);
}
#[test]
fn burst_episode_yields_a_candidate_on_entity_seven() {
let (consensus, masks) = full_pipeline();
let intervals = prepare_with_detectors(
&consensus,
&masks,
N_WINDOWS,
N_ENTITIES,
&CandidateConfig::CANONICAL,
);
let any_burst = intervals
.iter()
.any(|c| c.entity_id == 7 && c.start_window <= 62 && c.end_window >= 65);
assert!(any_burst, "no burst candidate found");
}
#[test]
fn shock_episode_yields_a_candidate_on_entity_eleven() {
let (consensus, masks) = full_pipeline();
let intervals = prepare_with_detectors(
&consensus,
&masks,
N_WINDOWS,
N_ENTITIES,
&CandidateConfig::CANONICAL,
);
let any_shock = intervals
.iter()
.any(|c| c.entity_id == 11 && c.start_window <= 90 && c.end_window >= 91);
assert!(any_shock, "no shock candidate found");
}
#[test]
fn candidates_carry_the_union_mask() {
let (consensus, masks) = full_pipeline();
let intervals = prepare_with_detectors(
&consensus,
&masks,
N_WINDOWS,
N_ENTITIES,
&CandidateConfig::CANONICAL,
);
for interval in &intervals {
if interval.entity_id == 3 && interval.start_window <= 25 && interval.end_window >= 30 {
assert!(interval.covers(MotifClass::ResidualSpike));
assert!(interval.covers(MotifClass::DriftRamp));
return;
}
}
panic!("ramp interval not found");
}
}