use std::fmt::{Debug, Display};
use crate::{
layouts::{Data, HostDataMut, HostDataRef},
source::Source,
};
use bytemuck::Pod;
use rand_distr::num_traits::Zero;
pub trait ZnxInfos {
fn n(&self) -> usize;
fn log_n(&self) -> usize {
(usize::BITS - (self.n() - 1).leading_zeros()) as _
}
fn rows(&self) -> usize;
fn cols(&self) -> usize;
fn size(&self) -> usize;
fn poly_count(&self) -> usize {
self.rows() * self.cols() * self.size()
}
}
pub trait DataView {
type D: Data;
fn data(&self) -> &Self::D;
}
pub trait DataViewMut: DataView {
fn data_mut(&mut self) -> &mut Self::D;
}
pub trait ZnxView: ZnxInfos + DataView<D: HostDataRef> {
type Scalar: Copy + Zero + Display + Debug + Pod;
fn as_ptr(&self) -> *const Self::Scalar {
self.data().as_ref().as_ptr() as *const Self::Scalar
}
fn raw(&self) -> &[Self::Scalar] {
unsafe { std::slice::from_raw_parts(self.as_ptr(), self.n() * self.poly_count()) }
}
fn at_ptr(&self, i: usize, j: usize) -> *const Self::Scalar {
assert!(i < self.cols(), "cols: {} >= self.cols(): {}", i, self.cols());
assert!(j < self.size(), "size: {} >= self.size(): {}", j, self.size());
let offset: usize = self.n() * (j * self.cols() + i);
unsafe { self.as_ptr().add(offset) }
}
fn at(&self, i: usize, j: usize) -> &[Self::Scalar] {
unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n()) }
}
}
pub trait ZnxViewMut: ZnxView + DataViewMut<D: HostDataMut> {
fn as_mut_ptr(&mut self) -> *mut Self::Scalar {
self.data_mut().as_mut().as_mut_ptr() as *mut Self::Scalar
}
fn raw_mut(&mut self) -> &mut [Self::Scalar] {
unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.n() * self.poly_count()) }
}
fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut Self::Scalar {
assert!(i < self.cols(), "cols: {} >= self.cols(): {}", i, self.cols());
assert!(j < self.size(), "size: {} >= self.size(): {}", j, self.size());
let offset: usize = self.n() * (j * self.cols() + i);
unsafe { self.as_mut_ptr().add(offset) }
}
fn at_mut(&mut self, i: usize, j: usize) -> &mut [Self::Scalar] {
unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n()) }
}
}
impl<T> ZnxViewMut for T where T: ZnxView + DataViewMut<D: HostDataMut> {}
pub trait ZnxZero
where
Self: Sized,
{
fn zero(&mut self);
fn zero_at(&mut self, i: usize, j: usize);
}
pub trait FillUniform {
fn fill_uniform(&mut self, log_bound: usize, source: &mut Source);
}