vkml 0.0.2

High-level Vulkan-based machine learning library
#[derive(Clone, Copy, Debug)]
struct AxisState {
    axis: usize,
    dim: u64,
    limit: u32,
    dispatches: u64,
    local_size: u32,
}

pub fn optimal_workgroup_size(
    max_workgroup_size: [u32; 3],
    max_workgroup_invocations: u32,
    dims: [Option<u64>; 3],
) -> [u32; 3] {
    let mut axes: Vec<AxisState> = dims
        .iter()
        .enumerate()
        .filter_map(|(axis, &maybe_dim)| {
            maybe_dim.filter(|&dim| dim > 0).map(|dim| AxisState {
                axis,
                dim,
                limit: max_workgroup_size[axis],
                dispatches: 1,
                local_size: 1,
            })
        })
        .collect();

    if axes.is_empty() {
        return [1, 1, 1];
    }

    let max_inv = max_workgroup_invocations.max(1);

    for axis in &mut axes {
        let limit = axis.limit.min(max_inv).max(1);
        axis.limit = limit;

        if axis.dim <= limit as u64 {
            axis.local_size = axis.dim as u32;
            axis.dispatches = 1;
        } else {
            axis.dispatches = div_ceil_u64(axis.dim, limit as u64);
            axis.local_size = div_ceil_u32(axis.dim, axis.dispatches).min(limit).max(1);
        }
    }

    let mut product = total_product(&axes);

    while product > max_inv as u64 {
        let mut best: Option<(usize, u64, u64, u32, u64)> = None;

        for (idx, axis) in axes.iter().enumerate() {
            if axis.local_size <= 1 {
                continue;
            }

            let mut next_dispatches = axis.dispatches + 1;
            let mut next_size = div_ceil_u32(axis.dim, next_dispatches)
                .min(axis.limit)
                .max(1);

            while next_size == axis.local_size && next_dispatches < axis.dim {
                next_dispatches += 1;
                next_size = div_ceil_u32(axis.dim, next_dispatches)
                    .min(axis.limit)
                    .max(1);
            }

            if next_size >= axis.local_size {
                continue;
            }

            let new_product = product / axis.local_size.max(1) as u64 * next_size.max(1) as u64;

            let new_waste = projected_waste(&axes, idx, next_size, next_dispatches);

            let candidate = (idx, new_product, next_dispatches, next_size, new_waste);
            match &best {
                None => best = Some(candidate),
                Some((_, best_product, _, best_size, best_waste)) => {
                    if new_product < *best_product
                        || (new_product == *best_product && new_waste < *best_waste)
                        || (new_product == *best_product
                            && new_waste == *best_waste
                            && next_size < *best_size)
                    {
                        best = Some(candidate);
                    }
                }
            }
        }

        let Some((idx, _, next_dispatches, next_size, _)) = best else {
            break;
        };

        axes[idx].dispatches = next_dispatches;
        axes[idx].local_size = next_size.min(axes[idx].limit).max(1);
        product = total_product(&axes);
    }

    let mut result = [1u32; 3];
    for axis in axes {
        result[axis.axis] = axis.local_size;
    }

    result
}

fn div_ceil_u64(value: u64, divisor: u64) -> u64 {
    if divisor == 0 {
        return 1;
    }
    value.div_ceil(divisor)
}

fn div_ceil_u32(value: u64, divisor: u64) -> u32 {
    div_ceil_u64(value, divisor).min(u32::MAX as u64) as u32
}

fn axis_waste(dim: u64, local_size: u64, dispatches: u64) -> u64 {
    dispatches.saturating_mul(local_size).saturating_sub(dim)
}

fn projected_waste(axes: &[AxisState], idx: usize, next_size: u32, next_dispatches: u64) -> u64 {
    axes.iter()
        .enumerate()
        .map(|(i, axis)| {
            if i == idx {
                axis_waste(axis.dim, next_size as u64, next_dispatches)
            } else {
                axis_waste(axis.dim, axis.local_size as u64, axis.dispatches)
            }
        })
        .sum()
}

fn total_product(axes: &[AxisState]) -> u64 {
    axes.iter().fold(1u64, |acc, axis| {
        acc.saturating_mul(axis.local_size.max(1) as u64)
    })
}