use std::{
cmp::Ordering,
error::Error,
fmt::{self, Write as _},
marker::PhantomData,
ops::{Add, AddAssign, Index, IndexMut, Sub, SubAssign},
slice,
};
use crate::ArrayExt;
pub mod generics;
use generics::{ConstShape, DynShape, Norm, Normalisation, Shape, Unnorm};
pub mod io;
pub mod iter;
use iter::Indices;
mod em;
const NORMALISATION_TOLERANCE: f64 = 10. * f64::EPSILON;
#[macro_export]
macro_rules! sfs1d {
($elem:expr; $n:expr) => {
$crate::sfs::USfs::from_elem($elem, [$n])
};
($($x:expr),+ $(,)?) => {
$crate::sfs::USfs::from_vec(vec![$($x),+])
};
}
#[macro_export]
macro_rules! sfs2d {
($([$($x:literal),+ $(,)?]),+ $(,)?) => {{
let (cols, vec) = $crate::matrix!($([$($x),+]),+);
let shape = [cols.len(), cols[0]];
$crate::sfs::SfsBase::from_vec_shape(vec, shape).unwrap()
}};
}
#[derive(Clone, Debug, PartialEq)]
pub struct SfsBase<S: Shape, N: Normalisation> {
values: Vec<f64>,
pub(crate) shape: S,
pub(crate) strides: S,
norm: PhantomData<N>,
}
pub type Sfs<const D: usize> = SfsBase<ConstShape<D>, Norm>;
pub type USfs<const D: usize> = SfsBase<ConstShape<D>, Unnorm>;
pub type DynSfs = SfsBase<DynShape, Norm>;
pub type DynUSfs = SfsBase<DynShape, Unnorm>;
impl<S: Shape, N: Normalisation> SfsBase<S, N> {
#[inline]
pub fn as_slice(&self) -> &[f64] {
&self.values
}
pub fn fold(&self) -> Self {
let n = self.values.len();
let total_count = self.shape.iter().sum::<usize>() - self.shape.len();
let mid_count = total_count / 2;
let has_diagonal = total_count % 2 == 0;
let mut folded = Self::new_unchecked(vec![0.0; n], self.shape.clone());
(0..n).zip((0..n).rev()).for_each(|(i, rev_i)| {
let count = compute_index_sum_unchecked(i, n, self.shape.as_ref());
match (count.cmp(&mid_count), has_diagonal) {
(Ordering::Less, _) | (Ordering::Equal, false) => {
folded.values[i] = self.values[i] + self.values[rev_i];
}
(Ordering::Equal, true) => {
folded.values[i] = 0.5 * self.values[i] + 0.5 * self.values[rev_i];
}
(Ordering::Greater, _) => (),
}
});
folded
}
pub fn format_flat(&self, sep: &str, precision: usize) -> String {
if let Some(first) = self.values.first() {
let cap = self.values.len() * (precision + 3);
let mut init = String::with_capacity(cap);
write!(init, "{first:.precision$}").unwrap();
self.iter().skip(1).fold(init, |mut s, x| {
s.push_str(sep);
write!(s, "{x:.precision$}").unwrap();
s
})
} else {
String::new()
}
}
#[inline]
pub fn get(&self, index: &S) -> Option<&f64> {
self.values.get(compute_flat(index, &self.shape)?)
}
#[inline]
pub fn into_normalised(self) -> Result<SfsBase<S, Norm>, NormError> {
let sum = self.sum();
if (sum - 1.).abs() <= NORMALISATION_TOLERANCE {
Ok(self.into_normalised_unchecked())
} else {
Err(NormError { sum })
}
}
#[inline]
fn into_normalised_unchecked(self) -> SfsBase<S, Norm> {
SfsBase {
values: self.values,
shape: self.shape,
strides: self.strides,
norm: PhantomData,
}
}
#[inline]
pub fn into_unnormalised(self) -> SfsBase<S, Unnorm> {
SfsBase {
values: self.values,
shape: self.shape,
strides: self.strides,
norm: PhantomData,
}
}
#[inline]
pub fn iter(&self) -> slice::Iter<'_, f64> {
self.values.iter()
}
#[inline]
fn new_unchecked(values: Vec<f64>, shape: S) -> Self {
let strides = shape.strides();
Self {
values,
shape,
strides,
norm: PhantomData,
}
}
#[inline]
#[must_use = "returns scaled SFS, doesn't modify in-place"]
pub fn scale(mut self, scale: f64) -> SfsBase<S, Unnorm> {
self.values.iter_mut().for_each(|x| *x *= scale);
self.into_unnormalised()
}
pub fn shape(&self) -> &S {
&self.shape
}
#[inline]
fn sum(&self) -> f64 {
self.iter().sum()
}
}
impl<const D: usize, N: Normalisation> SfsBase<ConstShape<D>, N> {
pub fn frequencies(&self) -> impl Iterator<Item = [f64; D]> {
let n_arr = self.shape.map(|n| n - 1);
self.indices()
.map(move |idx_arr| idx_arr.array_zip(n_arr).map(|(i, n)| i as f64 / n as f64))
}
pub fn indices(&self) -> Indices<ConstShape<D>> {
Indices::from_shape(self.shape)
}
}
impl<S: Shape> SfsBase<S, Norm> {
pub fn uniform(shape: S) -> SfsBase<S, Norm> {
let n: usize = shape.iter().product();
let elem = 1.0 / n as f64;
SfsBase::new_unchecked(vec![elem; n], shape)
}
}
impl<S: Shape> SfsBase<S, Unnorm> {
#[inline]
pub fn as_mut_slice(&mut self) -> &mut [f64] {
&mut self.values
}
pub fn from_elem(elem: f64, shape: S) -> Self {
let n = shape.iter().product();
Self::new_unchecked(vec![elem; n], shape)
}
pub fn from_iter_shape<I>(iter: I, shape: S) -> Result<Self, ShapeError<S>>
where
I: IntoIterator<Item = f64>,
{
Self::from_vec_shape(iter.into_iter().collect(), shape)
}
pub fn from_vec_shape(vec: Vec<f64>, shape: S) -> Result<Self, ShapeError<S>> {
let n: usize = shape.iter().product();
match vec.len() == n {
true => Ok(Self::new_unchecked(vec, shape)),
false => Err(ShapeError::new(n, shape)),
}
}
#[inline]
pub fn get_mut(&mut self, index: &S) -> Option<&mut f64> {
self.values.get_mut(compute_flat(index, &self.shape)?)
}
#[inline]
pub fn iter_mut(&mut self) -> slice::IterMut<'_, f64> {
self.values.iter_mut()
}
#[inline]
#[must_use = "returns normalised SFS, doesn't modify in-place"]
pub fn normalise(mut self) -> SfsBase<S, Norm> {
let sum = self.sum();
self.iter_mut().for_each(|x| *x /= sum);
self.into_normalised_unchecked()
}
pub fn zeros(shape: S) -> Self {
Self::from_elem(0.0, shape)
}
}
impl SfsBase<ConstShape<1>, Unnorm> {
pub fn from_vec(values: Vec<f64>) -> Self {
let shape = [values.len()];
Self::new_unchecked(values, shape)
}
}
impl SfsBase<ConstShape<2>, Norm> {
pub fn f2(&self) -> f64 {
self.iter()
.zip(self.frequencies())
.map(|(v, [f_i, f_j])| v * (f_i - f_j).powi(2))
.sum()
}
}
macro_rules! impl_op {
($trait:ident, $method:ident, $assign_trait:ident, $assign_method:ident) => {
impl<S: Shape, N: Normalisation> $assign_trait<&SfsBase<S, N>> for SfsBase<S, Unnorm> {
#[inline]
fn $assign_method(&mut self, rhs: &SfsBase<S, N>) {
assert_eq!(self.shape, rhs.shape);
self.iter_mut()
.zip(rhs.iter())
.for_each(|(x, rhs)| x.$assign_method(rhs));
}
}
impl<S: Shape, N: Normalisation> $assign_trait<SfsBase<S, N>> for SfsBase<S, Unnorm> {
#[inline]
fn $assign_method(&mut self, rhs: SfsBase<S, N>) {
self.$assign_method(&rhs);
}
}
impl<S: Shape, N: Normalisation, M: Normalisation> $trait<SfsBase<S, M>> for SfsBase<S, N> {
type Output = SfsBase<S, Unnorm>;
#[inline]
fn $method(self, rhs: SfsBase<S, M>) -> Self::Output {
let mut sfs = self.into_unnormalised();
sfs.$assign_method(&rhs);
sfs
}
}
impl<S: Shape, N: Normalisation, M: Normalisation> $trait<&SfsBase<S, M>>
for SfsBase<S, N>
{
type Output = SfsBase<S, Unnorm>;
#[inline]
fn $method(self, rhs: &SfsBase<S, M>) -> Self::Output {
let mut sfs = self.into_unnormalised();
sfs.$assign_method(rhs);
sfs
}
}
};
}
impl_op!(Add, add, AddAssign, add_assign);
impl_op!(Sub, sub, SubAssign, sub_assign);
impl<S: Shape, N: Normalisation> Index<S> for SfsBase<S, N> {
type Output = f64;
#[inline]
fn index(&self, index: S) -> &Self::Output {
self.get(&index).unwrap()
}
}
impl<S: Shape> IndexMut<S> for SfsBase<S, Unnorm> {
#[inline]
fn index_mut(&mut self, index: S) -> &mut Self::Output {
self.get_mut(&index).unwrap()
}
}
impl<const D: usize, N: Normalisation> From<SfsBase<ConstShape<D>, N>> for SfsBase<DynShape, N> {
fn from(sfs: SfsBase<ConstShape<D>, N>) -> Self {
SfsBase {
values: sfs.values,
shape: sfs.shape.into(),
strides: sfs.strides.into(),
norm: PhantomData,
}
}
}
impl<const D: usize, N: Normalisation> TryFrom<SfsBase<DynShape, N>> for SfsBase<ConstShape<D>, N> {
type Error = SfsBase<DynShape, N>;
fn try_from(sfs: SfsBase<DynShape, N>) -> Result<Self, Self::Error> {
match (
<[usize; D]>::try_from(&sfs.shape[..]),
<[usize; D]>::try_from(&sfs.strides[..]),
) {
(Ok(shape), Ok(strides)) => Ok(SfsBase {
values: sfs.values,
shape,
strides,
norm: PhantomData,
}),
(Err(_), Err(_)) => Err(sfs),
(Ok(_), Err(_)) | (Err(_), Ok(_)) => {
unreachable!("conversion of dyn shape and strides succeeds or fails together")
}
}
}
}
#[derive(Clone, Copy, Debug)]
pub struct ShapeError<S: Shape> {
n: usize,
shape: S,
}
impl<S: Shape> ShapeError<S> {
fn new(n: usize, shape: S) -> Self {
Self { n, shape }
}
}
impl<S: Shape> fmt::Display for ShapeError<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let shape_fmt = self
.shape
.iter()
.map(|x| x.to_string())
.collect::<Vec<_>>()
.join("/");
let n = self.n;
let d = self.shape.as_ref().len();
write!(
f,
"cannot create {d}D SFS with shape {shape_fmt} from {n} elements"
)
}
}
impl<S: Shape> Error for ShapeError<S> {}
#[derive(Clone, Copy, Debug)]
pub struct NormError {
sum: f64,
}
impl fmt::Display for NormError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"cannot create normalised SFS using values summing to {}",
self.sum
)
}
}
impl Error for NormError {}
fn compute_flat<S: Shape>(index: &S, shape: &S) -> Option<usize> {
assert_eq!(index.len(), shape.len());
for i in 1..index.len() {
if index.as_ref()[i] >= shape.as_ref()[i] {
return None;
}
}
Some(compute_flat_unchecked(index, shape))
}
fn compute_flat_unchecked<S: Shape>(index: &S, shape: &S) -> usize {
let mut flat = index.as_ref()[0];
for i in 1..index.len() {
flat *= shape.as_ref()[i];
flat += index.as_ref()[i];
}
flat
}
fn compute_index_sum_unchecked(mut flat: usize, mut n: usize, shape: &[usize]) -> usize {
let mut sum = 0;
for v in shape {
n /= v;
sum += flat / n;
flat %= n;
}
sum
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_index_1d() {
let sfs = sfs1d![0., 1., 2., 3., 4., 5.];
assert_eq!(sfs.get(&[0]), Some(&0.));
assert_eq!(sfs.get(&[2]), Some(&2.));
assert_eq!(sfs.get(&[5]), Some(&5.));
assert_eq!(sfs.get(&[6]), None);
}
#[test]
fn test_index_2d() {
let sfs = sfs2d![[0., 1., 2.], [3., 4., 5.]];
assert_eq!(sfs.get(&[0, 0]), Some(&0.));
assert_eq!(sfs.get(&[1, 0]), Some(&3.));
assert_eq!(sfs.get(&[1, 1]), Some(&4.));
assert_eq!(sfs.get(&[1, 2]), Some(&5.));
assert_eq!(sfs.get(&[2, 0]), None);
assert_eq!(sfs.get(&[0, 3]), None);
}
#[test]
fn test_f2() {
#[rustfmt::skip]
let sfs = sfs2d![
[0., 1., 2.],
[3., 4., 5.]
].normalise();
assert!((sfs.f2() - 0.4166667).abs() < 1e-6);
}
#[test]
fn test_sfs_addition() {
let mut lhs = sfs1d![0., 1., 2.];
let rhs = sfs1d![5., 6., 7.];
let sum = sfs1d![5., 7., 9.];
assert_eq!(lhs.clone() + rhs.clone(), sum);
assert_eq!(lhs.clone() + &rhs, sum);
lhs += rhs.clone();
assert_eq!(lhs, sum);
lhs += &rhs;
assert_eq!(lhs, sum + rhs);
}
#[test]
fn test_sfs_subtraction() {
let mut lhs = sfs1d![5., 6., 7.];
let rhs = sfs1d![0., 1., 2.];
let sub = sfs1d![5., 5., 5.];
assert_eq!(lhs.clone() - rhs.clone(), sub);
assert_eq!(lhs.clone() - &rhs, sub);
lhs -= rhs.clone();
assert_eq!(lhs, sub);
lhs -= &rhs;
assert_eq!(lhs, sub - rhs);
}
#[test]
fn test_fold_4() {
let sfs = sfs1d![0., 1., 2., 3.];
assert_eq!(sfs.fold(), sfs1d![3., 3., 0., 0.],);
}
#[test]
fn test_fold_5() {
let sfs = sfs1d![0., 1., 2., 3., 4.];
assert_eq!(sfs.fold(), sfs1d![4., 4., 2., 0., 0.],);
}
#[test]
fn test_fold_3x3() {
#[rustfmt::skip]
let sfs = sfs2d![
[0., 1., 2.],
[3., 4., 5.],
[6., 7., 8.],
];
#[rustfmt::skip]
let expected = sfs2d![
[8., 8., 4.],
[8., 4., 0.],
[4., 0., 0.],
];
assert_eq!(sfs.fold(), expected);
}
#[test]
fn test_fold_2x4() {
#[rustfmt::skip]
let sfs = sfs2d![
[0., 1., 2., 3.],
[4., 5., 6., 7.],
];
#[rustfmt::skip]
let expected = sfs2d![
[7., 7., 3.5, 0.],
[7., 3.5, 0., 0.],
];
assert_eq!(sfs.fold(), expected);
}
#[test]
fn test_fold_3x4() {
#[rustfmt::skip]
let sfs = sfs2d![
[0., 1., 2., 3.],
[4., 5., 6., 7.],
[8., 9., 10., 11.],
];
#[rustfmt::skip]
let expected = sfs2d![
[11., 11., 11., 0.],
[11., 11., 0., 0.],
[11., 0., 0., 0.],
];
assert_eq!(sfs.fold(), expected);
}
#[test]
fn test_fold_3x7() {
#[rustfmt::skip]
let sfs = sfs2d![
[ 0., 1., 2., 3., 4., 5., 6.],
[ 7., 8., 9., 10., 11., 12., 13.],
[14., 15., 16., 17., 18., 19., 20.],
];
#[rustfmt::skip]
let expected = sfs2d![
[20., 20., 20., 20., 10., 0., 0.],
[20., 20., 20., 10., 0., 0., 0.],
[20., 20., 10., 0., 0., 0., 0.],
];
assert_eq!(sfs.fold(), expected);
}
#[test]
fn test_fold_2x2x2() {
let sfs = USfs::from_iter_shape((0..8).map(|x| x as f64), [2, 2, 2]).unwrap();
#[rustfmt::skip]
let expected = USfs::from_vec_shape(
vec![
7., 7.,
7., 0.,
7., 0.,
0., 0.,
],
[2, 2, 2]
).unwrap();
assert_eq!(sfs.fold(), expected);
}
#[test]
fn test_fold_2x3x2() {
let sfs = USfs::from_iter_shape((0..12).map(|x| x as f64), [2, 3, 2]).unwrap();
#[rustfmt::skip]
let expected = USfs::from_vec_shape(
vec![
11., 11.,
11., 5.5,
5.5, 0.,
11., 5.5,
5.5, 0.,
0., 0.,
],
[2, 3, 2]
).unwrap();
assert_eq!(sfs.fold(), expected);
}
#[test]
fn test_fold_3x3x3() {
let sfs = USfs::from_iter_shape((0..27).map(|x| x as f64), [3, 3, 3]).unwrap();
#[rustfmt::skip]
let expected = USfs::from_vec_shape(
vec![
26., 26., 26.,
26., 26., 13.,
26., 13., 0.,
26., 26., 13.,
26., 13., 0.,
13., 0., 0.,
26., 13., 0.,
13., 0., 0.,
0., 0., 0.,
],
[3, 3, 3]
).unwrap();
assert_eq!(sfs.fold(), expected);
}
}