use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
pub const WORK_GROUP_SIZE: usize = 32;
#[derive(Clone, Debug)]
pub struct WorkItem<T> {
pub data: T,
pub num_subtasks: usize,
pub priority: f32,
}
impl<T> WorkItem<T> {
pub fn new(data: T, num_subtasks: usize) -> Self {
Self {
data,
num_subtasks,
priority: 0.0,
}
}
pub fn with_priority(data: T, num_subtasks: usize, priority: f32) -> Self {
Self {
data,
num_subtasks,
priority,
}
}
}
#[derive(Debug)]
pub struct WorkQueue<T> {
items: Vec<WorkItem<T>>,
dispatch_index: AtomicUsize,
total_subtasks: usize,
}
impl<T> WorkQueue<T> {
pub fn new(items: Vec<WorkItem<T>>) -> Self {
let total_subtasks = items.iter().map(|i| i.num_subtasks).sum();
Self {
items,
dispatch_index: AtomicUsize::new(0),
total_subtasks,
}
}
pub fn len(&self) -> usize {
self.items.len()
}
pub fn is_empty(&self) -> bool {
self.items.is_empty()
}
pub fn total_subtasks(&self) -> usize {
self.total_subtasks
}
pub fn request_next(&self) -> Option<usize> {
let index = self.dispatch_index.fetch_add(1, Ordering::AcqRel);
if index < self.items.len() {
Some(index)
} else {
None
}
}
pub fn get(&self, index: usize) -> Option<&WorkItem<T>> {
self.items.get(index)
}
pub fn reset(&self) {
self.dispatch_index.store(0, Ordering::Release);
}
pub fn progress(&self) -> (usize, usize) {
let dispatched = self.dispatch_index.load(Ordering::Acquire);
(dispatched.min(self.items.len()), self.items.len())
}
}
#[derive(Debug)]
pub struct WorkGroup {
size: usize,
id: usize,
thread_rank: usize,
}
impl WorkGroup {
pub fn new(size: usize, id: usize, thread_rank: usize) -> Self {
Self {
size,
id,
thread_rank,
}
}
pub fn size(&self) -> usize {
self.size
}
pub fn id(&self) -> usize {
self.id
}
pub fn thread_rank(&self) -> usize {
self.thread_rank
}
pub fn is_dispatcher(&self) -> bool {
self.thread_rank == 0
}
pub fn subtask_range(&self, num_subtasks: usize) -> impl Iterator<Item = usize> {
let start = self.thread_rank;
let step = self.size;
(start..num_subtasks).step_by(step)
}
}
#[derive(Debug)]
pub struct WorkDispatcher<T> {
queue: Arc<WorkQueue<T>>,
num_groups: usize,
group_size: usize,
}
impl<T> WorkDispatcher<T> {
pub fn new(items: Vec<WorkItem<T>>, num_groups: usize, group_size: usize) -> Self {
Self {
queue: Arc::new(WorkQueue::new(items)),
num_groups,
group_size,
}
}
pub fn with_default_group_size(items: Vec<WorkItem<T>>, num_groups: usize) -> Self {
Self::new(items, num_groups, WORK_GROUP_SIZE)
}
pub fn queue(&self) -> &WorkQueue<T> {
&self.queue
}
pub fn num_groups(&self) -> usize {
self.num_groups
}
pub fn group_size(&self) -> usize {
self.group_size
}
pub fn create_group(&self, group_id: usize, thread_rank: usize) -> WorkGroup {
WorkGroup::new(self.group_size, group_id, thread_rank)
}
pub fn reset(&self) {
self.queue.reset();
}
pub fn stats(&self) -> DispatchStats {
let (dispatched, total) = self.queue.progress();
DispatchStats {
total_items: total,
dispatched_items: dispatched,
total_subtasks: self.queue.total_subtasks(),
num_groups: self.num_groups,
group_size: self.group_size,
}
}
}
impl<T: Clone> WorkDispatcher<T> {
pub fn queue_handle(&self) -> Arc<WorkQueue<T>> {
Arc::clone(&self.queue)
}
}
#[derive(Clone, Debug)]
pub struct DispatchStats {
pub total_items: usize,
pub dispatched_items: usize,
pub total_subtasks: usize,
pub num_groups: usize,
pub group_size: usize,
}
impl DispatchStats {
pub fn completion_ratio(&self) -> f64 {
if self.total_items == 0 {
1.0
} else {
self.dispatched_items as f64 / self.total_items as f64
}
}
pub fn total_workers(&self) -> usize {
self.num_groups * self.group_size
}
pub fn avg_subtasks_per_worker(&self) -> f64 {
if self.total_workers() == 0 {
0.0
} else {
self.total_subtasks as f64 / self.total_workers() as f64
}
}
}
pub struct LoadBalancer {
num_workers: usize,
group_size: usize,
}
impl LoadBalancer {
pub fn new(num_workers: usize) -> Self {
let group_size = WORK_GROUP_SIZE.min(num_workers);
Self {
num_workers,
group_size,
}
}
pub fn with_group_size(num_workers: usize, group_size: usize) -> Self {
Self {
num_workers,
group_size,
}
}
pub fn num_workers(&self) -> usize {
self.num_workers
}
pub fn num_groups(&self) -> usize {
self.num_workers.div_ceil(self.group_size)
}
pub fn create_dispatcher<T>(&self, items: Vec<WorkItem<T>>) -> WorkDispatcher<T> {
WorkDispatcher::new(items, self.num_groups(), self.group_size)
}
pub fn estimate_workers(num_items: usize, avg_subtasks: usize) -> usize {
let total_work = num_items * avg_subtasks;
let min_groups = num_items.div_ceil(4);
let max_groups = total_work.div_ceil(WORK_GROUP_SIZE);
min_groups.max(1).min(max_groups) * WORK_GROUP_SIZE
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_work_item_creation() {
let item = WorkItem::new(42, 10);
assert_eq!(item.data, 42);
assert_eq!(item.num_subtasks, 10);
assert_eq!(item.priority, 0.0);
}
#[test]
fn test_work_item_with_priority() {
let item = WorkItem::with_priority("data", 5, 1.5);
assert_eq!(item.data, "data");
assert_eq!(item.num_subtasks, 5);
assert_eq!(item.priority, 1.5);
}
#[test]
fn test_work_queue_creation() {
let items = vec![
WorkItem::new(1, 10),
WorkItem::new(2, 20),
WorkItem::new(3, 30),
];
let queue = WorkQueue::new(items);
assert_eq!(queue.len(), 3);
assert_eq!(queue.total_subtasks(), 60);
assert!(!queue.is_empty());
}
#[test]
fn test_work_queue_dispatch() {
let items = vec![WorkItem::new(1, 10), WorkItem::new(2, 20)];
let queue = WorkQueue::new(items);
assert_eq!(queue.request_next(), Some(0));
assert_eq!(queue.request_next(), Some(1));
assert_eq!(queue.request_next(), None);
}
#[test]
fn test_work_queue_reset() {
let items = vec![WorkItem::new(1, 10)];
let queue = WorkQueue::new(items);
assert_eq!(queue.request_next(), Some(0));
assert_eq!(queue.request_next(), None);
queue.reset();
assert_eq!(queue.request_next(), Some(0));
}
#[test]
fn test_work_group() {
let group = WorkGroup::new(32, 0, 0);
assert_eq!(group.size(), 32);
assert_eq!(group.id(), 0);
assert!(group.is_dispatcher());
let non_dispatcher = WorkGroup::new(32, 0, 5);
assert!(!non_dispatcher.is_dispatcher());
}
#[test]
fn test_work_group_subtask_range() {
let group = WorkGroup::new(4, 0, 1);
let subtasks: Vec<_> = group.subtask_range(10).collect();
assert_eq!(subtasks, vec![1, 5, 9]); }
#[test]
fn test_work_dispatcher() {
let items = vec![WorkItem::new(1, 10), WorkItem::new(2, 20)];
let dispatcher = WorkDispatcher::with_default_group_size(items, 4);
assert_eq!(dispatcher.num_groups(), 4);
assert_eq!(dispatcher.group_size(), 32);
}
#[test]
fn test_dispatch_stats() {
let items = vec![WorkItem::new(1, 10), WorkItem::new(2, 20)];
let dispatcher = WorkDispatcher::new(items, 4, 8);
let stats = dispatcher.stats();
assert_eq!(stats.total_items, 2);
assert_eq!(stats.total_subtasks, 30);
assert_eq!(stats.total_workers(), 32);
assert!((stats.avg_subtasks_per_worker() - 0.9375).abs() < 0.01);
}
#[test]
fn test_load_balancer() {
let balancer = LoadBalancer::new(128);
assert_eq!(balancer.num_workers(), 128);
assert_eq!(balancer.num_groups(), 4);
}
#[test]
fn test_load_balancer_create_dispatcher() {
let balancer = LoadBalancer::new(64);
let items = vec![WorkItem::new(1, 10)];
let dispatcher = balancer.create_dispatcher(items);
assert_eq!(dispatcher.num_groups(), 2);
}
#[test]
fn test_estimate_workers() {
let workers = LoadBalancer::estimate_workers(10, 5);
assert!(workers >= 32);
let workers = LoadBalancer::estimate_workers(1000, 100);
assert!(workers >= 32);
}
#[test]
fn test_concurrent_dispatch() {
use std::thread;
let items: Vec<_> = (0..100).map(|i| WorkItem::new(i, 1)).collect();
let dispatcher = WorkDispatcher::with_default_group_size(items, 4);
let queue = dispatcher.queue_handle();
let handles: Vec<_> = (0..4)
.map(|_| {
let q = Arc::clone(&queue);
thread::spawn(move || {
let mut count = 0;
while q.request_next().is_some() {
count += 1;
}
count
})
})
.collect();
let total: usize = handles
.into_iter()
.map(|h| {
h.join()
.expect("gpu/load_balance.rs: required value was None/Err")
})
.sum();
assert_eq!(total, 100);
}
}
#[cfg(test)]
mod property_tests {
use super::*;
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(50))]
#[test]
fn work_item_stores_data(data in 0i32..1000, num_subtasks in 0usize..100) {
let item = WorkItem::new(data, num_subtasks);
prop_assert_eq!(item.data, data);
prop_assert_eq!(item.num_subtasks, num_subtasks);
prop_assert_eq!(item.priority, 0.0);
}
#[test]
fn work_item_priority(data in 0i32..1000, num_subtasks in 0usize..100, priority in -10.0f32..10.0) {
let item = WorkItem::with_priority(data, num_subtasks, priority);
prop_assert_eq!(item.data, data);
prop_assert_eq!(item.num_subtasks, num_subtasks);
prop_assert!((item.priority - priority).abs() < 1e-6);
}
#[test]
fn work_queue_subtask_sum(subtasks in proptest::collection::vec(0usize..50, 1..20)) {
let items: Vec<_> = subtasks.iter().map(|&s| WorkItem::new(s, s)).collect();
let expected_total: usize = subtasks.iter().sum();
let queue = WorkQueue::new(items);
prop_assert_eq!(queue.total_subtasks(), expected_total);
prop_assert_eq!(queue.len(), subtasks.len());
}
#[test]
fn work_queue_dispatch_all(num_items in 1usize..50) {
let items: Vec<_> = (0..num_items).map(|i| WorkItem::new(i, 1)).collect();
let queue = WorkQueue::new(items);
let mut dispatched = Vec::new();
while let Some(idx) = queue.request_next() {
dispatched.push(idx);
}
prop_assert_eq!(dispatched.len(), num_items);
dispatched.sort();
let expected: Vec<_> = (0..num_items).collect();
prop_assert_eq!(dispatched, expected);
}
#[test]
fn work_queue_reset_enables_redispatch(num_items in 1usize..20) {
let items: Vec<_> = (0..num_items).map(|i| WorkItem::new(i, 1)).collect();
let queue = WorkQueue::new(items);
while queue.request_next().is_some() {}
prop_assert!(queue.request_next().is_none());
queue.reset();
let mut count = 0;
while queue.request_next().is_some() {
count += 1;
}
prop_assert_eq!(count, num_items);
}
#[test]
fn work_queue_progress_accurate(num_items in 2usize..20, dispatch_count in 0usize..20) {
let items: Vec<_> = (0..num_items).map(|i| WorkItem::new(i, 1)).collect();
let queue = WorkQueue::new(items);
for _ in 0..dispatch_count {
queue.request_next();
}
let (dispatched, total) = queue.progress();
prop_assert_eq!(total, num_items);
prop_assert_eq!(dispatched, dispatch_count.min(num_items));
}
#[test]
fn work_group_thread_zero_is_dispatcher(size in 1usize..64, id in 0usize..100) {
let group = WorkGroup::new(size, id, 0);
prop_assert!(group.is_dispatcher());
for rank in 1..size {
let non_dispatch = WorkGroup::new(size, id, rank);
prop_assert!(!non_dispatch.is_dispatcher());
}
}
#[test]
fn work_group_subtask_range_covers_all(
size in 2usize..8,
num_subtasks in 1usize..50
) {
let mut all_subtasks = std::collections::HashSet::new();
for rank in 0..size {
let group = WorkGroup::new(size, 0, rank);
for subtask in group.subtask_range(num_subtasks) {
all_subtasks.insert(subtask);
}
}
let expected: std::collections::HashSet<_> = (0..num_subtasks).collect();
prop_assert_eq!(all_subtasks, expected);
}
#[test]
fn work_group_subtask_step(size in 2usize..16, rank in 0usize..16, num_subtasks in 10usize..50) {
prop_assume!(rank < size);
let group = WorkGroup::new(size, 0, rank);
let subtasks: Vec<_> = group.subtask_range(num_subtasks).collect();
if !subtasks.is_empty() {
prop_assert_eq!(subtasks[0], rank);
}
for i in 1..subtasks.len() {
prop_assert_eq!(subtasks[i] - subtasks[i-1], size);
}
}
#[test]
fn work_dispatcher_stats_consistent(
num_items in 1usize..30,
num_groups in 1usize..8,
group_size in 1usize..32
) {
let items: Vec<_> = (0..num_items).map(|i| WorkItem::new(i, i + 1)).collect();
let dispatcher = WorkDispatcher::new(items, num_groups, group_size);
let stats = dispatcher.stats();
prop_assert_eq!(stats.total_items, num_items);
prop_assert_eq!(stats.num_groups, num_groups);
prop_assert_eq!(stats.group_size, group_size);
prop_assert_eq!(stats.total_workers(), num_groups * group_size);
let expected_subtasks: usize = (1..=num_items).sum();
prop_assert_eq!(stats.total_subtasks, expected_subtasks);
}
#[test]
fn work_dispatcher_completion_ratio(num_items in 2usize..20) {
let items: Vec<_> = (0..num_items).map(|i| WorkItem::new(i, 1)).collect();
let dispatcher = WorkDispatcher::with_default_group_size(items, 1);
let initial = dispatcher.stats().completion_ratio();
prop_assert!((initial - 0.0).abs() < 0.01);
for _ in 0..num_items/2 {
dispatcher.queue().request_next();
}
let half = dispatcher.stats().completion_ratio();
let expected_half = (num_items / 2) as f64 / num_items as f64;
prop_assert!((half - expected_half).abs() < 0.01);
}
#[test]
fn load_balancer_num_groups(num_workers in 1usize..1000) {
let balancer = LoadBalancer::new(num_workers);
let group_size = WORK_GROUP_SIZE.min(num_workers);
let expected_groups = num_workers.div_ceil(group_size);
prop_assert_eq!(balancer.num_groups(), expected_groups);
}
#[test]
fn load_balancer_estimate_reasonable(num_items in 1usize..1000, avg_subtasks in 1usize..100) {
let workers = LoadBalancer::estimate_workers(num_items, avg_subtasks);
prop_assert!(workers >= WORK_GROUP_SIZE.min(num_items * avg_subtasks));
prop_assert!(workers >= 1);
}
#[test]
fn dispatch_stats_avg_subtasks(
total_subtasks in 1usize..1000,
num_groups in 1usize..10,
group_size in 1usize..32
) {
let stats = DispatchStats {
total_items: 10,
dispatched_items: 5,
total_subtasks,
num_groups,
group_size,
};
let total_workers = num_groups * group_size;
let expected_avg = total_subtasks as f64 / total_workers as f64;
prop_assert!((stats.avg_subtasks_per_worker() - expected_avg).abs() < 1e-10);
}
}
}