tract-linalg 0.23.0-dev.4

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
Documentation
use crate::LADatum;
use std::alloc::*;
use tract_data::TractResult;

pub(crate) fn map_slice_with_alignment<T>(
    vec: &mut [T],
    f: impl Fn(&mut [T]),
    nr: usize,
    alignment_bytes: usize,
) -> TractResult<()>
where
    T: LADatum,
{
    if vec.is_empty() {
        return Ok(());
    }
    unsafe {
        TMP.with(|buffer| {
            let mut buffer = buffer.borrow_mut();
            buffer.ensure(nr * T::datum_type().size_of(), alignment_bytes);
            let tmp = std::slice::from_raw_parts_mut(buffer.buffer as *mut T, nr);
            let mut compute_via_temp_buffer = |slice: &mut [T]| {
                tmp[..slice.len()].copy_from_slice(slice);
                f(tmp);
                slice.copy_from_slice(&tmp[..slice.len()])
            };
            let prefix_len = vec.as_ptr().align_offset(alignment_bytes).min(vec.len());
            if prefix_len > 0 {
                compute_via_temp_buffer(&mut vec[..prefix_len]);
            }
            let aligned_len = (vec.len() - prefix_len) / nr * nr;
            if aligned_len > 0 {
                f(&mut vec[prefix_len..][..aligned_len]);
            }
            if prefix_len + aligned_len < vec.len() {
                compute_via_temp_buffer(&mut vec[prefix_len + aligned_len..]);
            }
        })
    }
    Ok(())
}

pub(crate) fn reduce_slice_with_alignment<T>(
    vec: &[T],
    f: impl Fn(&[T]) -> T,
    nr: usize,
    alignment_bytes: usize,
    neutral: T,
    reduce: impl Fn(T, T) -> T,
) -> TractResult<T>
where
    T: LADatum,
{
    if vec.is_empty() {
        return Ok(neutral);
    }
    let mut red = neutral;
    unsafe {
        TMP.with(|buffer| {
            let mut buffer = buffer.borrow_mut();
            buffer.ensure(nr * T::datum_type().size_of(), alignment_bytes);
            let tmp = std::slice::from_raw_parts_mut(buffer.buffer as *mut T, nr);
            let mut compute_via_temp_buffer = |slice: &[T], red: &mut T| {
                tmp[..slice.len()].copy_from_slice(slice);
                tmp[slice.len()..].fill(neutral);
                *red = reduce(*red, f(tmp));
            };
            let prefix_len = vec.as_ptr().align_offset(alignment_bytes).min(vec.len());
            if prefix_len > 0 {
                compute_via_temp_buffer(&vec[..prefix_len], &mut red);
            }
            let aligned_len = (vec.len() - prefix_len) / nr * nr;
            if aligned_len > 0 {
                let t = f(&vec[prefix_len..][..aligned_len]);
                red = reduce(red, t);
            }
            if prefix_len + aligned_len < vec.len() {
                compute_via_temp_buffer(&vec[prefix_len + aligned_len..], &mut red);
            }
        })
    }
    Ok(red)
}

pub(crate) fn map_reduce_slice_with_alignment<T>(
    vec: &mut [T],
    f: impl Fn(&mut [T]) -> T,
    nr: usize,
    alignment_bytes: usize,
    map_neutral: T,
    neutral: T,
    reduce: impl Fn(T, T) -> T,
) -> TractResult<T>
where
    T: LADatum,
{
    if vec.is_empty() {
        return Ok(neutral);
    }
    let mut red = neutral;
    unsafe {
        TMP.with(|buffer| {
            let mut buffer = buffer.borrow_mut();
            buffer.ensure(nr * T::datum_type().size_of(), alignment_bytes);
            let tmp = std::slice::from_raw_parts_mut(buffer.buffer as *mut T, nr);
            let mut compute_via_temp_buffer = |slice: &mut [T], red: &mut T| {
                tmp[..slice.len()].copy_from_slice(slice);
                tmp[slice.len()..].fill(map_neutral);
                *red = reduce(*red, f(tmp));
                slice.copy_from_slice(&tmp[..slice.len()]);
            };
            let prefix_len = vec.as_ptr().align_offset(alignment_bytes).min(vec.len());
            if prefix_len > 0 {
                compute_via_temp_buffer(&mut vec[..prefix_len], &mut red);
            }
            let aligned_len = (vec.len() - prefix_len) / nr * nr;
            if aligned_len > 0 {
                let t = f(&mut vec[prefix_len..][..aligned_len]);
                red = reduce(red, t);
            }
            if prefix_len + aligned_len < vec.len() {
                compute_via_temp_buffer(&mut vec[prefix_len + aligned_len..], &mut red);
            }
        })
    }
    Ok(red)
}

std::thread_local! {
    static TMP: std::cell::RefCell<TempBuffer> = std::cell::RefCell::new(TempBuffer::default());
}

pub struct TempBuffer {
    pub layout: Layout,
    pub buffer: *mut u8,
}

impl Default for TempBuffer {
    fn default() -> Self {
        TempBuffer { layout: Layout::new::<()>(), buffer: std::ptr::null_mut() }
    }
}

impl TempBuffer {
    pub fn ensure(&mut self, size: usize, alignment: usize) {
        unsafe {
            if size > self.layout.size() || alignment > self.layout.align() {
                let size = size.max(self.layout.size());
                let alignment = alignment.max(self.layout.align());
                if !self.buffer.is_null() {
                    std::alloc::dealloc(self.buffer, self.layout);
                }
                self.layout = Layout::from_size_align_unchecked(size, alignment);
                self.buffer = std::alloc::alloc(self.layout);
                assert!(!self.buffer.is_null());
            }
        }
    }
}

impl Drop for TempBuffer {
    fn drop(&mut self) {
        unsafe {
            if !self.buffer.is_null() {
                std::alloc::dealloc(self.buffer, self.layout);
            }
        }
    }
}