use std::{cmp::Ordering, error::Error, fmt, io};
use angsd_saf as saf;
use rand::Rng;
use crate::ArrayExt;
pub mod iter;
use iter::{BlockIter, ParBlockIter, ParSiteIter, SiteIter};
mod site;
pub use site::{AsSiteView, Site, SiteView};
mod sealed {
pub trait Sealed: Sized {}
pub struct Bounds<T>(T);
impl<T> Sealed for Bounds<T> {}
}
use sealed::{Bounds, Sealed};
pub trait Lifetime<'a, SELF: Sealed = Bounds<&'a Self>> {
type Item;
}
pub trait AsSafView<const N: usize>: for<'a> Lifetime<'a, Item = SafView<'a, N>> {
fn as_saf_view(&self) -> <Self as Lifetime<'_>>::Item;
}
#[macro_export]
macro_rules! saf1d {
($([$($x:literal),+ $(,)?]),+ $(,)?) => {{
let (shape, vec) = $crate::matrix!($([$($x),+]),+);
$crate::saf::Saf::new(vec, [shape[0]]).unwrap()
}};
}
#[macro_export]
macro_rules! saf2d {
($([$($x:literal),+ $(,)?; $($y:literal),+ $(,)?]),+ $(,)?) => {{
let x_cols = vec![$($crate::matrix!(count: $($x),+)),+];
let y_cols = vec![$($crate::matrix!(count: $($y),+)),+];
for cols in [&x_cols, &y_cols] {
assert!(cols.windows(2).all(|w| w[0] == w[1]));
}
let vec = vec![$($($x),+, $($y),+),+];
$crate::saf::Saf::new(vec, [x_cols[0], y_cols[0]]).unwrap()
}};
}
macro_rules! impl_shared_saf_methods {
() => {
pub fn as_slice(&self) -> &[f32] {
&self.values
}
pub fn iter(&self) -> ::std::slice::Iter<f32> {
self.values.iter()
}
pub fn get_site(&self, index: usize) -> SiteView<N> {
let width = self.width();
SiteView::new_unchecked(&self.values[index * width..][..width], self.shape)
}
#[inline]
pub fn sites(&self) -> usize {
self.values.len() / self.width()
}
#[inline]
pub fn shape(&self) -> [usize; N] {
self.shape
}
#[inline]
pub(self) fn width(&self) -> usize {
self.shape.iter().sum()
}
};
}
#[derive(Clone, Debug, PartialEq)]
pub struct Saf<const N: usize> {
values: Vec<f32>,
shape: [usize; N],
}
impl<const N: usize> Saf<N> {
pub fn as_mut_slice(&mut self) -> &mut [f32] {
&mut self.values
}
pub fn iter_mut(&mut self) -> ::std::slice::IterMut<f32> {
self.values.iter_mut()
}
pub fn iter_blocks(&self, block_size: usize) -> BlockIter<N> {
BlockIter::new(self.view(), block_size)
}
pub fn iter_sites(&self) -> SiteIter<N> {
SiteIter::new(self.view())
}
pub fn new(values: Vec<f32>, shape: [usize; N]) -> Result<Self, ShapeError<N>> {
let len = values.len();
let width: usize = shape.iter().sum();
if len % width == 0 {
Ok(Self::new_unchecked(values, shape))
} else {
Err(ShapeError { len, shape })
}
}
pub(crate) fn new_unchecked(values: Vec<f32>, shape: [usize; N]) -> Self {
Self { values, shape }
}
pub fn par_iter_blocks(&self, block_size: usize) -> ParBlockIter<N> {
ParBlockIter::new(self.view(), block_size)
}
pub fn par_iter_sites(&self) -> ParSiteIter<N> {
ParSiteIter::new(self.view())
}
pub fn read<R>(readers: [saf::ReaderV3<R>; N]) -> io::Result<Self>
where
R: io::BufRead + io::Seek,
{
Self::read_inner_impl(readers, |values, item, _| {
values.extend_from_slice(item);
})
}
pub fn read_from_banded<R>(readers: [saf::ReaderV4<R>; N]) -> io::Result<Self>
where
R: io::BufRead + io::Seek,
{
Self::read_inner_impl(readers, |values, item, alleles| {
let full_likelihoods = &item.clone().into_full(alleles, f32::NEG_INFINITY);
values.extend_from_slice(full_likelihoods);
})
}
fn read_inner_impl<R, V, F>(readers: [saf::Reader<R, V>; N], f: F) -> io::Result<Self>
where
R: io::BufRead + io::Seek,
V: saf::version::Version,
F: Fn(&mut Vec<f32>, &V::Item, usize),
{
assert!(N > 0);
let max_sites = readers
.iter()
.map(|reader| reader.index().total_sites())
.min()
.unwrap();
let shape = readers.by_ref().map(|reader| reader.index().alleles() + 1);
let capacity = shape.iter().map(|shape| shape * max_sites).sum();
let mut values = Vec::with_capacity(capacity);
let mut intersect = saf::Intersect::new(Vec::from(readers));
let mut bufs = intersect.create_record_bufs();
while intersect.read_records(&mut bufs)?.is_not_done() {
for (buf, alleles) in bufs.iter().zip(shape.iter().map(|x| x - 1)) {
f(&mut values, buf.item(), alleles)
}
}
values.shrink_to_fit();
values.iter_mut().for_each(|x| *x = x.exp());
Ok(Self::new_unchecked(values, shape))
}
pub fn shuffle<R>(&mut self, rng: &mut R)
where
R: Rng,
{
let width = self.width();
for i in (1..self.sites()).rev() {
let j = rng.gen_range(0..i + 1);
self.swap_sites(i, j, width);
}
}
fn swap_sites(&mut self, mut i: usize, mut j: usize, width: usize) {
debug_assert_eq!(width, self.width());
match i.cmp(&j) {
Ordering::Less => (i, j) = (j, i),
Ordering::Equal => {
if i >= self.sites() || j >= self.sites() {
panic!("index out of bounds for swapping sites")
} else {
return;
}
}
Ordering::Greater => (),
}
let (hd, tl) = self.as_mut_slice().split_at_mut(i * width);
let left = &mut hd[j * width..][..width];
let right = &mut tl[..width];
left.swap_with_slice(right)
}
pub fn view(&self) -> SafView<N> {
SafView {
values: self.values.as_slice(),
shape: self.shape,
}
}
impl_shared_saf_methods! {}
}
impl<'a, const N: usize> Lifetime<'a> for Saf<N> {
type Item = SafView<'a, N>;
}
impl<const N: usize> AsSafView<N> for Saf<N> {
#[inline]
fn as_saf_view(&self) -> <Self as Lifetime<'_>>::Item {
self.view()
}
}
impl<'a, 'b, const N: usize> Lifetime<'a> for &'b Saf<N> {
type Item = SafView<'a, N>;
}
impl<'a, const N: usize> AsSafView<N> for &'a Saf<N> {
#[inline]
fn as_saf_view(&self) -> <Self as Lifetime<'_>>::Item {
self.view()
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct SafView<'a, const N: usize> {
values: &'a [f32],
shape: [usize; N],
}
impl<'a, const N: usize> SafView<'a, N> {
pub fn iter_blocks(&self, block_size: usize) -> BlockIter<'a, N> {
BlockIter::new(*self, block_size)
}
pub fn iter_sites(&self) -> SiteIter<'a, N> {
SiteIter::new(*self)
}
pub fn new(values: &'a [f32], shape: [usize; N]) -> Result<Self, ShapeError<N>> {
let len = values.len();
let width: usize = shape.iter().sum();
if len % width == 0 {
Ok(Self::new_unchecked(values, shape))
} else {
Err(ShapeError { len, shape })
}
}
pub(crate) fn new_unchecked(values: &'a [f32], shape: [usize; N]) -> Self {
Self { values, shape }
}
pub fn par_iter_blocks(&self, block_size: usize) -> ParBlockIter<N> {
ParBlockIter::new(*self, block_size)
}
pub fn par_iter_sites(&self) -> ParSiteIter<N> {
ParSiteIter::new(*self)
}
impl_shared_saf_methods! {}
}
impl<'a, 'b, const N: usize> Lifetime<'a> for SafView<'b, N> {
type Item = SafView<'a, N>;
}
impl<'a, const N: usize> AsSafView<N> for SafView<'a, N> {
#[inline]
fn as_saf_view(&self) -> <Self as Lifetime<'_>>::Item {
*self
}
}
#[derive(Clone, Debug)]
pub struct ShapeError<const N: usize> {
shape: [usize; N],
len: usize,
}
impl<const N: usize> fmt::Display for ShapeError<N> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"cannot construct shape {} from {} values",
self.shape.map(|x| x.to_string()).join("/"),
self.len,
)
}
}
impl<const N: usize> Error for ShapeError<N> {}
#[cfg(test)]
mod tests {
#[test]
fn test_swap_1d() {
let mut saf = saf1d![
[0., 0., 0.],
[1., 1., 1.],
[2., 2., 2.],
[3., 3., 3.],
[4., 4., 4.],
[5., 5., 5.],
];
let width = saf.width();
assert_eq!(width, 3);
saf.swap_sites(3, 3, width);
assert_eq!(saf.get_site(3).as_slice(), &[3., 3., 3.]);
saf.swap_sites(0, 1, width);
assert_eq!(saf.get_site(0).as_slice(), &[1., 1., 1.]);
assert_eq!(saf.get_site(1).as_slice(), &[0., 0., 0.]);
saf.swap_sites(5, 0, width);
assert_eq!(saf.get_site(0).as_slice(), &[5., 5., 5.]);
assert_eq!(saf.get_site(5).as_slice(), &[1., 1., 1.]);
}
#[test]
fn test_swap_2d() {
#[rustfmt::skip]
let mut saf = saf2d![
[0., 0., 0.; 10., 10.],
[1., 1., 1.; 11., 11.],
[2., 2., 2.; 12., 12.],
[3., 3., 3.; 13., 13.],
[4., 4., 4.; 14., 14.],
[5., 5., 5.; 15., 15.],
];
let width = saf.width();
assert_eq!(width, 5);
saf.swap_sites(0, 5, width);
assert_eq!(saf.get_site(0).as_slice(), &[5., 5., 5., 15., 15.,]);
saf.swap_sites(5, 0, width);
assert_eq!(saf.get_site(0).as_slice(), &[0., 0., 0., 10., 10.,]);
}
#[test]
#[should_panic]
fn test_swap_panics_out_of_bounds() {
let mut saf = saf1d![
[0., 0., 0.],
[1., 1., 1.],
[2., 2., 2.],
[3., 3., 3.],
[4., 4., 4.],
[5., 5., 5.],
];
saf.swap_sites(6, 5, saf.width());
}
}