use ndarray::{ArrayRef3, ArrayView3};
pub struct Offsets {
mask_strides: Vec<isize>,
dim_m1: Vec<usize>,
offsets: Vec<isize>,
axes: [usize; 3],
axes_rev: [usize; 3],
strides: Vec<usize>,
backstrides: Vec<usize>,
bounds: Vec<std::ops::Range<usize>>,
n: usize,
pub coordinates: Vec<usize>,
at: usize,
}
impl Offsets {
pub fn new<A>(mask: &ArrayRef3<A>, kernel: ArrayView3<bool>, is_dilate: bool) -> Offsets
{
let mask_shape = mask.shape();
let mask_strides = mask.strides().to_vec();
let (axes, axes_rev) = if mask_strides[0] > mask_strides[2] {
([2, 1, 0], [0, 1, 2])
} else {
([0, 1, 2], [2, 1, 0])
};
let (offsets, n) = build_offsets(mask_shape, &mask_strides, kernel.view(), is_dilate);
let dim_m1: Vec<_> = mask_shape.iter().map(|&len| len - 1).collect();
let kernel_shape = kernel.shape();
let mut strides = vec![0; mask.ndim()];
strides[mask.ndim() - 1] = n;
for d in (0..mask.ndim() - 1).rev() {
strides[d] = strides[d + 1] * kernel_shape[d];
}
let backstrides = strides.iter().zip(kernel_shape).map(|(&s, &l)| (l - 1) * s).collect();
let bounds = (0..mask.ndim())
.map(|d| {
let radius = (kernel_shape[d] - 1) / 2;
radius..dim_m1[d] - radius
})
.collect();
Offsets {
mask_strides,
dim_m1,
offsets,
axes,
axes_rev,
strides,
backstrides,
bounds,
n,
coordinates: vec![0; mask.ndim()],
at: 0,
}
}
pub fn range(&self) -> &[isize] {
&self.offsets[self.at..self.at + self.n]
}
pub fn move_to(&mut self, idx: isize) {
let mut idx = idx as usize;
for &d in &self.axes_rev {
let s = self.mask_strides[d] as usize;
self.coordinates[d] = idx / s;
idx -= self.coordinates[d] * s;
}
self.at = 0;
for &d in &self.axes {
let (start, end) = (self.bounds[d].start, self.bounds[d].end);
let c = self.coordinates[d];
let j = if c < start {
c
} else if c > end && end >= start {
c + start - end
} else {
start
};
self.at += self.strides[d] * j;
}
}
pub fn next(&mut self) {
for &d in &self.axes {
if self.coordinates[d] < self.dim_m1[d] {
if !self.bounds[d].contains(&self.coordinates[d]) {
self.at += self.strides[d];
}
self.coordinates[d] += 1;
break;
} else {
self.coordinates[d] = 0;
self.at -= self.backstrides[d];
}
}
}
}
fn build_offsets(
shape: &[usize],
strides: &[isize],
kernel: ArrayView3<bool>,
is_dilate: bool,
) -> (Vec<isize>, usize) {
let radii: Vec<_> = kernel.shape().iter().map(|&len| (len - 1) / 2).collect();
let indices = build_indices(kernel, &radii, is_dilate);
let shape = [shape[0] as isize, shape[1] as isize, shape[2] as isize];
let ooi_offset = shape.iter().fold(1, |acc, &s| acc * s);
let build_pos = |d: usize| {
let mut pos = Vec::with_capacity(kernel.shape()[d]);
let radius = radii[d] as isize;
pos.extend(0..radius);
pos.push(shape[d] / 2);
pos.extend(shape[d] - radius..shape[d]);
pos
};
let z_pos = build_pos(0);
let y_pos = build_pos(1);
let x_pos = build_pos(2);
let mut offsets = vec![];
for &z in &z_pos {
for &y in &y_pos {
for &x in &x_pos {
for idx2 in &indices {
let idx = [z + idx2[0], y + idx2[1], x + idx2[2]];
let offset = if idx.iter().zip(shape).any(|(i, s)| !(0..s).contains(i)) {
ooi_offset
} else {
idx2.iter().zip(strides).fold(0, |acc, (i, s)| acc + i * s)
};
offsets.push(offset)
}
}
}
}
for chunk in offsets.chunks_mut(indices.len()) {
chunk.sort();
}
(offsets, indices.len())
}
fn build_indices(kernel: ArrayView3<bool>, radii: &[usize], is_dilate: bool) -> Vec<[isize; 3]> {
let radii = [radii[0] as isize, radii[1] as isize, radii[2] as isize];
kernel
.indexed_iter()
.filter_map(|(idx, &b)| {
if !b {
return None;
}
let centered =
[idx.0 as isize - radii[0], idx.1 as isize - radii[1], idx.2 as isize - radii[2]];
(centered != [0, 0, 0]).then_some(if is_dilate {
[-1 * centered[0], -1 * centered[1], -1 * centered[2]]
} else {
centered
})
})
.collect()
}