use std::f32::consts::PI;
use std::iter;
use std::marker::PhantomData;
use rand::Rng;
use rayon::join;
use rayon::prelude::*;
use crate::access::Access;
use crate::ops::{Enqueue, Op, ReadValue, SliceSpec, ViewSpec};
use crate::{
stackvec, strides_for, AccessMut, Axes, BufferConverter, CType, Error, Float, Range, Shape,
Strides,
};
use super::buffer::Buffer;
use super::platform::{Heap, Host, Stack};
use super::{SliceConverter, StackVec, VEC_MIN_SIZE};
macro_rules! host_enqueue {
($this:expr, $cond:expr, $t:ty) => {
if $cond {
Enqueue::<Stack, $t>::enqueue($this).map(Buffer::Stack)
} else {
Enqueue::<Heap, $t>::enqueue($this).map(Buffer::Heap)
}
};
}
pub struct Cast<A, IT, OT> {
access: A,
dtype: PhantomData<(IT, OT)>,
}
impl<A, IT, OT> Cast<A, IT, OT> {
pub fn new(access: A) -> Self {
Self {
access,
dtype: PhantomData,
}
}
}
impl<A: Access<IT>, IT: CType, OT: CType> Op for Cast<A, IT, OT> {
fn size(&self) -> usize {
self.access.size()
}
}
impl<A: Access<IT>, IT: CType, OT: CType> Enqueue<Heap, OT> for Cast<A, IT, OT> {
type Buffer = Vec<OT>;
fn enqueue(&self) -> Result<Self::Buffer, Error> {
self.access
.read()
.and_then(|buf| buf.to_slice())
.map(|slice| {
slice
.into_par_iter()
.map(|n| n.to_f64())
.map(OT::from_f64)
.collect()
})
}
}
impl<A: Access<IT>, IT: CType, OT: CType> Enqueue<Stack, OT> for Cast<A, IT, OT> {
type Buffer = StackVec<OT>;
fn enqueue(&self) -> Result<Self::Buffer, Error> {
self.access
.read()
.and_then(|buf| buf.to_slice())
.map(|slice| {
slice
.into_iter()
.map(|n| n.to_f64())
.map(OT::from_f64)
.collect()
})
}
}
impl<A: Access<IT>, IT: CType, OT: CType> Enqueue<Host, OT> for Cast<A, IT, OT> {
type Buffer = Buffer<OT>;
fn enqueue(&self) -> Result<Self::Buffer, Error> {
host_enqueue!(self, self.size() < VEC_MIN_SIZE, OT)
}
}
impl<A: Access<IT>, IT: CType, OT: CType> ReadValue<Host, OT> for Cast<A, IT, OT> {
fn read_value(&self, offset: usize) -> Result<OT, Error> {
self.access
.read_value(offset)
.map(|n| n.to_f64())
.map(OT::from_f64)
}
}
pub struct Dual<L, R, IT, OT> {
left: L,
right: R,
zip: fn(IT, IT) -> OT,
}
impl<L, R, IT, OT> Op for Dual<L, R, IT, OT>
where
L: Access<IT>,
R: Access<IT>,
IT: CType,
OT: CType,
{
fn size(&self) -> usize {
self.left.size()
}
}
impl<L, R, T: CType> Dual<L, R, T, T> {
pub fn add(left: L, right: R) -> Self {
Self {
left,
right,
zip: T::add,
}
}
pub fn div(left: L, right: R) -> Self {
Self {
left,
right,
zip: T::div,
}
}
pub fn log(left: L, right: R) -> Self {
Self {
left,
right,
zip: |a, b| T::from_float(a.to_float().log(b.to_float())),
}
}
pub fn mul(left: L, right: R) -> Self {
Self {
left,
right,
zip: T::mul,
}
}
pub fn pow(left: L, right: R) -> Self {
Self {
left,
right,
zip: T::pow,
}
}
pub fn rem(left: L, right: R) -> Self {
Self {
left,
right,
zip: T::rem,
}
}
pub fn sub(left: L, right: R) -> Self {
Self {
left,
right,
zip: T::sub,
}
}
}
impl<L, R, T: CType> Dual<L, R, T, u8> {
pub fn and(left: L, right: R) -> Self {
Self {
left,
right,
zip: |l, r| if l != T::ZERO && r != T::ZERO { 1 } else { 0 },
}
}
pub fn or(left: L, right: R) -> Self {
Self {
left,
right,
zip: |l, r| if l != T::ZERO || r != T::ZERO { 1 } else { 0 },
}
}
pub fn xor(left: L, right: R) -> Self {
Self {
left,
right,
zip: |l, r| {
if (l != T::ZERO) ^ (r != T::ZERO) {
1
} else {
0
}
},
}
}
}
impl<L, R, T: CType> Dual<L, R, T, u8> {
pub fn eq(left: L, right: R) -> Self {
Self {
left,
right,
zip: |l, r| if l == r { 1 } else { 0 },
}
}
pub fn ge(left: L, right: R) -> Self {
Self {
left,
right,
zip: |l, r| if l >= r { 1 } else { 0 },
}
}
pub fn gt(left: L, right: R) -> Self {
Self {
left,
right,
zip: |l, r| if l > r { 1 } else { 0 },
}
}
pub fn le(left: L, right: R) -> Self {
Self {
left,
right,
zip: |l, r| if l <= r { 1 } else { 0 },
}
}
pub fn lt(left: L, right: R) -> Self {
Self {
left,
right,
zip: |l, r| if l < r { 1 } else { 0 },
}
}
pub fn ne(left: L, right: R) -> Self {
Self {
left,
right,
zip: |l, r| if l != r { 1 } else { 0 },
}
}
}
impl<L, R, IT, OT> Enqueue<Stack, OT> for Dual<L, R, IT, OT>
where
L: Access<IT>,
R: Access<IT>,
IT: CType,
OT: CType,
{
type Buffer = StackVec<OT>;
fn enqueue(&self) -> Result<Self::Buffer, Error> {
let left = self.left.read()?.to_slice()?;
let right = self.right.read()?.to_slice()?;
exec_dual(self.zip, left, right)
}
}
impl<L, R, IT, OT> Enqueue<Heap, OT> for Dual<L, R, IT, OT>
where
L: Access<IT>,
R: Access<IT>,
IT: CType,
OT: CType,
{
type Buffer = Vec<OT>;
fn enqueue(&self) -> Result<Self::Buffer, Error> {
let (left, right) = try_join_read(&self.left, &self.right)?;
exec_dual_parallel(self.zip, left, right)
}
}
impl<L, R, IT, OT> Enqueue<Host, OT> for Dual<L, R, IT, OT>
where
L: Access<IT>,
R: Access<IT>,
IT: CType,
OT: CType,
{
type Buffer = Buffer<OT>;
fn enqueue(&self) -> Result<Self::Buffer, Error> {
host_enqueue!(self, self.size() < VEC_MIN_SIZE, OT)
}
}
impl<L, R, IT, OT> ReadValue<Host, OT> for Dual<L, R, IT, OT>
where
L: Access<IT>,
R: Access<IT>,
IT: CType,
OT: CType,
{
fn read_value(&self, offset: usize) -> Result<OT, Error> {
try_join_value(&self.left, &self.right, offset).map(|(l, r)| (self.zip)(l, r))
}
}
pub struct Cond<A, L, R, T> {
cond: A,
then: L,
or_else: R,
dtype: PhantomData<T>,
}
impl<A, L, R, T> Cond<A, L, R, T> {
pub fn new(cond: A, then: L, or_else: R) -> Self {
Self {
cond,
then,
or_else,
dtype: PhantomData,
}
}
}
impl<A, L, R, T> Op for Cond<A, L, R, T>
where
A: Access<u8>,
L: Access<T>,
R: Access<T>,
T: CType,
{
fn size(&self) -> usize {
debug_assert_eq!(self.cond.size(), self.then.size());
debug_assert_eq!(self.cond.size(), self.or_else.size());
self.cond.size()
}
}
impl<A, L, R, T> Enqueue<Stack, T> for Cond<A, L, R, T>
where
A: Access<u8>,
L: Access<T>,
R: Access<T>,
T: CType,
{
type Buffer = StackVec<T>;
fn enqueue(&self) -> Result<Self::Buffer, Error> {
let cond = self.cond.read()?.to_slice()?;
let then = self.then.read()?.to_slice()?;
let or_else = self.or_else.read()?.to_slice()?;
let output = cond
.into_iter()
.copied()
.zip(then.into_iter().zip(or_else.into_iter()))
.map(
|(cond, (then, or_else))| {
if cond != 0 {
then
} else {
or_else
}
},
)
.copied()
.collect();
Ok(output)
}
}
impl<A, L, R, T> Enqueue<Heap, T> for Cond<A, L, R, T>
where
A: Access<u8>,
L: Access<T>,
R: Access<T>,
T: CType,
{
type Buffer = Vec<T>;
fn enqueue(&self) -> Result<Self::Buffer, Error> {
let (cond, (then, or_else)) = join(
|| self.cond.read().and_then(|buf| buf.to_slice()),
|| {
join(
|| self.then.read().and_then(|buf| buf.to_slice()),
|| self.or_else.read().and_then(|buf| buf.to_slice()),
)
},
);
let (cond, (then, or_else)) = (cond?, (then?, or_else?));
let output = cond
.into_par_iter()
.copied()
.zip(then.into_par_iter().zip(or_else.into_par_iter()))
.map(
|(cond, (then, or_else))| {
if cond != 0 {
then
} else {
or_else
}
},
)
.copied()
.collect();
Ok(output)
}
}
impl<A, L, R, T> Enqueue<Host, T> for Cond<A, L, R, T>
where
A: Access<u8>,
L: Access<T>,
R: Access<T>,
T: CType,
{
type Buffer = Buffer<T>;
fn enqueue(&self) -> Result<Self::Buffer, Error> {
host_enqueue!(self, self.size() < VEC_MIN_SIZE, T)
}
}
impl<A, L, R, T> ReadValue<Host, T> for Cond<A, L, R, T>
where
A: Access<u8>,
L: Access<T>,
R: Access<T>,
T: CType,
{
fn read_value(&self, offset: usize) -> Result<T, Error> {
let (cond, (then, or_else)) = join(
|| self.cond.read_value(offset),
|| {
join(
|| self.then.read_value(offset),
|| self.or_else.read_value(offset),
)
},
);
let (cond, (then, or_else)) = (cond?, (then?, or_else?));
if cond != 0 {
Ok(then)
} else {
Ok(or_else)
}
}
}
pub struct Linear<T> {
start: T,
step: f64,
size: usize,
}
impl<T> Linear<T> {
pub fn new(start: T, step: f64, size: usize) -> Self {
Self { start, step, size }
}
#[inline]
fn value_at(&self, offset: usize) -> T
where
T: CType,
{
T::add(self.start, T::from_f64((offset as f64) * self.step))
}
}
impl<T: Send + Sync> Op for Linear<T> {
fn size(&self) -> usize {
self.size
}
}
impl<T: CType> Enqueue<Stack, T> for Linear<T> {
type Buffer = StackVec<T>;
fn enqueue(&self) -> Result<Self::Buffer, Error> {
let start = self.start.to_f64();
let buffer = (0..self.size)
.into_iter()
.map(|i| i as f64)
.map(|i| i * self.step)
.map(|o| start + o)
.map(T::from_f64)
.collect();
Ok(buffer)
}
}
impl<T: CType> Enqueue<Heap, T> for Linear<T> {
type Buffer = Vec<T>;
fn enqueue(&self) -> Result<Self::Buffer, Error> {
let buffer = (0..self.size)
.into_par_iter()
.map(|offset| self.value_at(offset))
.collect();
Ok(buffer)
}
}
impl<T: CType> Enqueue<Host, T> for Linear<T> {
type Buffer = Buffer<T>;
fn enqueue(&self) -> Result<Self::Buffer, Error> {
host_enqueue!(self, self.size < VEC_MIN_SIZE, T)
}
}
impl<T: CType> ReadValue<Host, T> for Linear<T> {
fn read_value(&self, offset: usize) -> Result<T, Error> {
Ok(self.value_at(offset))
}
}
pub struct MatDiag<A, T> {
access: A,
dim: usize,
batch_size: usize,
dtype: PhantomData<T>,
}
impl<A, T> MatDiag<A, T> {
pub fn new(access: A, batch_size: usize, dim: usize) -> Self {
Self {
access,
dim,
batch_size,
dtype: PhantomData,
}
}
}
impl<A: Access<T>, T: CType> Op for MatDiag<A, T> {
fn size(&self) -> usize {
debug_assert_eq!(self.access.size(), self.batch_size * self.dim * self.dim);
self.batch_size * self.dim
}
}
impl<A: Access<T>, T: CType> Enqueue<Heap, T> for MatDiag<A, T> {
type Buffer = Vec<T>;
fn enqueue(&self) -> Result<Self::Buffer, Error> {
let input = self.access.read()?.to_slice()?;
let diagonals = input
.par_chunks_exact(self.dim * self.dim)
.map(|matrix| {
matrix
.par_chunks_exact(self.dim)
.enumerate()
.map(|(i, row)| row[i])
})
.flatten()
.collect();
Ok(diagonals)
}
}
impl<A: Access<T>, T: CType> Enqueue<Stack, T> for MatDiag<A, T> {
type Buffer = StackVec<T>;
fn enqueue(&self) -> Result<Self::Buffer, Error> {
let input = self.access.read()?.to_slice()?;
let diagonals = input
.chunks_exact(self.dim * self.dim)
.map(|matrix| {
matrix
.chunks_exact(self.dim)
.enumerate()
.map(|(i, row)| row[i])
})
.flatten()
.collect();
Ok(diagonals)
}
}
impl<A: Access<T>, T: CType> Enqueue<Host, T> for MatDiag<A, T> {
type Buffer = Buffer<T>;
fn enqueue(&self) -> Result<Self::Buffer, Error> {
host_enqueue!(self, self.size() < VEC_MIN_SIZE, T)
}
}
impl<A: Access<T>, T: CType> ReadValue<Host, T> for MatDiag<A, T> {
fn read_value(&self, offset: usize) -> Result<T, Error> {
let batch = offset / self.batch_size;
let i = offset % self.batch_size;
let source_offset = (batch * self.dim * self.dim) + (i * self.dim) + i;
self.access.read_value(source_offset)
}
}
pub struct MatMul<L, R, T> {
left: L,
right: R,
batch_size: usize,
dims: [usize; 3],
dtype: PhantomData<T>,
}
impl<L, R, T> MatMul<L, R, T> {
pub fn new(left: L, right: R, dims: [usize; 4]) -> Self {
let [batch_size, a, b, c] = dims;
Self {
left,
right,
batch_size,
dims: [a, b, c],
dtype: PhantomData,
}
}
}
impl<L, R, T> Op for MatMul<L, R, T>
where
L: Send + Sync,
R: Send + Sync,
T: Send + Sync,
{
fn size(&self) -> usize {
self.batch_size * self.dims[0] * self.dims[2]
}
}
impl<L, R, T> Enqueue<Stack, T> for MatMul<L, R, T>
where
L: Access<T>,
R: Access<T>,
T: CType,
{
type Buffer = StackVec<T>;
fn enqueue(&self) -> Result<Self::Buffer, Error> {
let left = self.left.read()?.to_slice()?;
let right = self.right.read()?.to_slice()?;
let [a, b, c] = self.dims;
let mut product = StackVec::with_capacity(self.batch_size * a * c);
for _batch in 0..self.batch_size {
for x in 0..a {
for z in 0..c {
let mut sum = T::ZERO;
for y in 0..b {
let l_offset = (x * b) + y;
let r_offset = (y * c) + z;
sum = T::add(sum, T::mul(left[l_offset], right[r_offset]));
}
product.push(sum)
}
}
}
debug_assert_eq!(product.len(), self.size());
Ok(product)
}
}
impl<L, R, T> Enqueue<Heap, T> for MatMul<L, R, T>
where
L: Access<T>,
R: Access<T>,
T: CType,
{
type Buffer = Vec<T>;
fn enqueue(&self) -> Result<Self::Buffer, Error> {
let [a, b, c] = self.dims;
let (left, right) = try_join_read(&self.left, &self.right)?;
let right_size = b * c;
let right_matrices = right.par_chunks_exact(right_size).map(|right| {
let mut right_t = vec![T::ZERO; right_size];
transpose::transpose(right, &mut right_t[..], c, b);
right_t
});
let left_size = a * b;
let left_matrices = left.par_chunks_exact(left_size);
let output_size = a * c;
let mut output = Vec::<T>::with_capacity(self.batch_size * output_size);
let output_matrices = left_matrices
.zip(right_matrices)
.map(|(lm, rm)| {
let mut out = Vec::<T>::with_capacity(output_size);
let product = lm
.par_chunks_exact(b)
.map(|row| {
rm.par_chunks_exact(b).map(move |col| {
let col = col.par_chunks(8).map(|cc| cc.into_iter().copied());
row.par_chunks(8)
.zip(col)
.map(|(rc, cc)| {
rc.into_iter()
.copied()
.zip(cc)
.map(|(r, c)| T::mul(r, c))
.reduce(T::add)
.expect("sum")
})
.reduce(|| T::ZERO, T::add)
})
})
.flatten();
out.par_extend(product);
out
})
.flatten();
output.par_extend(output_matrices);
debug_assert_eq!(output.len(), self.batch_size * output_size);
Ok(output)
}
}
impl<L, R, T> Enqueue<Host, T> for MatMul<L, R, T>
where
L: Access<T>,
R: Access<T>,
T: CType,
{
type Buffer = Buffer<T>;
fn enqueue(&self) -> Result<Self::Buffer, Error> {
host_enqueue!(
self,
self.left.size() < VEC_MIN_SIZE && self.right.size() < VEC_MIN_SIZE,
T
)
}
}
impl<L, R, T> ReadValue<Host, T> for MatMul<L, R, T>
where
L: Access<T>,
R: Access<T>,
T: CType,
{
fn read_value(&self, _offset: usize) -> Result<T, Error> {
Err(Error::Bounds(
"reading an individual value from a matrix multiplication is not implemented"
.to_string(),
))
}
}
pub struct Scalar<A, IT, OT> {
access: A,
scalar: IT,
op: fn(IT, IT) -> OT,
}
impl<A, IT, OT> Scalar<A, IT, OT> {
fn new(access: A, scalar: IT, op: fn(IT, IT) -> OT) -> Self {
Self { access, scalar, op }
}
}
impl<A, T: CType> Scalar<A, T, T> {
pub fn add(access: A, scalar: T) -> Self {
Self::new(access, scalar, T::add)
}
pub fn div(access: A, scalar: T) -> Self {
Self::new(access, scalar, T::div)
}
pub fn log(access: A, scalar: T) -> Self {
Self::new(access, scalar, |a, b| {
T::from_float(a.to_float().log(b.to_float()))
})
}
pub fn mul(access: A, scalar: T) -> Self {
Self::new(access, scalar, T::mul)
}
pub fn pow(access: A, scalar: T) -> Self {
Self::new(access, scalar, T::pow)
}
pub fn rem(access: A, scalar: T) -> Self {
Self::new(access, scalar, T::rem)
}
pub fn sub(access: A, scalar: T) -> Self {
Self::new(access, scalar, T::sub)
}
}
impl<A, T> Scalar<A, T, u8> {
pub fn and(access: A, scalar: T) -> Self
where
T: CType,
{
Self::new(access, scalar, |l, r| {
if (l != T::ZERO) && (r != T::ZERO) {
1
} else {
0
}
})
}
pub fn or(access: A, scalar: T) -> Self
where
T: CType,
{
Self::new(access, scalar, |l, r| {
if (l != T::ZERO) || (r != T::ZERO) {
1
} else {
0
}
})
}
pub fn xor(access: A, scalar: T) -> Self
where
T: CType,
{
Self::new(access, scalar, |l, r| {
if (l != T::ZERO) ^ (r != T::ZERO) {
1
} else {
0
}
})
}
pub fn eq(access: A, scalar: T) -> Self
where
T: PartialEq,
{
Self::new(access, scalar, |l, r| if l == r { 1 } else { 0 })
}
pub fn ge(access: A, scalar: T) -> Self
where
T: PartialOrd,
{
Self::new(access, scalar, |l, r| if l >= r { 1 } else { 0 })
}
pub fn gt(access: A, scalar: T) -> Self
where
T: PartialOrd,
{
Self::new(access, scalar, |l, r| if l > r { 1 } else { 0 })
}
pub fn le(access: A, scalar: T) -> Self
where
T: PartialOrd,
{
Self::new(access, scalar, |l, r| if l <= r { 1 } else { 0 })
}
pub fn lt(access: A, scalar: T) -> Self
where
T: PartialOrd,
{
Self::new(access, scalar, |l, r| if l < r { 1 } else { 0 })
}
pub fn ne(access: A, scalar: T) -> Self
where
T: PartialEq,
{
Self::new(access, scalar, |l, r| if l != r { 1 } else { 0 })
}
}
impl<A, IT, OT> Op for Scalar<A, IT, OT>
where
A: Access<IT>,
IT: CType,
OT: CType,
{
fn size(&self) -> usize {
self.access.size()
}
}
impl<A, IT, OT> Enqueue<Heap, OT> for Scalar<A, IT, OT>
where
A: Access<IT>,
IT: CType,
OT: CType,
{
type Buffer = Vec<OT>;
fn enqueue(&self) -> Result<Self::Buffer, Error> {
self.access
.read()
.and_then(|buf| buf.to_slice())
.map(|slice| {
slice
.as_ref()
.into_par_iter()
.copied()
.map(|l| (self.op)(l, self.scalar))
.collect()
})
}
}
impl<A, IT, OT> Enqueue<Stack, OT> for Scalar<A, IT, OT>
where
A: Access<IT>,
IT: CType,
OT: CType,
{
type Buffer = StackVec<OT>;
fn enqueue(&self) -> Result<Self::Buffer, Error> {
self.access
.read()
.and_then(|buf| buf.to_slice())
.map(|slice| {
slice
.as_ref()
.into_iter()
.copied()
.map(|l| (self.op)(l, self.scalar))
.collect()
})
}
}
impl<A, IT, OT> Enqueue<Host, OT> for Scalar<A, IT, OT>
where
A: Access<IT>,
IT: CType,
OT: CType,
{
type Buffer = Buffer<OT>;
fn enqueue(&self) -> Result<Self::Buffer, Error> {
host_enqueue!(self, self.size() < VEC_MIN_SIZE, OT)
}
}
impl<A, IT, OT> ReadValue<Host, OT> for Scalar<A, IT, OT>
where
A: Access<IT>,
IT: CType,
OT: CType,
{
fn read_value(&self, offset: usize) -> Result<OT, Error> {
self.access
.read_value(offset)
.map(|n| (self.op)(n, self.scalar))
}
}
pub struct RandomNormal {
size: usize,
}
impl RandomNormal {
pub fn new(size: usize) -> Self {
Self { size }
}
fn box_muller(u: [f32; 2]) -> [f32; 2] {
let [u1, u2] = u;
let r = (u1.ln() * -2.).sqrt();
let theta = 2. * PI * u2;
[r * theta.cos(), r * theta.sin()]
}
}
impl Op for RandomNormal {
fn size(&self) -> usize {
self.size
}
}
impl Enqueue<Heap, f32> for RandomNormal {
type Buffer = Vec<f32>;
fn enqueue(&self) -> Result<Self::Buffer, Error> {
let mut u = vec![
0.0f32;
if self.size % 2 == 0 {
self.size
} else {
self.size + 1
}
];
rand::thread_rng().fill(&mut u[..]);
let mut output = u
.par_chunks_exact(2)
.map(|u| {
let u: [f32; 2] = u.try_into().expect("u");
Self::box_muller(u)
})
.flatten()
.collect::<Vec<f32>>();
if output.len() > self.size {
output.pop();
}
debug_assert_eq!(output.len(), self.size);
Ok(output)
}
}
impl Enqueue<Stack, f32> for RandomNormal {
type Buffer = StackVec<f32>;
fn enqueue(&self) -> Result<Self::Buffer, Error> {
let mut rng = rand::thread_rng();
let mut output = iter::repeat_with(|| [rng.gen(), rng.gen()])
.take(self.size.div_ceil(2))
.map(Self::box_muller)
.flatten()
.collect::<StackVec<f32>>();
if output.len() > self.size {
output.pop();
}
debug_assert_eq!(output.len(), self.size);
Ok(output)
}
}
impl Enqueue<Host, f32> for RandomNormal {
type Buffer = Buffer<f32>;
fn enqueue(&self) -> Result<Self::Buffer, Error> {
host_enqueue!(self, self.size < VEC_MIN_SIZE, f32)
}
}
impl ReadValue<Host, f32> for RandomNormal {
fn read_value(&self, _offset: usize) -> Result<f32, Error> {
Err(Error::Bounds(
"cannot calculate an individual value of a random normal distribution".to_string(),
))
}
}
pub struct RandomUniform {
size: usize,
}
impl RandomUniform {
pub fn new(size: usize) -> Self {
Self { size }
}
}
impl Op for RandomUniform {
fn size(&self) -> usize {
self.size
}
}
impl Enqueue<Heap, f32> for RandomUniform {
type Buffer = Vec<f32>;
fn enqueue(&self) -> Result<Self::Buffer, Error> {
let mut data = vec![0.; self.size];
rand::thread_rng().fill(&mut data[..]);
Ok(data)
}
}
impl Enqueue<Stack, f32> for RandomUniform {
type Buffer = StackVec<f32>;
fn enqueue(&self) -> Result<Self::Buffer, Error> {
let mut data = stackvec![0.; self.size];
rand::thread_rng().fill(&mut data[..]);
Ok(data)
}
}
impl Enqueue<Host, f32> for RandomUniform {
type Buffer = Buffer<f32>;
fn enqueue(&self) -> Result<Self::Buffer, Error> {
host_enqueue!(self, self.size < VEC_MIN_SIZE, f32)
}
}
impl ReadValue<Host, f32> for RandomUniform {
fn read_value(&self, _offset: usize) -> Result<f32, Error> {
Ok(rand::thread_rng().gen())
}
}
pub struct Reduce<A, T> {
access: A,
stride: usize,
reduce: fn(T, T) -> T,
id: T,
}
impl<A, T> Reduce<A, T>
where
T: CType,
{
pub fn max(access: A, stride: usize) -> Self {
Self {
access,
stride,
reduce: CType::max,
id: T::MIN,
}
}
pub fn min(access: A, stride: usize) -> Self {
Self {
access,
stride,
reduce: CType::min,
id: T::MAX,
}
}
pub fn product(access: A, stride: usize) -> Self {
Self {
access,
stride,
reduce: T::mul,
id: T::ONE,
}
}
pub fn sum(access: A, stride: usize) -> Self {
Self {
access,
stride,
reduce: T::add,
id: T::ZERO,
}
}
}
impl<A: Access<T>, T: CType> Op for Reduce<A, T> {
fn size(&self) -> usize {
debug_assert_eq!(self.access.size() % self.stride, 0);
self.access.size() / self.stride
}
}
impl<A: Access<T>, T: CType> Enqueue<Heap, T> for Reduce<A, T> {
type Buffer = Vec<T>;
fn enqueue(&self) -> Result<Self::Buffer, Error> {
self.access
.read()
.and_then(|buf| buf.to_slice())
.map(|slice| {
slice
.chunks_exact(self.stride)
.map(|chunk| {
chunk
.par_chunks(8)
.map(|chunk| {
chunk.iter().copied().reduce(self.reduce).expect("reduced")
})
.reduce(|| self.id, self.reduce)
})
.collect()
})
}
}
impl<A: Access<T>, T: CType> Enqueue<Stack, T> for Reduce<A, T> {
type Buffer = StackVec<T>;
fn enqueue(&self) -> Result<Self::Buffer, Error> {
self.access
.read()
.and_then(|buf| buf.to_slice())
.map(|slice| {
slice
.chunks_exact(self.stride)
.map(|chunk| chunk.iter().copied().reduce(self.reduce).expect("reduced"))
.collect()
})
}
}
impl<A: Access<T>, T: CType> Enqueue<Host, T> for Reduce<A, T> {
type Buffer = Buffer<T>;
fn enqueue(&self) -> Result<Self::Buffer, Error> {
host_enqueue!(
self,
self.stride < VEC_MIN_SIZE && self.size() < VEC_MIN_SIZE,
T
)
}
}
impl<A: Access<T>, T: CType> ReadValue<Host, T> for Reduce<A, T> {
fn read_value(&self, offset: usize) -> Result<T, Error> {
let offset = offset * self.stride;
if offset < self.access.size() {
(offset..(offset + self.stride))
.into_par_iter()
.map(|offset| self.access.read_value(offset))
.try_reduce(|| self.id, |r, v| Ok((self.reduce)(r, v)))
} else {
Err(Error::Bounds(format!(
"invalid offset {offset} for a reduce op with size {}",
self.size()
)))
}
}
}
pub struct Slice<A, T> {
access: A,
spec: SliceSpec,
dtype: PhantomData<T>,
}
impl<A, T> Slice<A, T> {
pub fn new(access: A, shape: &[usize], range: Range) -> Self {
let spec = SliceSpec::new(shape, range);
Self {
access,
spec,
dtype: PhantomData,
}
}
}
impl<A: Send + Sync, T: Copy + Send + Sync> Slice<A, T> {
fn read(&self, source: &[T]) -> Result<StackVec<T>, Error> {
let output = (0..self.size())
.into_iter()
.map(|offset_out| self.spec.source_offset(offset_out))
.map(|offset_in| source[offset_in])
.collect();
Ok(output)
}
fn read_parallel(&self, source: &[T]) -> Result<Vec<T>, Error> {
let output = (0..self.size())
.into_par_iter()
.map(|offset_out| self.spec.source_offset(offset_out))
.map(|offset_in| source[offset_in])
.collect();
Ok(output)
}
}
impl<A, T> Slice<A, T>
where
T: CType,
A: AccessMut<T>,
{
fn overwrite<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error> {
if data.len() == self.size() {
let data = data.to_slice()?;
for (offset, value) in data.into_iter().copied().enumerate() {
let source_offset = self.spec.source_offset(offset);
self.access.write_value_at(source_offset, value)?;
}
Ok(())
} else {
Err(Error::Bounds(format!(
"cannot overwrite a slice of size {} with a buffer of size {}",
self.size(),
data.len(),
)))
}
}
fn overwrite_value(&mut self, value: T) -> Result<(), Error> {
for offset in 0..self.access.size() {
let source_offset = self.spec.source_offset(offset);
self.access.write_value_at(source_offset, value)?;
}
Ok(())
}
fn overwrite_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
let source_offset = self.spec.source_offset(offset);
self.access.write_value_at(source_offset, value)
}
}
impl<A: Send + Sync, T: Send + Sync> Op for Slice<A, T> {
fn size(&self) -> usize {
self.spec.size()
}
}
impl<A: Access<T>, T: CType> Enqueue<Heap, T> for Slice<A, T> {
type Buffer = Vec<T>;
fn enqueue(&self) -> Result<Self::Buffer, Error> {
self.access
.read()
.and_then(|buf| buf.to_slice())
.and_then(|buf| self.read_parallel(&*buf))
}
}
impl<A: Access<T>, T: CType> Enqueue<Stack, T> for Slice<A, T> {
type Buffer = StackVec<T>;
fn enqueue(&self) -> Result<Self::Buffer, Error> {
self.access
.read()
.and_then(|buf| buf.to_slice())
.and_then(|buf| self.read(&*buf))
}
}
impl<A: Access<T>, T: CType> Enqueue<Host, T> for Slice<A, T> {
type Buffer = Buffer<T>;
fn enqueue(&self) -> Result<Self::Buffer, Error> {
host_enqueue!(self, self.size() < VEC_MIN_SIZE, T)
}
}
impl<A: Access<T>, T: CType> ReadValue<Host, T> for Slice<A, T> {
fn read_value(&self, offset: usize) -> Result<T, Error> {
let offset = self.spec.source_offset(offset);
self.access.read_value(offset)
}
}
impl<A, T> crate::ops::Write<Heap, T> for Slice<A, T>
where
T: CType,
A: AccessMut<T>,
{
fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error> {
self.overwrite(data)
}
fn write_value(&mut self, value: T) -> Result<(), Error> {
self.overwrite_value(value)
}
fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
self.overwrite_value_at(offset, value)
}
}
impl<A, T> crate::ops::Write<Stack, T> for Slice<A, T>
where
T: CType,
A: AccessMut<T>,
{
fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error> {
self.overwrite(data)
}
fn write_value(&mut self, value: T) -> Result<(), Error> {
self.overwrite_value(value)
}
fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
self.overwrite_value_at(offset, value)
}
}
impl<A, T> crate::ops::Write<Host, T> for Slice<A, T>
where
T: CType,
A: AccessMut<T>,
{
fn write<'a>(&mut self, data: BufferConverter<'a, T>) -> Result<(), Error> {
self.overwrite(data)
}
fn write_value(&mut self, value: T) -> Result<(), Error> {
self.overwrite_value(value)
}
fn write_value_at(&mut self, offset: usize, value: T) -> Result<(), Error> {
self.overwrite_value_at(offset, value)
}
}
pub struct Unary<A, IT, OT> {
access: A,
op: fn(IT) -> OT,
}
impl<A: Access<T>, T: CType> Unary<A, T, T> {
pub fn abs(access: A) -> Self {
Self {
access,
op: CType::abs,
}
}
pub fn exp(access: A) -> Self {
Self {
access,
op: |n| T::from_float(n.to_float().exp()),
}
}
pub fn ln(access: A) -> Self {
Self {
access,
op: |n| T::from_float(n.to_float().ln()),
}
}
pub fn round(access: A) -> Self {
Self {
access,
op: CType::round,
}
}
}
impl<A: Access<T>, T: CType> Unary<A, T, T::Float> {
pub fn sin(access: A) -> Self {
Self {
access,
op: |n| n.to_float().sin(),
}
}
pub fn asin(access: A) -> Self {
Self {
access,
op: |n| n.to_float().asin(),
}
}
pub fn sinh(access: A) -> Self {
Self {
access,
op: |n| n.to_float().sinh(),
}
}
pub fn cos(access: A) -> Self {
Self {
access,
op: |n| n.to_float().cos(),
}
}
pub fn acos(access: A) -> Self {
Self {
access,
op: |n| n.to_float().acos(),
}
}
pub fn cosh(access: A) -> Self {
Self {
access,
op: |n| n.to_float().cosh(),
}
}
pub fn tan(access: A) -> Self {
Self {
access,
op: |n| n.to_float().tan(),
}
}
pub fn atan(access: A) -> Self {
Self {
access,
op: |n| n.to_float().atan(),
}
}
pub fn tanh(access: A) -> Self {
Self {
access,
op: |n| n.to_float().tanh(),
}
}
}
impl<A: Access<T>, T: CType> Unary<A, T, u8> {
pub fn not(access: A) -> Self {
Self {
access,
op: |n| if n == T::ZERO { 1 } else { 0 },
}
}
}
impl<A: Access<T>, T: Float> Unary<A, T, u8> {
pub fn inf(access: A) -> Self {
Self {
access,
op: |n| if n.is_inf() { 1 } else { 0 },
}
}
pub fn nan(access: A) -> Self {
Self {
access,
op: |n| if n.is_nan() { 1 } else { 0 },
}
}
}
impl<A, IT, OT> Op for Unary<A, IT, OT>
where
A: Access<IT>,
IT: CType,
OT: CType,
{
fn size(&self) -> usize {
self.access.size()
}
}
impl<A, IT, OT> Enqueue<Heap, OT> for Unary<A, IT, OT>
where
A: Access<IT>,
IT: CType,
OT: CType,
{
type Buffer = Vec<OT>;
fn enqueue(&self) -> Result<Self::Buffer, Error> {
self.access
.read()
.and_then(|buf| buf.to_slice())
.map(|input| input.into_par_iter().copied().map(self.op).collect())
}
}
impl<A, IT, OT> Enqueue<Stack, OT> for Unary<A, IT, OT>
where
A: Access<IT>,
IT: CType,
OT: CType,
{
type Buffer = StackVec<OT>;
fn enqueue(&self) -> Result<Self::Buffer, Error> {
self.access
.read()
.and_then(|buf| buf.to_slice())
.map(|input| input.into_iter().copied().map(self.op).collect())
}
}
impl<A, IT, OT> Enqueue<Host, OT> for Unary<A, IT, OT>
where
A: Access<IT>,
IT: CType,
OT: CType,
{
type Buffer = Buffer<OT>;
fn enqueue(&self) -> Result<Self::Buffer, Error> {
host_enqueue!(self, self.size() < VEC_MIN_SIZE, OT)
}
}
impl<A, IT, OT> ReadValue<Host, OT> for Unary<A, IT, OT>
where
A: Access<IT>,
IT: CType,
OT: CType,
{
fn read_value(&self, offset: usize) -> Result<OT, Error> {
self.access.read_value(offset).map(|n| (self.op)(n))
}
}
pub struct View<A, T> {
access: A,
spec: ViewSpec,
dtype: PhantomData<T>,
}
impl<A: Access<T>, T: CType> View<A, T> {
pub fn broadcast(access: A, shape: Shape, broadcast: Shape) -> Self {
let source_strides = strides_for(&shape, shape.len()).collect();
Self {
access,
spec: ViewSpec::new(broadcast, source_strides),
dtype: PhantomData,
}
}
pub fn transpose(access: A, shape: Shape, axes: Axes) -> Self {
let strides = strides_for(&shape, shape.len()).collect::<Strides>();
let shape = axes.iter().copied().map(|x| shape[x]).collect::<Strides>();
let source_strides = axes.into_iter().map(|x| strides[x]).collect::<Strides>();
Self {
access,
spec: ViewSpec::new(shape, source_strides),
dtype: PhantomData,
}
}
}
impl<A: Access<T>, T: CType> Op for View<A, T> {
fn size(&self) -> usize {
self.spec.size()
}
}
impl<A: Access<T>, T: CType> Enqueue<Stack, T> for View<A, T> {
type Buffer = StackVec<T>;
fn enqueue(&self) -> Result<Self::Buffer, Error> {
let source = self.access.read().and_then(|source| source.to_slice())?;
let buffer = (0..self.spec.size())
.into_iter()
.map(|offset| self.spec.source_offset(offset))
.map(|source_offset| source[source_offset])
.collect();
Ok(buffer)
}
}
impl<A: Access<T>, T: CType> Enqueue<Heap, T> for View<A, T> {
type Buffer = Vec<T>;
fn enqueue(&self) -> Result<Self::Buffer, Error> {
let source = self.access.read().and_then(|source| source.to_slice())?;
let buffer = (0..self.spec.size())
.into_par_iter()
.map(|offset| self.spec.source_offset(offset))
.map(|source_offset| source[source_offset])
.collect();
Ok(buffer)
}
}
impl<A: Access<T>, T: CType> Enqueue<Host, T> for View<A, T> {
type Buffer = Buffer<T>;
fn enqueue(&self) -> Result<Self::Buffer, Error> {
host_enqueue!(self, self.size() < VEC_MIN_SIZE, T)
}
}
impl<A: Access<T>, T: CType> ReadValue<Host, T> for View<A, T> {
fn read_value(&self, offset: usize) -> Result<T, Error> {
self.access.read_value(self.spec.source_offset(offset))
}
}
fn exec_dual<IT: CType, OT: CType>(
zip: fn(IT, IT) -> OT,
left: SliceConverter<IT>,
right: SliceConverter<IT>,
) -> Result<StackVec<OT>, Error> {
let output = left
.into_iter()
.copied()
.zip(right.into_iter().copied())
.map(|(l, r)| (zip)(l, r))
.collect();
Ok(output)
}
fn exec_dual_parallel<IT: CType, OT: CType>(
zip: fn(IT, IT) -> OT,
left: SliceConverter<IT>,
right: SliceConverter<IT>,
) -> Result<Vec<OT>, Error> {
let output = left
.into_par_iter()
.copied()
.zip(right.into_par_iter().copied())
.map(|(l, r)| (zip)(l, r))
.collect();
Ok(output)
}
#[inline]
fn try_join_read<'a, L, R, T>(
left: &'a L,
right: &'a R,
) -> Result<(SliceConverter<'a, T>, SliceConverter<'a, T>), Error>
where
L: Access<T>,
R: Access<T>,
T: CType,
{
let (l, r) = join(
|| left.read().and_then(|buf| buf.to_slice()),
|| right.read().and_then(|buf| buf.to_slice()),
);
Ok((l?, r?))
}
#[inline]
fn try_join_value<'a, L, R, T>(left: &'a L, right: &'a R, offset: usize) -> Result<(T, T), Error>
where
L: Access<T>,
R: Access<T>,
T: CType,
{
let (l, r) = join(|| left.read_value(offset), || right.read_value(offset));
Ok((l?, r?))
}