use ndarray::{ArrayView1, ArrayView2};
#[derive(Clone, Debug)]
pub struct TopSSelector {
heap: Vec<(u32, f32, f32)>,
capacity: usize,
}
impl TopSSelector {
pub fn new(capacity: usize) -> Self {
Self {
heap: Vec::with_capacity(capacity.max(1)),
capacity: capacity.max(1),
}
}
#[inline]
pub fn offer(&mut self, atom: u32, score: f32) {
let mag = score.abs();
if self.heap.len() < self.capacity {
self.heap.push((atom, score, mag));
return;
}
let mut worst = 0usize;
for k in 1..self.heap.len() {
if self.heap[k].2 < self.heap[worst].2
|| (self.heap[k].2 == self.heap[worst].2 && self.heap[k].0 > self.heap[worst].0)
{
worst = k;
}
}
let (w_atom, _, w_mag) = self.heap[worst];
if mag > w_mag || (mag == w_mag && atom < w_atom) {
self.heap[worst] = (atom, score, mag);
}
}
pub fn finish(mut self) -> Vec<(u32, f32)> {
self.heap.sort_by(|a, b| {
b.2.partial_cmp(&a.2)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.0.cmp(&b.0))
});
self.heap.into_iter().map(|(a, s, _)| (a, s)).collect()
}
}
#[inline]
pub fn score_row_tile(
row: ArrayView1<'_, f32>,
atoms_tile: ArrayView2<'_, f32>,
atom_offset: usize,
sel: &mut TopSSelector,
) {
let p = row.len();
for (local, atom) in atoms_tile.outer_iter().enumerate() {
let mut acc = 0.0f32;
for c in 0..p {
acc += row[c] * atom[c];
}
sel.offer((atom_offset + local) as u32, acc);
}
}
pub fn top_s_online(
row: ArrayView1<'_, f32>,
decoder: ArrayView2<'_, f32>,
s: usize,
tile: usize,
) -> Vec<(u32, f32)> {
let k = decoder.nrows();
let tile = tile.max(1);
let mut sel = TopSSelector::new(s);
let mut start = 0usize;
while start < k {
let end = (start + tile).min(k);
let block = decoder.slice(ndarray::s![start..end, ..]);
score_row_tile(row, block, start, &mut sel);
start = end;
}
sel.finish()
}
#[derive(Clone, Copy, Debug)]
pub struct TileScorer {
pub tile: usize,
pub active: usize,
}
impl TileScorer {
pub fn new(active: usize, tile: usize) -> Self {
Self {
tile: tile.max(1),
active: active.max(1),
}
}
pub fn route_row(
&self,
row: ArrayView1<'_, f32>,
decoder: ArrayView2<'_, f32>,
) -> Vec<(u32, f32)> {
top_s_online(row, decoder, self.active, self.tile)
}
}