use crate::error::ExtractionError;
use crate::neutron::{Neutron, NeutronBatch};
#[derive(Clone, Debug)]
pub struct ExtractionConfig {
pub super_resolution_factor: f64,
pub weighted_by_tot: bool,
pub min_tot_threshold: u16,
}
impl Default for ExtractionConfig {
fn default() -> Self {
Self {
super_resolution_factor: 8.0,
weighted_by_tot: true,
min_tot_threshold: 10,
}
}
}
impl ExtractionConfig {
#[must_use]
pub fn venus_defaults() -> Self {
Self::default()
}
#[must_use]
pub fn with_super_resolution(mut self, factor: f64) -> Self {
self.super_resolution_factor = factor;
self
}
#[must_use]
pub fn with_weighted_by_tot(mut self, weighted: bool) -> Self {
self.weighted_by_tot = weighted;
self
}
#[must_use]
pub fn with_min_tot_threshold(mut self, threshold: u16) -> Self {
self.min_tot_threshold = threshold;
self
}
}
pub trait NeutronExtraction: Send + Sync {
fn name(&self) -> &'static str;
fn configure(&mut self, config: ExtractionConfig);
fn config(&self) -> &ExtractionConfig;
fn extract_soa(
&self,
batch: &crate::soa::HitBatch,
num_clusters: usize,
) -> Result<Vec<Neutron>, ExtractionError>;
}
#[derive(Clone, Debug, Default)]
struct ClusterAccumulator {
sum_x: f64,
sum_y: f64,
raw_sum_x: f64,
raw_sum_y: f64,
sum_tot: u64,
count: u32,
max_tot: u16,
rep_tof: u32,
rep_chip: u8,
}
#[derive(Clone, Debug, Default)]
pub struct SimpleCentroidExtraction {
config: ExtractionConfig,
}
impl SimpleCentroidExtraction {
#[must_use]
pub fn new() -> Self {
Self {
config: ExtractionConfig::default(),
}
}
#[must_use]
pub fn with_config(config: ExtractionConfig) -> Self {
Self { config }
}
}
impl NeutronExtraction for SimpleCentroidExtraction {
fn name(&self) -> &'static str {
"SimpleCentroid"
}
fn configure(&mut self, config: ExtractionConfig) {
self.config = config;
}
fn config(&self) -> &ExtractionConfig {
&self.config
}
fn extract_soa(
&self,
batch: &crate::soa::HitBatch,
num_clusters: usize,
) -> Result<Vec<Neutron>, ExtractionError> {
let mut accumulators = vec![ClusterAccumulator::default(); num_clusters];
if self.config.weighted_by_tot {
accumulate_weighted(
&mut accumulators,
batch,
num_clusters,
self.config.min_tot_threshold,
);
Ok(build_neutrons_weighted(
accumulators,
self.config.super_resolution_factor,
))
} else {
accumulate_unweighted(
&mut accumulators,
batch,
num_clusters,
self.config.min_tot_threshold,
);
Ok(build_neutrons_unweighted(
accumulators,
self.config.super_resolution_factor,
))
}
}
}
impl SimpleCentroidExtraction {
pub fn extract_soa_batch(
&self,
batch: &crate::soa::HitBatch,
num_clusters: usize,
) -> Result<NeutronBatch, ExtractionError> {
let mut accumulators = vec![ClusterAccumulator::default(); num_clusters];
if self.config.weighted_by_tot {
accumulate_weighted(
&mut accumulators,
batch,
num_clusters,
self.config.min_tot_threshold,
);
Ok(build_neutron_batch_weighted(
accumulators,
self.config.super_resolution_factor,
))
} else {
accumulate_unweighted(
&mut accumulators,
batch,
num_clusters,
self.config.min_tot_threshold,
);
Ok(build_neutron_batch_unweighted(
accumulators,
self.config.super_resolution_factor,
))
}
}
}
#[inline]
fn cluster_index(label: i32, num_clusters: usize) -> Option<usize> {
if label < 0 {
return None;
}
let idx = usize::try_from(label).ok()?;
if idx >= num_clusters {
None
} else {
Some(idx)
}
}
fn accumulate_weighted(
accumulators: &mut [ClusterAccumulator],
batch: &crate::soa::HitBatch,
num_clusters: usize,
min_tot: u16,
) {
let labels = &batch.cluster_id;
let x_values = &batch.x;
let y_values = &batch.y;
let time_over_threshold = &batch.tot;
let time_of_flight = &batch.tof;
let chip_ids = &batch.chip_id;
if min_tot > 0 {
for i in 0..labels.len() {
let Some(cluster_idx) = cluster_index(labels[i], num_clusters) else {
continue;
};
let tot = time_over_threshold[i];
if tot < min_tot {
continue;
}
let acc = &mut accumulators[cluster_idx];
let x = f64::from(x_values[i]);
let y = f64::from(y_values[i]);
let weight = f64::from(tot);
acc.count += 1;
acc.sum_tot += u64::from(tot);
acc.raw_sum_x += x;
acc.raw_sum_y += y;
acc.sum_x += x * weight;
acc.sum_y += y * weight;
if tot >= acc.max_tot {
acc.max_tot = tot;
acc.rep_tof = time_of_flight[i];
acc.rep_chip = chip_ids[i];
}
}
} else {
for i in 0..labels.len() {
let Some(cluster_idx) = cluster_index(labels[i], num_clusters) else {
continue;
};
let tot = time_over_threshold[i];
let acc = &mut accumulators[cluster_idx];
let x = f64::from(x_values[i]);
let y = f64::from(y_values[i]);
let weight = f64::from(tot);
acc.count += 1;
acc.sum_tot += u64::from(tot);
acc.raw_sum_x += x;
acc.raw_sum_y += y;
acc.sum_x += x * weight;
acc.sum_y += y * weight;
if tot >= acc.max_tot {
acc.max_tot = tot;
acc.rep_tof = time_of_flight[i];
acc.rep_chip = chip_ids[i];
}
}
}
}
fn accumulate_unweighted(
accumulators: &mut [ClusterAccumulator],
batch: &crate::soa::HitBatch,
num_clusters: usize,
min_tot: u16,
) {
let labels = &batch.cluster_id;
let x_values = &batch.x;
let y_values = &batch.y;
let time_over_threshold = &batch.tot;
let time_of_flight = &batch.tof;
let chip_ids = &batch.chip_id;
if min_tot > 0 {
for i in 0..labels.len() {
let Some(cluster_idx) = cluster_index(labels[i], num_clusters) else {
continue;
};
let tot = time_over_threshold[i];
if tot < min_tot {
continue;
}
let acc = &mut accumulators[cluster_idx];
let x = f64::from(x_values[i]);
let y = f64::from(y_values[i]);
acc.count += 1;
acc.sum_tot += u64::from(tot);
acc.raw_sum_x += x;
acc.raw_sum_y += y;
if tot >= acc.max_tot {
acc.max_tot = tot;
acc.rep_tof = time_of_flight[i];
acc.rep_chip = chip_ids[i];
}
}
} else {
for i in 0..labels.len() {
let Some(cluster_idx) = cluster_index(labels[i], num_clusters) else {
continue;
};
let tot = time_over_threshold[i];
let acc = &mut accumulators[cluster_idx];
let x = f64::from(x_values[i]);
let y = f64::from(y_values[i]);
acc.count += 1;
acc.sum_tot += u64::from(tot);
acc.raw_sum_x += x;
acc.raw_sum_y += y;
if tot >= acc.max_tot {
acc.max_tot = tot;
acc.rep_tof = time_of_flight[i];
acc.rep_chip = chip_ids[i];
}
}
}
}
fn sum_tot_as_f64(sum_tot: u64) -> f64 {
let clamped = sum_tot.min(u64::from(u32::MAX));
f64::from(u32::try_from(clamped).unwrap_or(u32::MAX))
}
fn build_neutrons_weighted(accumulators: Vec<ClusterAccumulator>, scale: f64) -> Vec<Neutron> {
let mut neutrons = Vec::with_capacity(accumulators.len());
for acc in accumulators {
if acc.count == 0 {
continue;
}
let (centroid_x, centroid_y) = if acc.sum_tot > 0 {
let sum_weight = sum_tot_as_f64(acc.sum_tot);
(acc.sum_x / sum_weight, acc.sum_y / sum_weight)
} else {
(
acc.raw_sum_x / f64::from(acc.count),
acc.raw_sum_y / f64::from(acc.count),
)
};
let scaled_x = centroid_x * scale;
let scaled_y = centroid_y * scale;
neutrons.push(Neutron::new(
scaled_x,
scaled_y,
acc.rep_tof,
u16::try_from(acc.sum_tot.min(u64::from(u16::MAX))).unwrap_or(u16::MAX),
u16::try_from(acc.count).unwrap_or(u16::MAX),
acc.rep_chip,
));
}
neutrons
}
fn build_neutrons_unweighted(accumulators: Vec<ClusterAccumulator>, scale: f64) -> Vec<Neutron> {
let mut neutrons = Vec::with_capacity(accumulators.len());
for acc in accumulators {
if acc.count == 0 {
continue;
}
let centroid_x = acc.raw_sum_x / f64::from(acc.count);
let centroid_y = acc.raw_sum_y / f64::from(acc.count);
let scaled_x = centroid_x * scale;
let scaled_y = centroid_y * scale;
neutrons.push(Neutron::new(
scaled_x,
scaled_y,
acc.rep_tof,
u16::try_from(acc.sum_tot.min(u64::from(u16::MAX))).unwrap_or(u16::MAX),
u16::try_from(acc.count).unwrap_or(u16::MAX),
acc.rep_chip,
));
}
neutrons
}
fn build_neutron_batch_weighted(accumulators: Vec<ClusterAccumulator>, scale: f64) -> NeutronBatch {
let mut batch = NeutronBatch::with_capacity(accumulators.len());
for acc in accumulators {
if acc.count == 0 {
continue;
}
let (centroid_x, centroid_y) = if acc.sum_tot > 0 {
let sum_weight = sum_tot_as_f64(acc.sum_tot);
(acc.sum_x / sum_weight, acc.sum_y / sum_weight)
} else {
(
acc.raw_sum_x / f64::from(acc.count),
acc.raw_sum_y / f64::from(acc.count),
)
};
batch.push(Neutron::new(
centroid_x * scale,
centroid_y * scale,
acc.rep_tof,
u16::try_from(acc.sum_tot.min(u64::from(u16::MAX))).unwrap_or(u16::MAX),
u16::try_from(acc.count).unwrap_or(u16::MAX),
acc.rep_chip,
));
}
batch
}
fn build_neutron_batch_unweighted(
accumulators: Vec<ClusterAccumulator>,
scale: f64,
) -> NeutronBatch {
let mut batch = NeutronBatch::with_capacity(accumulators.len());
for acc in accumulators {
if acc.count == 0 {
continue;
}
let centroid_x = acc.raw_sum_x / f64::from(acc.count);
let centroid_y = acc.raw_sum_y / f64::from(acc.count);
batch.push(Neutron::new(
centroid_x * scale,
centroid_y * scale,
acc.rep_tof,
u16::try_from(acc.sum_tot.min(u64::from(u16::MAX))).unwrap_or(u16::MAX),
u16::try_from(acc.count).unwrap_or(u16::MAX),
acc.rep_chip,
));
}
batch
}
#[cfg(test)]
mod tests {
use super::*;
use crate::soa::HitBatch;
fn make_batch(hits: &[(u32, u16, u16, u32, u16, u8, i32)]) -> HitBatch {
let mut batch = HitBatch::with_capacity(hits.len());
for (i, (tof, x, y, timestamp, tot, chip_id, cluster_id)) in hits.iter().enumerate() {
batch.push((*x, *y, *tof, *tot, *timestamp, *chip_id));
batch.cluster_id[i] = *cluster_id;
}
batch
}
#[test]
fn test_single_hit_extraction() {
let batch = make_batch(&[(1000, 100, 200, 500, 50, 0, 0)]);
let extractor = SimpleCentroidExtraction::new();
let neutrons = extractor.extract_soa(&batch, 1).unwrap();
assert_eq!(neutrons.len(), 1);
assert!((neutrons[0].x - 800.0).abs() < f64::EPSILON); assert!((neutrons[0].y - 1600.0).abs() < f64::EPSILON); assert_eq!(neutrons[0].tof, 1000);
assert_eq!(neutrons[0].n_hits, 1);
}
#[test]
fn test_weighted_centroid() {
let batch = make_batch(&[
(1000, 0, 0, 500, 30, 0, 0), (1000, 2, 0, 500, 10, 0, 0), ]);
let extractor = SimpleCentroidExtraction::new();
let neutrons = extractor.extract_soa(&batch, 1).unwrap();
assert_eq!(neutrons.len(), 1);
assert!((neutrons[0].x - 4.0).abs() < 0.01);
assert_eq!(neutrons[0].n_hits, 2);
assert_eq!(neutrons[0].tot, 40);
}
#[test]
fn test_multiple_clusters() {
let batch = make_batch(&[
(1000, 10, 10, 500, 50, 0, 0),
(1000, 11, 10, 500, 50, 0, 0),
(2000, 100, 100, 500, 50, 1, 1),
]);
let extractor = SimpleCentroidExtraction::new();
let neutrons = extractor.extract_soa(&batch, 2).unwrap();
assert_eq!(neutrons.len(), 2);
assert_eq!(neutrons[0].n_hits, 2);
assert_eq!(neutrons[1].n_hits, 1);
}
#[test]
fn test_tot_threshold_filters_low_tot_hits() {
let batch = make_batch(&[
(1000, 0, 0, 500, 5, 0, 0), (1000, 10, 0, 500, 15, 0, 0), (1000, 20, 0, 500, 20, 0, 0), ]);
let extractor = SimpleCentroidExtraction::new();
let neutrons = extractor.extract_soa(&batch, 1).unwrap();
assert_eq!(neutrons.len(), 1);
assert_eq!(neutrons[0].n_hits, 2);
assert_eq!(neutrons[0].tot, 35);
assert!((neutrons[0].x - 125.71).abs() < 0.1);
}
#[test]
fn test_tot_threshold_skips_empty_clusters_after_filtering() {
let batch = make_batch(&[
(1000, 0, 0, 500, 5, 0, 0), (1000, 1, 0, 500, 8, 0, 0), ]);
let extractor = SimpleCentroidExtraction::new();
let neutrons = extractor.extract_soa(&batch, 1).unwrap();
assert_eq!(neutrons.len(), 0);
}
#[test]
fn test_tot_threshold_disabled_when_zero() {
let batch = make_batch(&[
(1000, 0, 0, 500, 5, 0, 0), (1000, 10, 0, 500, 3, 0, 0), ]);
let mut extractor = SimpleCentroidExtraction::new();
extractor.configure(ExtractionConfig::default().with_min_tot_threshold(0));
let neutrons = extractor.extract_soa(&batch, 1).unwrap();
assert_eq!(neutrons.len(), 1);
assert_eq!(neutrons[0].n_hits, 2);
assert_eq!(neutrons[0].tot, 8); }
#[test]
fn test_representative_tof_from_max_tot_after_filtering() {
let batch = make_batch(&[
(1000, 0, 0, 500, 5, 0, 0), (2000, 10, 0, 500, 15, 0, 0), (3000, 20, 0, 500, 25, 0, 0), ]);
let extractor = SimpleCentroidExtraction::new();
let neutrons = extractor.extract_soa(&batch, 1).unwrap();
assert_eq!(neutrons.len(), 1);
assert_eq!(neutrons[0].tof, 3000);
assert_ne!(neutrons[0].tof, 1000);
}
#[test]
fn test_zero_tot_weighted_centroid() {
let batch = make_batch(&[
(1000, 10, 20, 500, 0, 0, 0), (1000, 30, 40, 500, 0, 0, 0), ]);
let mut extractor = SimpleCentroidExtraction::new();
extractor.configure(ExtractionConfig::default().with_min_tot_threshold(0));
let neutrons = extractor.extract_soa(&batch, 1).unwrap();
assert_eq!(neutrons.len(), 1);
assert!((neutrons[0].x - 160.0).abs() < 0.01);
assert!((neutrons[0].y - 240.0).abs() < 0.01);
assert_eq!(neutrons[0].tot, 0);
assert_eq!(neutrons[0].n_hits, 2);
assert!(!neutrons[0].x.is_nan());
assert!(!neutrons[0].y.is_nan());
}
#[test]
fn test_extract_soa_expected_values() {
let mut batch = HitBatch::with_capacity(3);
batch.push((10, 10, 1000, 20, 500, 0));
batch.push((20, 10, 1500, 10, 500, 0));
batch.push((5, 7, 2000, 15, 500, 1));
batch.cluster_id[0] = 0;
batch.cluster_id[1] = 0;
batch.cluster_id[2] = 1;
let extractor = SimpleCentroidExtraction::new();
let neutrons = extractor.extract_soa(&batch, 2).unwrap();
assert_eq!(neutrons.len(), 2);
let n0 = &neutrons[0];
let expected_x = (10.0 * 20.0 + 20.0 * 10.0) / 30.0 * 8.0;
let expected_y = 10.0 * 8.0;
assert!((n0.x - expected_x).abs() < 1e-6);
assert!((n0.y - expected_y).abs() < 1e-6);
assert_eq!(n0.tof, 1000);
assert_eq!(n0.tot, 30);
assert_eq!(n0.n_hits, 2);
assert_eq!(n0.chip_id, 0);
let n1 = &neutrons[1];
assert!((n1.x - 40.0).abs() < 1e-6);
assert!((n1.y - 56.0).abs() < 1e-6);
assert_eq!(n1.tof, 2000);
assert_eq!(n1.tot, 15);
assert_eq!(n1.n_hits, 1);
assert_eq!(n1.chip_id, 1);
}
#[test]
fn test_super_resolution_factor_affects_output() {
let batch = make_batch(&[(1000, 2, 3, 500, 20, 0, 0)]);
let mut extractor = SimpleCentroidExtraction::new();
extractor.configure(
ExtractionConfig::default()
.with_super_resolution(1.0)
.with_min_tot_threshold(0),
);
let neutrons = extractor.extract_soa(&batch, 1).unwrap();
assert_eq!(neutrons.len(), 1);
assert!((neutrons[0].x - 2.0).abs() < f64::EPSILON);
assert!((neutrons[0].y - 3.0).abs() < f64::EPSILON);
extractor.configure(
ExtractionConfig::default()
.with_super_resolution(4.0)
.with_min_tot_threshold(0),
);
let neutrons = extractor.extract_soa(&batch, 1).unwrap();
assert_eq!(neutrons.len(), 1);
assert!((neutrons[0].x - 8.0).abs() < f64::EPSILON);
assert!((neutrons[0].y - 12.0).abs() < f64::EPSILON);
}
}