tract-core 0.19.2

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
Documentation
use crate::internal::*;
use std::ops::Range;
use tract_linalg::frame::PackingWriter;
use tract_linalg::mmm::{VirtualInput, VirtualInputSpec};

#[derive(Clone, Debug, Hash)]
pub struct LazyIm2colSpec {
    pub n_bytes_offsets: Vec<isize>,
    pub k_bytes_offsets: Vec<isize>,
}

impl_dyn_hash!(LazyIm2colSpec);

impl LazyIm2colSpec {
    fn wrap_t<T: Datum + Copy>(&self, view: &TensorView) -> Box<dyn VirtualInput> {
        let input = LazyIm2col::<T> {
            ptr: view.as_ptr().unwrap(),
            n: self.n_bytes_offsets.len(),
            n_byte_offsets: self.n_bytes_offsets.as_ptr(),
            k_byte_offsets: self.k_bytes_offsets.as_ptr(),
        };
        Box::new(input)
    }
}

impl VirtualInputSpec for LazyIm2colSpec {
    fn wrap(&self, view: &TensorView) -> Box<dyn VirtualInput> {
        dispatch_copy!(Self::wrap_t(view.datum_type())(self, view))
    }
}

#[derive(Clone, Debug)]
struct LazyIm2col<T: Datum + Copy> {
    ptr: *const T,
    n: usize,
    n_byte_offsets: *const isize,
    k_byte_offsets: *const isize,
}

unsafe impl<T: Datum + Copy> Send for LazyIm2col<T> {}
unsafe impl<T: Datum + Copy> Sync for LazyIm2col<T> {}

impl<T: Datum + Copy> LazyIm2col<T> {
    fn input_8n(&self, writer: &mut impl PackingWriter<T>, k_range: Range<isize>, n: isize) {
        unsafe {
            let o1 = *self.n_byte_offsets.offset(n);
            let o2 = *self.n_byte_offsets.offset(n + 1);
            let o3 = *self.n_byte_offsets.offset(n + 2);
            let o4 = *self.n_byte_offsets.offset(n + 3);
            let o5 = *self.n_byte_offsets.offset(n + 4);
            let o6 = *self.n_byte_offsets.offset(n + 5);
            let o7 = *self.n_byte_offsets.offset(n + 6);
            let o8 = *self.n_byte_offsets.offset(n + 7);
            for k in k_range.start..k_range.end {
                let ptr = (self.ptr as *const u8).offset(*self.k_byte_offsets.offset(k));
                let v1 = *(ptr.offset(o1) as *const T);
                let v2 = *(ptr.offset(o2) as *const T);
                let v3 = *(ptr.offset(o3) as *const T);
                let v4 = *(ptr.offset(o4) as *const T);
                let v5 = *(ptr.offset(o5) as *const T);
                let v6 = *(ptr.offset(o6) as *const T);
                let v7 = *(ptr.offset(o7) as *const T);
                let v8 = *(ptr.offset(o8) as *const T);
                writer.write(v1);
                writer.write(v2);
                writer.write(v3);
                writer.write(v4);
                writer.write(v5);
                writer.write(v6);
                writer.write(v7);
                writer.write(v8);
            }
        }
    }

    fn input_6n(&self, writer: &mut impl PackingWriter<T>, k_range: Range<isize>, n: isize) {
        unsafe {
            let o1 = *self.n_byte_offsets.offset(n);
            let o2 = *self.n_byte_offsets.offset(n + 1);
            let o3 = *self.n_byte_offsets.offset(n + 2);
            let o4 = *self.n_byte_offsets.offset(n + 3);
            let o5 = *self.n_byte_offsets.offset(n + 4);
            let o6 = *self.n_byte_offsets.offset(n + 5);
            for k in k_range.start..k_range.end {
                let ptr = (self.ptr as *const u8).offset(*self.k_byte_offsets.offset(k));
                let v1 = *(ptr.offset(o1) as *const T);
                let v2 = *(ptr.offset(o2) as *const T);
                let v3 = *(ptr.offset(o3) as *const T);
                let v4 = *(ptr.offset(o4) as *const T);
                let v5 = *(ptr.offset(o5) as *const T);
                let v6 = *(ptr.offset(o6) as *const T);
                writer.write(v1);
                writer.write(v2);
                writer.write(v3);
                writer.write(v4);
                writer.write(v5);
                writer.write(v6);
            }
        }
    }

    fn input_4n(&self, writer: &mut impl PackingWriter<T>, k_range: Range<isize>, n: isize) {
        unsafe {
            let o1 = *self.n_byte_offsets.offset(n);
            let o2 = *self.n_byte_offsets.offset(n + 1);
            let o3 = *self.n_byte_offsets.offset(n + 2);
            let o4 = *self.n_byte_offsets.offset(n + 3);
            for k in k_range.start..k_range.end {
                let ptr = (self.ptr as *const u8).offset(*self.k_byte_offsets.offset(k));
                let v1 = *(ptr.offset(o1) as *const T);
                let v2 = *(ptr.offset(o2) as *const T);
                let v3 = *(ptr.offset(o3) as *const T);
                let v4 = *(ptr.offset(o4) as *const T);
                writer.write(v1);
                writer.write(v2);
                writer.write(v3);
                writer.write(v4);
            }
        }
    }

