vyre-libs 0.6.3

vyre Category A library ecosystem - pure-IR compositions over vyre-ops hardware primitives
Documentation
use vyre::ir::{BufferAccess, Program};
use vyre_primitives::reduce::multi_block_prefix_scan::{
    multi_block_prefix_scan_sum_u32, pass_a_local_scan, pass_c_broadcast_offsets, BLOCK_LANES,
};

use super::GpuDispatcher;

#[derive(Default)]
pub(super) struct PrefixScanScratch {
    small_zero: Vec<u8>,
    small_outputs: Vec<Vec<u8>>,
    pass_a_partials_zero: Vec<u8>,
    pass_a_totals_zero: Vec<u8>,
    pass_a_outputs: Vec<Vec<u8>>,
    block_totals_input: Vec<u8>,
    scanned_block_totals: Vec<u8>,
    nested: Option<Box<PrefixScanScratch>>,
    pass_c_zero: Vec<u8>,
    pass_c_outputs: Vec<Vec<u8>>,
}

impl PrefixScanScratch {
    fn prepare_zero(out: &mut Vec<u8>, byte_len: usize) -> Result<(), String> {
        out.clear();
        out.try_reserve_exact(byte_len).map_err(|error| {
            format!(
                "prefix scan: could not reserve {byte_len} zero-staging bytes: {error:?}. Fix: shard the GPU prefix scan input."
            )
        })?;
        out.resize(byte_len, 0);
        Ok(())
    }
}

fn prefix_scan_word_bytes(word_count: u32, field: &'static str) -> Result<usize, String> {
    (word_count as usize)
        .checked_mul(std::mem::size_of::<u32>())
        .ok_or_else(|| {
            format!(
                "prefix scan: {field} word count {word_count} overflows host byte sizing. Fix: shard the GPU prefix scan input."
            )
        })
}

fn prefix_scan_product_word_bytes(
    left: u32,
    right: u32,
    field: &'static str,
) -> Result<usize, String> {
    (left as usize)
        .checked_mul(right as usize)
        .and_then(|words| words.checked_mul(std::mem::size_of::<u32>()))
        .ok_or_else(|| {
            format!(
                "prefix scan: {field} word product {left} x {right} overflows host byte sizing. Fix: shard the GPU prefix scan input."
            )
        })
}

pub(super) fn inclusive_prefix_scan_u32_into(
    dispatcher: &dyn GpuDispatcher,
    input_words_le: &[u8],
    n: u32,
    scratch: &mut PrefixScanScratch,
    out: &mut Vec<u8>,
) -> Result<(), String> {
    if n > BLOCK_LANES {
        return inclusive_prefix_scan_u32_large_into(dispatcher, input_words_le, n, scratch, out);
    }
    let scan = multi_block_prefix_scan_sum_u32("scan_in", "scan_out", n);
    if dispatcher.requires_output_inputs() {
        let small_zero_bytes = prefix_scan_word_bytes(n, "small output")?;
        PrefixScanScratch::prepare_zero(&mut scratch.small_zero, small_zero_bytes)?;
        dispatcher.dispatch_borrowed_into(
            &scan,
            &[input_words_le, scratch.small_zero.as_slice()],
            &mut scratch.small_outputs,
        )?;
        if scratch.small_outputs.len() != 1 {
            return Err(format!(
                "prefix scan: expected exactly 1 output, got {}. Fix: backend must return only scan_out.",
                scratch.small_outputs.len()
            ));
        }
    } else {
        dispatcher.dispatch_borrowed_into(&scan, &[input_words_le], &mut scratch.small_outputs)?;
        if scratch.small_outputs.len() != 1 {
            return Err(format!(
                "prefix scan: expected exactly 1 output, got {}. Fix: backend must return only scan_out.",
                scratch.small_outputs.len()
            ));
        }
    }
    out.clear();
    out.extend_from_slice(&scratch.small_outputs[0]);
    Ok(())
}

