tenflowers-core 0.1.1

Core tensor operations and execution engine for TenfloweRS
Documentation
// Parallel scan (prefix sum) operations for GPU

struct ScanParams {
    size: u32,
    axis: u32,
    stride: u32,
    axis_size: u32,
}

@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
@group(0) @binding(2) var<storage, read_write> temp: array<f32>;
@group(0) @binding(3) var<uniform> params: ScanParams;

// Workgroup-level shared memory for parallel scan
var<workgroup> shared_data: array<f32, 256>;

// Up-sweep (reduce) phase of parallel scan
@compute @workgroup_size(256)
fn cumsum_up_sweep(@builtin(global_invocation_id) global_id: vec3<u32>,
                   @builtin(local_invocation_id) local_id: vec3<u32>,
                   @builtin(workgroup_id) workgroup_id: vec3<u32>) {
    let tid = local_id.x;
    let gid = global_id.x;
    let workgroup_size = 256u;
    
    if (gid >= params.size) {
        return;
    }
    
    // Load data into shared memory
    shared_data[tid] = input[gid];
    workgroupBarrier();
    
    // Up-sweep phase
    var step = 1u;
    while (step < workgroup_size) {
        if (tid % (2u * step) == 0u && tid + step < workgroup_size) {
            shared_data[tid + step] += shared_data[tid];
        }
        step *= 2u;
        workgroupBarrier();
    }
    
    // Store the partial sum for this workgroup
    if (tid == 0u) {
        temp[workgroup_id.x] = shared_data[workgroup_size - 1u];
    }
}

// Down-sweep phase of parallel scan
@compute @workgroup_size(256)
fn cumsum_down_sweep(@builtin(global_invocation_id) global_id: vec3<u32>,
                     @builtin(local_invocation_id) local_id: vec3<u32>,
                     @builtin(workgroup_id) workgroup_id: vec3<u32>) {
    let tid = local_id.x;
    let gid = global_id.x;
    let workgroup_size = 256u;
    
    if (gid >= params.size) {
        return;
    }
    
    // Load data into shared memory
    shared_data[tid] = input[gid];
    workgroupBarrier();
    
    // Up-sweep phase (same as above)
    var step = 1u;
    while (step < workgroup_size) {
        if (tid % (2u * step) == 0u && tid + step < workgroup_size) {
            shared_data[tid + step] += shared_data[tid];
        }
        step *= 2u;
        workgroupBarrier();
    }
    
    // Clear the last element for down-sweep
    if (tid == 0u) {
        shared_data[workgroup_size - 1u] = 0.0;
    }
    workgroupBarrier();
    
    // Down-sweep phase
    step = workgroup_size / 2u;
    while (step > 0u) {
        if (tid % (2u * step) == 0u && tid + step < workgroup_size) {
            let temp_val = shared_data[tid];
            shared_data[tid] = shared_data[tid + step];
            shared_data[tid + step] += temp_val;
        }
        step /= 2u;
        workgroupBarrier();
    }
    
    // Add the prefix sum from previous workgroups
    var prefix_sum = 0.0;
    if (workgroup_id.x > 0u) {
        prefix_sum = temp[workgroup_id.x - 1u];
    }
    
    // Store result (exclusive scan) + input value for inclusive scan
    output[gid] = shared_data[tid] + prefix_sum + input[gid];
}

// Simple cumulative sum for small arrays (single workgroup)
@compute @workgroup_size(256)
fn cumsum_simple(@builtin(global_invocation_id) global_id: vec3<u32>,
                 @builtin(local_invocation_id) local_id: vec3<u32>) {
    let tid = local_id.x;
    let gid = global_id.x;
    
    if (gid >= params.size) {
        return;
    }
    
    // Load data into shared memory
    shared_data[tid] = input[gid];
    workgroupBarrier();
    
    // Simple prefix sum using shared memory
    var step = 1u;
    while (step < 256u) {
        if (tid >= step && tid < params.size) {
            shared_data[tid] += shared_data[tid - step];
        }
        step *= 2u;
        workgroupBarrier();
    }
    
    // Store result
    output[gid] = shared_data[tid];
}

// Cumulative product operations
@compute @workgroup_size(256)
fn cumprod_simple(@builtin(global_invocation_id) global_id: vec3<u32>,
                  @builtin(local_invocation_id) local_id: vec3<u32>) {
    let tid = local_id.x;
    let gid = global_id.x;
    
    if (gid >= params.size) {
        return;
    }
    
    // Load data into shared memory
    shared_data[tid] = input[gid];
    workgroupBarrier();
    
    // Simple prefix product using shared memory
    var step = 1u;
    while (step < 256u) {
        if (tid >= step && tid < params.size) {
            shared_data[tid] *= shared_data[tid - step];
        }
        step *= 2u;
        workgroupBarrier();
    }
    
    // Store result
    output[gid] = shared_data[tid];
}

// Axis-aware cumulative sum
@compute @workgroup_size(64)
fn cumsum_axis(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let gid = global_id.x;
    
    if (gid >= params.size) {
        return;
    }
    
    // Calculate position within the tensor
    let outer_idx = gid / (params.axis_size * params.stride);
    let inner_idx = gid % params.stride;
    let axis_idx = (gid / params.stride) % params.axis_size;
    
    // Base index for this slice
    let base_idx = outer_idx * params.axis_size * params.stride + inner_idx;
    
    // Compute cumulative sum along the axis
    var sum = 0.0;
    for (var i = 0u; i <= axis_idx; i++) {
        let idx = base_idx + i * params.stride;
        sum += input[idx];
    }
    
    output[gid] = sum;
}

// Axis-aware cumulative product
@compute @workgroup_size(64)
fn cumprod_axis(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let gid = global_id.x;
    
    if (gid >= params.size) {
        return;
    }
    
    // Calculate position within the tensor
    let outer_idx = gid / (params.axis_size * params.stride);
    let inner_idx = gid % params.stride;
    let axis_idx = (gid / params.stride) % params.axis_size;
    
    // Base index for this slice
    let base_idx = outer_idx * params.axis_size * params.stride + inner_idx;
    
    // Compute cumulative product along the axis
    var prod = 1.0;
    for (var i = 0u; i <= axis_idx; i++) {
        let idx = base_idx + i * params.stride;
        prod *= input[idx];
    }
    
    output[gid] = prod;
}

// Optimized scan for power-of-2 sizes
@compute @workgroup_size(256)
fn cumsum_optimized(@builtin(global_invocation_id) global_id: vec3<u32>,
                    @builtin(local_invocation_id) local_id: vec3<u32>) {
    let tid = local_id.x;
    let gid = global_id.x * 2u;
    
    if (gid >= params.size) {
        return;
    }
    
    // Load two elements per thread
    shared_data[tid] = 0.0;
    if (gid < params.size) {
        shared_data[tid] += input[gid];
    }
    if (gid + 1u < params.size) {
        shared_data[tid] += input[gid + 1u];
    }
    
    workgroupBarrier();
    
    // Efficient tree-based scan
    var step = 1u;
    while (step < 256u) {
        if (tid >= step) {
            shared_data[tid] += shared_data[tid - step];
        }
        step *= 2u;
        workgroupBarrier();
    }
    
    // Store results
    if (gid < params.size) {
        output[gid] = shared_data[tid] - (gid + 1u < params.size ? input[gid + 1u] : 0.0);
    }
    if (gid + 1u < params.size) {
        output[gid + 1u] = shared_data[tid];
    }
}