nexar 0.1.2

Distributed runtime with QUIC transport, stream-multiplexed messaging, and built-in collectives
Documentation
use crate::client::NexarClient;
use crate::collective::helpers::{CollectiveTag, ceil_log2, collective_recv, collective_send};
use crate::error::{NexarError, Result};
use crate::reduce::{identity_slice, reduce_slice};
use crate::types::{DataType, ReduceOp};

/// Inclusive prefix scan: rank `i` holds the reduction of ranks 0..=i.
///
/// Uses a parallel prefix (Hillis-Steele) algorithm with O(log N) rounds.
///
/// # Safety
/// `ptr` must be valid for at least `count * dtype.size_in_bytes()` bytes.
pub(crate) async unsafe fn inclusive_scan(
    client: &NexarClient,
    ptr: u64,
    count: usize,
    dtype: DataType,
    op: ReduceOp,
    tag: CollectiveTag,
) -> Result<()> {
    let world = client.world_size() as usize;
    let rank = client.rank() as usize;

    if world <= 1 {
        return Ok(());
    }

    let elem_size = dtype.size_in_bytes();
    let total_bytes = count * elem_size;

    let data = unsafe { client.adapter().stage_for_send(ptr, total_bytes)? };
    let mut buf = data;

    let num_rounds = ceil_log2(world as u32) as usize;

    for round in 0..num_rounds {
        let distance = 1 << round;

        let should_send = rank + distance < world;
        let should_recv = rank >= distance;
        let source = rank.wrapping_sub(distance);

        match (should_send, should_recv) {
            (true, true) => {
                let dest = rank + distance;
                let send_data = buf.clone();
                let (_, received) = tokio::try_join!(
                    collective_send(client, dest as u32, &send_data, "scan", tag),
                    collective_recv(client, source as u32, "scan", tag),
                )?;
                if received.len() != total_bytes {
                    return Err(NexarError::BufferSizeMismatch {
                        expected: total_bytes,
                        actual: received.len(),
                    });
                }
                reduce_slice(&mut buf, &received, count, dtype, op)?;
            }
            (true, false) => {
                let dest = rank + distance;
                collective_send(client, dest as u32, &buf, "scan", tag).await?;
            }
            (false, true) => {
                let received = collective_recv(client, source as u32, "scan", tag).await?;
                if received.len() != total_bytes {
                    return Err(NexarError::BufferSizeMismatch {
                        expected: total_bytes,
                        actual: received.len(),
                    });
                }
                reduce_slice(&mut buf, &received, count, dtype, op)?;
            }
            (false, false) => {}
        }
    }

    unsafe { client.adapter().receive_to_device(&buf, ptr)? };

    Ok(())
}

/// Exclusive prefix scan: rank `i` holds the reduction of ranks 0..i.
///
/// Rank 0 receives the identity element for the given operation.
/// Implemented as inclusive scan followed by a left-shift exchange.
///
/// # Safety
/// `ptr` must be valid for at least `count * dtype.size_in_bytes()` bytes.
pub(crate) async unsafe fn exclusive_scan(
    client: &NexarClient,
    ptr: u64,
    count: usize,
    dtype: DataType,
    op: ReduceOp,
    tag: CollectiveTag,
) -> Result<()> {
    let world = client.world_size() as usize;
    let rank = client.rank() as usize;
    let elem_size = dtype.size_in_bytes();
    let total_bytes = count * elem_size;

    if world <= 1 {
        let id = identity_slice(count, dtype, op)?;
        unsafe { client.adapter().receive_to_device(&id, ptr)? };
        return Ok(());
    }

    // Run inclusive scan in-place.
    unsafe { inclusive_scan(client, ptr, count, dtype, op, tag).await? };

    // Shift results left by one rank.
    let inclusive_data = unsafe { client.adapter().stage_for_send(ptr, total_bytes)? };

    let should_send = rank + 1 < world;
    let should_recv = rank > 0;

    match (should_send, should_recv) {
        (true, true) => {
            let (_, received) = tokio::try_join!(
                collective_send(client, (rank + 1) as u32, &inclusive_data, "exscan", tag),
                collective_recv(client, (rank - 1) as u32, "exscan", tag),
            )?;
            if received.len() != total_bytes {
                return Err(NexarError::BufferSizeMismatch {
                    expected: total_bytes,
                    actual: received.len(),
                });
            }
            unsafe { client.adapter().receive_to_device(&received, ptr)? };
        }
        (true, false) => {
            collective_send(client, (rank + 1) as u32, &inclusive_data, "exscan", tag).await?;
            let id = identity_slice(count, dtype, op)?;
            unsafe { client.adapter().receive_to_device(&id, ptr)? };
        }
        (false, true) => {
            let received = collective_recv(client, (rank - 1) as u32, "exscan", tag).await?;
            if received.len() != total_bytes {
                return Err(NexarError::BufferSizeMismatch {
                    expected: total_bytes,
                    actual: received.len(),
                });
            }
            unsafe { client.adapter().receive_to_device(&received, ptr)? };
        }
        (false, false) => unreachable!("world > 1"),
    }

    Ok(())
}