fn inclusive_prefix_scan_u32_large_into(
    dispatcher: &dyn GpuDispatcher,
    input_words_le: &[u8],
    n: u32,
    scratch: &mut PrefixScanScratch,
    out: &mut Vec<u8>,
) -> Result<(), String> {
    let num_blocks = n.div_ceil(BLOCK_LANES);
    let mut pass_a = pass_a_local_scan(
        "scan_in",
        "scan_partials",
        "scan_block_totals",
        n,
        num_blocks,
    );
    if dispatcher.requires_output_inputs() {
        pass_a = live_out_readwrite_buffers(pass_a, &["scan_partials", "scan_block_totals"]);
        let pass_a_partials_bytes =
            prefix_scan_product_word_bytes(num_blocks, BLOCK_LANES, "pass A partials")?;
        let pass_a_totals_bytes = prefix_scan_word_bytes(num_blocks, "pass A block totals")?;
        PrefixScanScratch::prepare_zero(&mut scratch.pass_a_partials_zero, pass_a_partials_bytes)?;
        PrefixScanScratch::prepare_zero(&mut scratch.pass_a_totals_zero, pass_a_totals_bytes)?;
        dispatcher
            .dispatch_borrowed_into(
                &pass_a,
                &[
                    input_words_le,
                    scratch.pass_a_partials_zero.as_slice(),
                    scratch.pass_a_totals_zero.as_slice(),
                ],
                &mut scratch.pass_a_outputs,
            )
            .map_err(|e| format!("pass A: {e}"))?;
    } else {
        dispatcher
            .dispatch_borrowed_into(&pass_a, &[input_words_le], &mut scratch.pass_a_outputs)
            .map_err(|e| format!("pass A: {e}"))?;
    }
    if scratch.pass_a_outputs.len() != 2 {
        return Err(format!(
            "pass A: expected exactly 2 outputs, got {}. Fix: backend must return scan_partials/scan_block_totals and no extras.",
            scratch.pass_a_outputs.len()
        ));
    }

    scratch.block_totals_input.clear();
    scratch
        .block_totals_input
        .extend_from_slice(&scratch.pass_a_outputs[1]);
    let nested = scratch
        .nested
        .get_or_insert_with(|| Box::new(PrefixScanScratch::default()));
    inclusive_prefix_scan_u32_into(
        dispatcher,
        scratch.block_totals_input.as_slice(),
        num_blocks,
        nested,
        &mut scratch.scanned_block_totals,
    )?;

    let pass_c = pass_c_broadcast_offsets(
        "scan_partials",
        "scan_block_totals_scanned",
        "scan_out",
        n,
        num_blocks,
    );
    if dispatcher.requires_output_inputs() {
        let pass_c_zero_bytes = prefix_scan_word_bytes(n, "pass C output")?;
        PrefixScanScratch::prepare_zero(&mut scratch.pass_c_zero, pass_c_zero_bytes)?;
        dispatcher
            .dispatch_borrowed_into(
                &pass_c,
                &[
                    scratch.pass_a_outputs[0].as_slice(),
                    scratch.scanned_block_totals.as_slice(),
                    scratch.pass_c_zero.as_slice(),
                ],
                &mut scratch.pass_c_outputs,
            )
            .map_err(|e| format!("pass C: {e}"))?;
    } else {
        dispatcher
            .dispatch_borrowed_into(
                &pass_c,
                &[
                    scratch.pass_a_outputs[0].as_slice(),
                    scratch.scanned_block_totals.as_slice(),
                ],
                &mut scratch.pass_c_outputs,
            )
            .map_err(|e| format!("pass C: {e}"))?;
    }
    if scratch.pass_c_outputs.len() != 1 {
        return Err(format!(
            "pass C: expected exactly 1 output, got {}. Fix: backend must return only scan_out.",
            scratch.pass_c_outputs.len()
        ));
    }
    out.clear();
    out.extend_from_slice(&scratch.pass_c_outputs[0]);
    Ok(())
}

fn live_out_readwrite_buffers(program: Program, names: &[&str]) -> Program {
    let buffers = program
        .buffers()
        .iter()
        .map(|buffer| {
            let mut buffer = buffer.clone();
            if names.iter().any(|name| *name == buffer.name()) {
                buffer.is_output = false;
                buffer.pipeline_live_out = true;
                buffer.output_byte_range = None;
                buffer.access = BufferAccess::ReadWrite;
            }
            buffer
        })
        .collect();
    program.with_rewritten_buffers(buffers)
}

#[cfg(test)]
mod tests {
    use super::{
        prefix_scan_product_word_bytes, prefix_scan_word_bytes, PrefixScanScratch, BLOCK_LANES,
    };

    #[test]
    fn prefix_scan_word_bytes_uses_checked_u32_sizing() {
        assert_eq!(
            prefix_scan_word_bytes(3, "test").expect("Fix: small word count should fit"),
            12
        );
        assert_eq!(
            prefix_scan_product_word_bytes(2, BLOCK_LANES, "partials")
                .expect("Fix: small partial count should fit"),
            (2 * BLOCK_LANES as usize) * std::mem::size_of::<u32>()
        );
    }

    #[test]
    fn prefix_scan_zero_staging_reserves_before_resize() {
        let mut out = Vec::with_capacity(16);
        PrefixScanScratch::prepare_zero(&mut out, 12)
            .expect("Fix: small zero staging reservation should fit");
        assert_eq!(out, vec![0; 12]);
        assert!(out.capacity() >= 12);
    }
}