use std::alloc::Layout;
use std::fmt::Debug;
use tract_data::internal::*;
use crate::LADatum;
use super::{BinOp, FusedKerSpec, FusedSpec, MatMatMulKer, OutputStoreKer};
use downcast_rs::{impl_downcast, Downcast};
use tract_data::internal::num_integer::Integer;
pub trait ScratchSpace: Downcast + Send {}
impl_downcast!(ScratchSpace);
#[derive(Debug)]
pub struct ScratchSpaceFusedNonLinear<TI: LADatum> {
uspecs: Vec<FusedKerSpec<TI>>,
layout: Layout,
buffer: *const u8,
loc_dependant: TVec<LocDependant>,
}
impl<TI: LADatum> Default for ScratchSpaceFusedNonLinear<TI> {
fn default() -> Self {
ScratchSpaceFusedNonLinear {
uspecs: vec![],
layout: unsafe { Layout::from_size_align_unchecked(0, 1) },
buffer: std::ptr::null(),
loc_dependant: tvec!(),
}
}
}
#[derive(Debug, new)]
struct LocDependant {
spec: usize,
uspec: usize,
loc: *const u8,
buffer: Option<*const u8>,
}
impl<TI: LADatum> ScratchSpace for ScratchSpaceFusedNonLinear<TI> {}
unsafe impl<TI: LADatum> Send for ScratchSpaceFusedNonLinear<TI> {}
impl<TI: LADatum> Drop for ScratchSpaceFusedNonLinear<TI> {
fn drop(&mut self) {
if !self.buffer.is_null() {
unsafe {
std::alloc::dealloc(self.buffer as _, self.layout);
}
}
}
}
struct AddMatMulTemp(*const u8, usize);
impl<TI: LADatum> ScratchSpaceFusedNonLinear<TI> {
pub unsafe fn prepare<K: MatMatMulKer<TI>>(&mut self, specs: &[FusedSpec]) {
use FusedKerSpec as FKS;
use FusedSpec as FS;
self.uspecs.clear();
self.loc_dependant.clear();
self.uspecs.reserve(specs.len() + 2);
self.uspecs.push(FusedKerSpec::Clear);
let mut offset = 0;
let mut align = std::mem::size_of::<*const ()>();
fn ld(spec: usize, uspec: usize, loc: *const u8) -> LocDependant {
LocDependant { spec, uspec, loc, buffer: None }
}
for (ix, spec) in specs.iter().enumerate() {
let uspec = match spec {
FS::BinScalar(t, op) => match op {
BinOp::Min => FKS::ScalarMin(*t.to_scalar_unchecked()),
BinOp::Max => FKS::ScalarMax(*t.to_scalar_unchecked()),
BinOp::Mul => FKS::ScalarMul(*t.to_scalar_unchecked()),
BinOp::Add => FKS::ScalarAdd(*t.to_scalar_unchecked()),
BinOp::Sub => FKS::ScalarSub(*t.to_scalar_unchecked()),
BinOp::SubF => FKS::ScalarSubF(*t.to_scalar_unchecked()),
},
FS::ShiftLeft(s) => FKS::ShiftLeft(*s),
FS::RoundingShiftRight(s, rp) => FKS::RoundingShiftRight(*s, *rp),
FS::QScale(s, rp, m) => FKS::QScale(*s, *rp, *m),
FS::BinPerRow(_, _) => {
self.loc_dependant.push(ld(ix, self.uspecs.len(), offset as _));
offset += TI::datum_type().size_of() * K::mr();
FusedKerSpec::Done
}
FS::BinPerCol(_, _) => {
self.loc_dependant.push(ld(ix, self.uspecs.len(), offset as _));
offset += TI::datum_type().size_of() * K::nr();
FusedKerSpec::Done
}
FS::AddRowColProducts(_, _) => {
self.loc_dependant.push(ld(ix, self.uspecs.len(), offset as _));
offset += TI::datum_type().size_of() * (K::mr() + K::nr());
FusedKerSpec::Done
}
FS::Store(_) | FS::AddUnicast(_) => {
self.loc_dependant.push(ld(ix, self.uspecs.len(), offset as _));
offset += TI::datum_type().size_of() * K::mr() * K::nr();
FusedKerSpec::Done
}
FS::AddMatMul { b, .. } => {
let mut ld = ld(ix, self.uspecs.len(), offset as _);
offset += std::mem::size_of::<AddMatMulTemp>();
if let Some(tmp) = b.scratch_panel_buffer_layout() {
align = tmp.align().lcm(&align);
offset = Integer::next_multiple_of(&offset, &tmp.align());
ld.buffer = Some(offset as _);
offset += tmp.size();
}
self.loc_dependant.push(ld);
FusedKerSpec::Done
}
};
self.uspecs.push(uspec);
}
self.uspecs.push(FKS::Done);
if offset > self.layout.size() || align > self.layout.align() {
if !self.buffer.is_null() {
std::alloc::dealloc(self.buffer as _, self.layout);
}
self.layout = Layout::from_size_align_unchecked(offset, align);
self.buffer = std::alloc::alloc(self.layout);
assert!(!self.buffer.is_null());
}
for LocDependant { loc, buffer, spec, .. } in &mut self.loc_dependant {
*loc = self.buffer.offset(*loc as _);
if let Some(b) = buffer {
*b = self.buffer.offset(*b as _);
}
let spec = specs.get_unchecked(*spec);
#[allow(clippy::single_match)]
match spec {
FS::AddMatMul { .. } => {
let scratch = *loc as *mut AddMatMulTemp;
(*scratch).1 = usize::MAX;
}
_ => (),
};
}
}
#[inline(always)]
pub unsafe fn for_valid_tile<K: MatMatMulKer<TI>>(
&mut self,
specs: &[FusedSpec],
down: usize,
right: usize,
) {
use FusedKerSpec as FKS;
use FusedSpec as FS;
let ScratchSpaceFusedNonLinear { uspecs, loc_dependant, .. } = self;
debug_assert!(specs.len() + 2 == uspecs.len());
for LocDependant { spec, uspec, loc, buffer } in loc_dependant.iter_mut() {
let spec = specs.get_unchecked(*spec);
*uspecs.get_unchecked_mut(*uspec) = match spec {
FS::BinPerRow(v, op) => {
let v = v.as_ptr_unchecked::<TI>().add(down * K::mr());
match op {
BinOp::Min => FKS::PerRowMin(v),
BinOp::Max => FKS::PerRowMax(v),
BinOp::Add => FKS::PerRowAdd(v),
BinOp::Mul => FKS::PerRowMul(v),
BinOp::Sub => FKS::PerRowSub(v),
BinOp::SubF => FKS::PerRowSubF(v),
}
}
FS::BinPerCol(v, op) => {
let v = v.as_ptr_unchecked::<TI>().add(right * K::nr());
match op {
BinOp::Min => FKS::PerColMin(v),
BinOp::Max => FKS::PerColMax(v),
BinOp::Add => FKS::PerColAdd(v),
BinOp::Mul => FKS::PerColMul(v),
BinOp::Sub => FKS::PerColSub(v),
BinOp::SubF => FKS::PerColSubF(v),
}
}
FS::AddRowColProducts(rows, cols) => {
let row_ptr = rows.as_ptr_unchecked::<TI>().add(down * K::mr());
let col_ptr = cols.as_ptr_unchecked::<TI>().add(right * K::nr());
FKS::AddRowColProducts(row_ptr, col_ptr)
}
FS::AddUnicast(store) => FKS::AddUnicast(store.tile_c(down, right)),
FS::Store(c_store) => FKS::Store(c_store.tile_c(down, right)),
FS::AddMatMul { k, a, b } => {
let pa = a.panel(down);
K::prefetch(pa as _, 512);
let scratch = *loc as *mut AddMatMulTemp;
if (*scratch).1 != right {
(*scratch).0 = b.panel_b(right, *buffer);
(*scratch).1 = right;
}
FKS::AddMatMul { k: *k, pa, pb: (*scratch).0, cpu_variant: 0 }
}
_ => std::hint::unreachable_unchecked(),
};
}
}
#[inline(never)]
pub unsafe fn for_border_tile<K: MatMatMulKer<TI>>(
&mut self,
specs: &[FusedSpec],
down: usize,
right: usize,
) {
use FusedKerSpec as FKS;
use FusedSpec as FS;
let ScratchSpaceFusedNonLinear { uspecs, loc_dependant, .. } = self;
debug_assert!(specs.len() + 2 == uspecs.len());
for LocDependant { spec, uspec, loc, buffer } in loc_dependant.iter_mut() {
let spec = specs.get_unchecked(*spec);
*uspecs.get_unchecked_mut(*uspec) = match spec {
FS::BinPerRow(v, op) => {
let buf = std::slice::from_raw_parts_mut(*loc as *mut TI, K::mr());
let have = v.len().saturating_sub(down * K::mr()).min(K::mr());
let ptr = if have < K::mr() {
if have > 0 {
buf.get_unchecked_mut(..have).copy_from_slice(
v.as_slice_unchecked()
.get_unchecked(down * K::mr()..)
.get_unchecked(..have),
);
}
if cfg!(debug_assertions) {
buf.get_unchecked_mut(have..).iter_mut().for_each(|x| *x = TI::zero());
}
buf.as_ptr()
} else {
v.as_ptr_unchecked::<TI>().add(down * K::mr())
};
match op {
BinOp::Min => FKS::PerRowMin(ptr),
BinOp::Max => FKS::PerRowMax(ptr),
BinOp::Add => FKS::PerRowAdd(ptr),
BinOp::Mul => FKS::PerRowMul(ptr),
BinOp::Sub => FKS::PerRowSub(ptr),
BinOp::SubF => FKS::PerRowSubF(ptr),
}
}
FS::BinPerCol(v, op) => {
let buf = std::slice::from_raw_parts_mut(*loc as *mut TI, K::nr());
let have = v.len().saturating_sub(right * K::nr()).min(K::nr());
let ptr = if have < K::nr() {
if have > 0 {
buf.get_unchecked_mut(..have).copy_from_slice(
v.as_slice_unchecked()
.get_unchecked(right * K::nr()..)
.get_unchecked(..have),
);
}
if cfg!(debug_assertions) {
buf.get_unchecked_mut(have..).iter_mut().for_each(|x| *x = TI::zero());
}
buf.as_ptr()
} else {
v.as_ptr_unchecked::<TI>().add(right * K::nr())
};
match op {
BinOp::Min => FKS::PerColMin(ptr),
BinOp::Max => FKS::PerColMax(ptr),
BinOp::Add => FKS::PerColAdd(ptr),
BinOp::Mul => FKS::PerColMul(ptr),
BinOp::Sub => FKS::PerColSub(ptr),
BinOp::SubF => FKS::PerColSubF(ptr),
}
}
FS::AddRowColProducts(rows, cols) => {
let r = std::slice::from_raw_parts_mut(*loc as *mut TI, K::mr());
let have = rows.len() - down * K::mr();
let row_ptr = if have < K::mr() {
r.get_unchecked_mut(..have).copy_from_slice(
rows
.as_slice_unchecked()
.get_unchecked(down * K::mr()..)
.get_unchecked(..have),
);
if cfg!(debug_assertions) {
r.get_unchecked_mut(have..).iter_mut().for_each(|x| *x = TI::zero());
}
r.as_ptr()
} else {
rows.as_ptr_unchecked::<TI>().add(down * K::mr())
};
let c = std::slice::from_raw_parts_mut(
(*loc as *mut TI).add(K::mr()),
K::nr(),
);
let have = cols.len() - right * K::nr();
let col_ptr = if have < K::nr() {
c.get_unchecked_mut(..have).copy_from_slice(
cols
.as_slice_unchecked()
.get_unchecked(right * K::nr()..)
.get_unchecked(..have),
);
if cfg!(debug_assertions) {
r.get_unchecked_mut(have..).iter_mut().for_each(|x| *x = TI::zero());
}
c.as_ptr()
} else {
cols.as_ptr_unchecked::<TI>().add(right * K::nr())
};
FKS::AddRowColProducts(row_ptr, col_ptr)
}
FS::AddUnicast(store) => {
let row_byte_stride = store.row_byte_stride;
let col_byte_stride = store.col_byte_stride;
let tile_offset = row_byte_stride * down as isize * K::mr() as isize
+ col_byte_stride * right as isize * K::nr() as isize;
let tile_ptr = store.ptr.offset(tile_offset);
let tmp_d_tile =
std::slice::from_raw_parts_mut(*loc as *mut TI, K::mr() * K::nr());
let m = (store.m - down * K::mr()).min(K::mr());
let n = (store.n - right * K::nr()).min(K::nr());
for r in 0..m as isize {
for c in 0..n as isize {
let inner_offset = c * col_byte_stride + r * row_byte_stride;
if inner_offset + tile_offset
< (store.item_size * store.item_count) as isize
{
*tmp_d_tile.get_unchecked_mut(r as usize + c as usize * K::mr()) =
*(tile_ptr.offset(inner_offset) as *const TI);
}
}
}
FKS::AddUnicast(OutputStoreKer {
ptr: tmp_d_tile.as_ptr() as _,
row_byte_stride: std::mem::size_of::<TI>() as isize,
col_byte_stride: (std::mem::size_of::<TI>() * K::mr()) as isize,
item_size: std::mem::size_of::<TI>(),
})
}
FS::Store(c_store) => {
let tmpc = OutputStoreKer {
ptr: *loc as _,
item_size: c_store.item_size,
row_byte_stride: c_store.item_size as isize,
col_byte_stride: (c_store.item_size * K::mr()) as isize,
};
FKS::Store(tmpc)
}
FS::AddMatMul { k, a, b } => {
let pa = a.panel(down);
K::prefetch(pa as _, 512);
let scratch = *loc as *mut AddMatMulTemp;
if (*scratch).1 != right {
(*scratch).0 = b.panel_b(right, *buffer);
(*scratch).1 = right;
}
FKS::AddMatMul { k: *k, pa, pb: (*scratch).0, cpu_variant: 0 }
}
_ => std::hint::unreachable_unchecked(),
};
}
}
#[inline]
pub fn uspecs(&self) -> &[FusedKerSpec<TI>] {
&self.uspecs
}
pub unsafe fn postprocess_tile<K: MatMatMulKer<TI>>(
&mut self,
specs: &[FusedSpec],
down: usize,
right: usize,
m_remnant: usize,
n_remnant: usize,
) where
TI: LADatum,
{
for LocDependant { spec, uspec, .. } in self.loc_dependant.iter() {
let spec = specs.get_unchecked(*spec);
let ker_spec = self.uspecs.get_unchecked(*uspec);
if let (FusedSpec::Store(c_store), FusedKerSpec::Store(tmp)) = (spec, ker_spec) {
c_store.set_from_tile(down, right, m_remnant, n_remnant, tmp)
}
}
}
}