#[derive(Clone, PartialEq)]
pub struct Array2d {
items: Vec<f64>,
w: u32,
h: u32,
}
pub type Dimensions = (u32, u32);
impl Array2d {
pub fn new(dim: Dimensions) -> Self {
Self {
items: vec![0.0; (dim.0 * dim.1) as usize],
w: dim.0,
h: dim.1,
}
}
pub fn from_vec<V: Into<Vec<f64>>>(items: V, dim: Dimensions) -> Self {
let items = items.into();
assert_eq!(
items.len(),
(dim.0 * dim.1) as usize,
"number of items does not match shape"
);
Self {
items,
w: dim.0,
h: dim.1,
}
}
pub fn from_cb<F: Fn((u32, u32)) -> f64>(dim: Dimensions, cb: F) -> Self {
let mut items = Vec::with_capacity((dim.0 * dim.1) as usize);
for y in 0..dim.1 {
for x in 0..dim.0 {
items.push(cb((x, y)));
}
}
Self {
items,
w: dim.0,
h: dim.1,
}
}
pub fn dimensions(&self) -> Dimensions {
(self.w, self.h)
}
fn pixel_index_unchecked(&self, x: u32, y: u32) -> usize {
(y * self.w + x) as usize
}
fn pixel_index_checked(&self, x: u32, y: u32) -> Option<usize> {
if x >= self.w || y >= self.h {
None
} else {
Some(self.pixel_index_unchecked(x, y))
}
}
fn pixel_index(&self, x: u32, y: u32) -> usize {
match self.pixel_index_checked(x, y) {
Some(res) => res,
None => panic!(
"Point index {:?} out of bounds {:?}",
(x, y),
self.dimensions()
),
}
}
fn pixel_index_reflect(&self, x: isize, y: isize) -> usize {
let xu = u32::try_from(x).ok();
let yu = u32::try_from(y).ok();
match xu.zip(yu).and_then(|(x, y)| self.pixel_index_checked(x, y)) {
Some(i) => i,
None => {
let rfl = |v: isize, len: isize| -> u32 {
(if v < 0 {
-(v + 1)
} else if v >= len {
2 * len - v - 1
} else {
v
}) as u32
};
let (x, y) = (rfl(x, self.w as isize), rfl(y, self.h as isize));
self.pixel_index(x, y)
}
}
}
pub fn get(&self, x: u32, y: u32) -> f64 {
let i = self.pixel_index(x, y);
self.items[i]
}
pub fn get_mut(&mut self, x: u32, y: u32) -> &mut f64 {
let i = self.pixel_index(x, y);
&mut self.items[i]
}
pub fn set(&mut self, x: u32, y: u32, val: f64) {
let i = self.pixel_index(x, y);
self.items[i] = val;
}
pub fn get_reflect(&self, x: isize, y: isize) -> f64 {
let i = self.pixel_index_reflect(x, y);
self.items[i]
}
pub fn row(&self, y: u32) -> RowIter<'_> {
RowIter {
arr: self,
pos: (y * self.w) as usize,
step: 1,
end: ((y + 1) * self.w) as usize,
}
}
pub fn col(&self, x: u32) -> RowIter<'_> {
RowIter {
arr: self,
pos: x as usize,
step: self.w as usize,
end: self.items.len(),
}
}
pub fn rows(&self) -> RowsIter<'_> {
RowsIter { arr: self, pos: 0 }
}
pub fn cols(&self) -> ColsIter<'_> {
ColsIter { arr: self, pos: 0 }
}
pub fn apply_op<F: Fn(f64) -> f64>(&mut self, op: F) {
for itm in self.items.iter_mut() {
*itm = op(*itm);
}
}
}
impl std::fmt::Debug for Array2d {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("Array2(")?;
for row in self.rows() {
f.write_str("\n[")?;
for (i, x) in row.enumerate() {
if i != 0 {
f.write_str(", ")?;
}
write!(f, "{x: >10}")?;
}
f.write_str("]")?;
}
f.write_str(")")
}
}
impl std::ops::Mul for Array2d {
type Output = Array2d;
fn mul(self, rhs: Self) -> Self::Output {
let (wa, ha) = self.dimensions();
let (wb, hb) = rhs.dimensions();
let w = wa.max(wb);
let h = ha.max(hb);
let mut output = Array2d::new((w, h));
for x in 0..w {
for y in 0..h {
let (xa, xb) = (x.min(wa - 1), x.min(wb - 1));
let (ya, yb) = (y.min(ha - 1), y.min(hb - 1));
let val = self.get(xa, ya) * rhs.get(xb, yb);
output.set(x, y, val);
}
}
output
}
}
pub struct RowsIter<'a> {
arr: &'a Array2d,
pos: u32,
}
pub struct ColsIter<'a> {
arr: &'a Array2d,
pos: u32,
}
pub struct RowIter<'a> {
arr: &'a Array2d,
pos: usize,
step: usize,
end: usize,
}
impl<'a> Iterator for RowsIter<'a> {
type Item = RowIter<'a>;
fn next(&mut self) -> Option<Self::Item> {
if self.pos < self.arr.h {
let res = self.arr.row(self.pos);
self.pos += 1;
Some(res)
} else {
None
}
}
}
impl<'a> Iterator for ColsIter<'a> {
type Item = RowIter<'a>;
fn next(&mut self) -> Option<Self::Item> {
if self.pos < self.arr.w {
let res = self.arr.col(self.pos);
self.pos += 1;
Some(res)
} else {
None
}
}
}
impl Iterator for RowIter<'_> {
type Item = f64;
fn next(&mut self) -> Option<Self::Item> {
if self.pos < self.end {
let res = self.arr.items[self.pos];
self.pos += self.step;
Some(res)
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use super::Array2d;
#[test]
fn rows() {
let arr = Array2d::from_vec((0..25).map(f64::from).collect::<Vec<_>>(), (5, 5));
let rsums = arr.rows().map(|r| r.sum::<f64>()).collect::<Vec<_>>();
assert_eq!(rsums, [10.0, 35.0, 60.0, 85.0, 110.0]);
let csums = arr.cols().map(|r| r.sum::<f64>()).collect::<Vec<_>>();
assert_eq!(csums, [50.0, 55.0, 60.0, 65.0, 70.0]);
}
#[test]
fn get_reflect() {
let arr = Array2d::from_vec((0..6).map(f64::from).collect::<Vec<_>>(), (3, 2));
assert_eq!(arr.get(0, 0), 0.0);
assert_eq!(arr.get_reflect(0, 0), 0.0);
assert_eq!(arr.get_reflect(2, 0), 2.0);
assert_eq!(arr.get_reflect(-1, 0), 0.0);
assert_eq!(arr.get_reflect(-2, 0), 1.0);
assert_eq!(arr.get_reflect(-3, 0), 2.0);
}
}