instmodel_inference 0.8.0

High-performance neural network inference library with instruction-based execution
Documentation
use std::sync::atomic::{AtomicUsize, Ordering};
use std::thread;

use crate::errors::{ParallelPredictError, ParallelPredictResult};
use crate::instruction_model::InstructionModel;

#[derive(Clone, Copy)]
struct SendPtr {
    ptr: *mut f32,
}

impl SendPtr {
    fn new(ptr: *mut f32) -> Self {
        Self { ptr }
    }

    unsafe fn add(self, offset: usize) -> *mut f32 {
        unsafe { self.ptr.add(offset) }
    }

    unsafe fn as_slice_mut(self, offset: usize, len: usize) -> &'static mut [f32] {
        unsafe { std::slice::from_raw_parts_mut(self.ptr.add(offset), len) }
    }
}

unsafe impl Send for SendPtr {}
unsafe impl Sync for SendPtr {}

#[derive(Debug, Clone, Default)]
pub struct PredictConfig {
    threads: Option<usize>,
    slice_result_buffer: Option<(usize, usize)>,
}

impl PredictConfig {
    pub fn new() -> Self {
        Self {
            threads: None,
            slice_result_buffer: None,
        }
    }

    pub fn with_threads(mut self, threads: usize) -> Self {
        self.threads = Some(threads);
        self
    }

    pub fn with_slice_result_buffer(mut self, start: usize, end: usize) -> Self {
        self.slice_result_buffer = Some((start, end));
        self
    }

    pub fn get_threads(&self) -> usize {
        self.threads.unwrap_or_else(|| {
            thread::available_parallelism()
                .map(|n| n.get())
                .unwrap_or(1)
        })
    }

    pub fn get_slice_range(&self, model: &InstructionModel) -> (usize, usize) {
        self.slice_result_buffer
            .unwrap_or_else(|| (model.get_output_index_start(), model.required_memory()))
    }
}

pub struct ParallelPredictOutput {
    buffer: Vec<f32>,
    num_samples: usize,
    slice_size: usize,
}

impl ParallelPredictOutput {
    fn new(buffer: Vec<f32>, num_samples: usize, slice_size: usize) -> Self {
        Self {
            buffer,
            num_samples,
            slice_size,
        }
    }

    pub fn num_samples(&self) -> usize {
        self.num_samples
    }

    pub fn slice_size(&self) -> usize {
        self.slice_size
    }

    pub fn as_slice(&self) -> &[f32] {
        &self.buffer
    }

    pub fn copy_results(&self, dest: &mut [f32]) -> ParallelPredictResult<()> {
        if dest.len() != self.buffer.len() {
            return Err(ParallelPredictError::DestinationBufferSizeMismatch {
                expected: self.buffer.len(),
                actual: dest.len(),
            });
        }
        dest.copy_from_slice(&self.buffer);
        Ok(())
    }

    pub fn copy_results_to_vec(&self) -> Vec<Vec<f32>> {
        (0..self.num_samples)
            .map(|i| {
                let start = i * self.slice_size;
                let end = start + self.slice_size;
                self.buffer[start..end].to_vec()
            })
            .collect()
    }

    pub fn get_result(&self, index: usize) -> ParallelPredictResult<&[f32]> {
        if index >= self.num_samples {
            return Err(ParallelPredictError::ResultIndexOutOfBounds {
                index,
                num_samples: self.num_samples,
            });
        }
        let start = index * self.slice_size;
        let end = start + self.slice_size;
        Ok(&self.buffer[start..end])
    }

    pub fn into_buffer(self) -> Vec<f32> {
        self.buffer
    }
}

pub fn execute_parallel_predict(
    model: &InstructionModel,
    inputs: &[f32],
    config: &PredictConfig,
) -> ParallelPredictResult<ParallelPredictOutput> {
    let feature_size = model.get_feature_size();
    let required_memory = model.required_memory();

    if !inputs.len().is_multiple_of(feature_size) {
        let num_samples = inputs.len() / feature_size;
        let expected = (num_samples + 1) * feature_size;
        return Err(ParallelPredictError::InputBufferSizeMismatch {
            expected,
            actual: inputs.len(),
            num_samples,
            feature_size,
        });
    }

    let num_samples = inputs.len() / feature_size;
    if num_samples == 0 {
        return Ok(ParallelPredictOutput::new(Vec::new(), 0, 0));
    }

    let num_threads = config.get_threads();
    if num_threads == 0 {
        return Err(ParallelPredictError::InvalidThreadCount { count: 0 });
    }

    let (slice_start, slice_end) = config.get_slice_range(model);
    if slice_start >= slice_end {
        return Err(ParallelPredictError::InvalidSliceRange {
            start: slice_start,
            end: slice_end,
        });
    }
    if slice_end > required_memory {
        return Err(ParallelPredictError::SliceRangeOutOfBounds {
            start: slice_start,
            end: slice_end,
            buffer_size: required_memory,
        });
    }

    let slice_size = slice_end - slice_start;
    let mut output_buffer = vec![0.0f32; num_samples * slice_size];
    let mut computation_buffers = vec![0.0f32; num_threads * required_memory];

    let sample_counter = AtomicUsize::new(0);

    thread::scope(|scope| {
        let sample_counter_ref = &sample_counter;
        let output_buffer_ptr = SendPtr::new(output_buffer.as_mut_ptr());
        let computation_buffers_ptr = SendPtr::new(computation_buffers.as_mut_ptr());

        let handles: Vec<_> = (0..num_threads)
            .map(|thread_id| {
                scope.spawn(move || -> ParallelPredictResult<()> {
                    let thread_buffer_start = thread_id * required_memory;
                    let thread_buffer = unsafe {
                        computation_buffers_ptr.as_slice_mut(thread_buffer_start, required_memory)
                    };

                    loop {
                        let sample_index = sample_counter_ref.fetch_add(1, Ordering::Relaxed);
                        if sample_index >= num_samples {
                            break;
                        }

                        let input_start = sample_index * feature_size;
                        let input_end = input_start + feature_size;
                        thread_buffer[..feature_size]
                            .copy_from_slice(&inputs[input_start..input_end]);

                        model.predict_with_buffer(thread_buffer).map_err(|e| {
                            ParallelPredictError::PredictionFailed {
                                sample_index,
                                message: e.to_string(),
                            }
                        })?;

                        let output_start = sample_index * slice_size;
                        unsafe {
                            std::ptr::copy_nonoverlapping(
                                thread_buffer.as_ptr().add(slice_start),
                                output_buffer_ptr.add(output_start),
                                slice_size,
                            );
                        }
                    }

                    Ok(())
                })
            })
            .collect();

        for handle in handles {
            match handle.join() {
                Ok(result) => result?,
                Err(_) => return Err(ParallelPredictError::ThreadPanicked),
            }
        }

        Ok(())
    })?;

    Ok(ParallelPredictOutput::new(
        output_buffer,
        num_samples,
        slice_size,
    ))
}