use crate::soa::STRING_STRIDE;
const DEFAULT_VRAM_UTILIZATION: f32 = 0.75;
pub const GPU_BATCH_MIN: usize = 1_000;
#[derive(Debug, Clone)]
pub struct BatchSizer {
pub vram_utilization: f32,
}
impl Default for BatchSizer {
fn default() -> Self {
Self::new()
}
}
impl BatchSizer {
pub fn new() -> Self {
Self {
vram_utilization: DEFAULT_VRAM_UTILIZATION,
}
}
pub fn with_utilization(mut self, fraction: f32) -> Self {
assert!(
fraction > 0.0 && fraction <= 1.0,
"utilization must be in (0, 1]"
);
self.vram_utilization = fraction;
self
}
pub fn max_batch_size(&self, available_vram_bytes: u64, num_fields: usize) -> usize {
let bytes_per_pair: usize = 2 * num_fields * STRING_STRIDE + 2 * num_fields * 2 + 2 * 8 + 2 * 4 + num_fields * 4;
let usable = (available_vram_bytes as f64 * self.vram_utilization as f64) as u64;
(usable / bytes_per_pair as u64).max(1) as usize
}
pub const fn min_batch_for_gpu() -> usize {
GPU_BATCH_MIN
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn max_batch_grows_with_vram() {
let sizer = BatchSizer::new();
let small = sizer.max_batch_size(1 * 1024 * 1024 * 1024, 10);
let large = sizer.max_batch_size(8 * 1024 * 1024 * 1024, 10);
assert!(large > small);
}
#[test]
fn max_batch_never_zero() {
let sizer = BatchSizer::new();
let r = sizer.max_batch_size(1, 1000);
assert_eq!(r, 1);
}
#[test]
fn three_gb_vram_fits_millions() {
let sizer = BatchSizer::new();
let available = 3u64 * 1024 * 1024 * 1024;
let max = sizer.max_batch_size(available, 10);
assert!(max > 1_000_000, "expected >1M pairs, got {max}");
}
#[test]
fn min_batch_constant_is_positive() {
assert!(BatchSizer::min_batch_for_gpu() > 0);
}
#[test]
fn utilization_scales_result() {
let full = BatchSizer::new()
.with_utilization(1.0)
.max_batch_size(1_000_000, 5);
let half = BatchSizer::new()
.with_utilization(0.5)
.max_batch_size(1_000_000, 5);
assert!(full > half);
}
#[test]
fn formula_matches_compare_pool_layout() {
let bytes_per_pair_1field = 2 * 1 * STRING_STRIDE + 2 * 1 * 2 + 16 + 8 + 1 * 4;
assert_eq!(bytes_per_pair_1field, 160);
let sizer = BatchSizer::new().with_utilization(1.0);
let max = sizer.max_batch_size(160, 1);
assert_eq!(
max, 1,
"exactly one pair should fit in 160 bytes for 1 field"
);
}
}