    fn input_2n(&self, writer: &mut impl PackingWriter<T>, k_range: Range<isize>, n: isize) {
        unsafe {
            let o1 = *self.n_byte_offsets.offset(n);
            let o2 = *self.n_byte_offsets.offset(n + 1);
            for k in k_range.start..k_range.end {
                let ptr = (self.ptr as *const u8).offset(*self.k_byte_offsets.offset(k));
                let v1 = *(ptr.offset(o1) as *const T);
                let v2 = *(ptr.offset(o2) as *const T);
                writer.write(v1);
                writer.write(v2);
            }
        }
    }

    fn write(
        &self,
        writer: &mut impl PackingWriter<T>,
        k_range: std::ops::Range<isize>,
        mn_range: std::ops::Range<isize>,
    ) {
        let mn_end = mn_range.end.min(self.n as isize);
        let n_range = mn_range.start..mn_end;
        match n_range.len() {
            8 => return self.input_8n(writer, k_range, n_range.start),
            6 => return self.input_6n(writer, k_range, n_range.start),
            4 => return self.input_4n(writer, k_range, n_range.start),
            2 => return self.input_2n(writer, k_range, n_range.start),
            _ => (),
        }
        unsafe {
            for k in k_range.start..k_range.end {
                let ptr = (self.ptr as *const u8).offset(*self.k_byte_offsets.offset(k));
                let mut n = n_range.start;
                while n + 8 <= n_range.end {
                    let o1 = *self.n_byte_offsets.offset(n);
                    let o2 = *self.n_byte_offsets.offset(n + 1);
                    let o3 = *self.n_byte_offsets.offset(n + 2);
                    let o4 = *self.n_byte_offsets.offset(n + 3);
                    let o5 = *self.n_byte_offsets.offset(n + 4);
                    let o6 = *self.n_byte_offsets.offset(n + 5);
                    let o7 = *self.n_byte_offsets.offset(n + 6);
                    let o8 = *self.n_byte_offsets.offset(n + 7);
                    let v1 = *(ptr.offset(o1) as *const T);
                    let v2 = *(ptr.offset(o2) as *const T);
                    let v3 = *(ptr.offset(o3) as *const T);
                    let v4 = *(ptr.offset(o4) as *const T);
                    let v5 = *(ptr.offset(o5) as *const T);
                    let v6 = *(ptr.offset(o6) as *const T);
                    let v7 = *(ptr.offset(o7) as *const T);
                    let v8 = *(ptr.offset(o8) as *const T);
                    writer.write(v1);
                    writer.write(v2);
                    writer.write(v3);
                    writer.write(v4);
                    writer.write(v5);
                    writer.write(v6);
                    writer.write(v7);
                    writer.write(v8);
                    n += 8;
                }
                while n + 6 <= n_range.end {
                    let o1 = *self.n_byte_offsets.offset(n);
                    let o2 = *self.n_byte_offsets.offset(n + 1);
                    let o3 = *self.n_byte_offsets.offset(n + 2);
                    let o4 = *self.n_byte_offsets.offset(n + 3);
                    let o5 = *self.n_byte_offsets.offset(n + 4);
                    let o6 = *self.n_byte_offsets.offset(n + 5);
                    let v1 = *(ptr.offset(o1) as *const T);
                    let v2 = *(ptr.offset(o2) as *const T);
                    let v3 = *(ptr.offset(o3) as *const T);
                    let v4 = *(ptr.offset(o4) as *const T);
                    let v5 = *(ptr.offset(o5) as *const T);
                    let v6 = *(ptr.offset(o6) as *const T);
                    writer.write(v1);
                    writer.write(v2);
                    writer.write(v3);
                    writer.write(v4);
                    writer.write(v5);
                    writer.write(v6);
                    n += 6;
                }
                while n + 4 <= n_range.end {
                    let o1 = *self.n_byte_offsets.offset(n);
                    let o2 = *self.n_byte_offsets.offset(n + 1);
                    let o3 = *self.n_byte_offsets.offset(n + 2);
                    let o4 = *self.n_byte_offsets.offset(n + 3);
                    let v1 = *(ptr.offset(o1) as *const T);
                    let v2 = *(ptr.offset(o2) as *const T);
                    let v3 = *(ptr.offset(o3) as *const T);
                    let v4 = *(ptr.offset(o4) as *const T);
                    writer.write(v1);
                    writer.write(v2);
                    writer.write(v3);
                    writer.write(v4);
                    n += 4;
                }
                while n < n_range.end {
                    let o1 = *self.n_byte_offsets.offset(n);
                    let v1 = *(ptr.offset(o1) as *const T);
                    writer.write(v1);
                    n += 1;
                }
            }
        }
    }
}

impl<T: Datum + Copy> VirtualInput for LazyIm2col<T> {
    fn input(
        &self,
        packer: &tract_linalg::frame::Packer,
        packed: *mut u8,
        k_range: std::ops::Range<usize>,
        mn_range: std::ops::Range<usize>,
    ) {
        let mn_end = mn_range.end.min(self.n) as isize;
        let n_range = mn_range.start as isize..mn_end;
        if n_range.len() == packer.r && mn_range.start % packer.r == 0 {
            let mut writer = packer.write_single_panel_with_k_outer(packed as *mut T);
            self.write(
                &mut writer,
                k_range.start as isize..k_range.end as isize,
                mn_range.start as isize..n_range.end,
            )
        } else {
            let mut writer =
                packer.write_with_k_outer(packed as *mut T, k_range.len(), n_range.len());
            self.write(
                &mut writer,
                k_range.start as isize..k_range.end as isize,
                mn_range.start as isize..n_range.end,
            )
        }
    }
}