tenflowers-core 0.1.1

Core tensor operations and execution engine for TenfloweRS
Documentation
// Segmented operations for ragged tensor support

struct SegmentParams {
    data_size: u32,
    num_segments: u32,
    _padding: array<u32, 2>,
}

@group(0) @binding(0) var<storage, read> data: array<f32>;
@group(0) @binding(1) var<storage, read> segment_ids: array<i32>;
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
@group(0) @binding(3) var<uniform> params: SegmentParams;

// Segmented sum operation
// Each thread processes one data element and atomically adds to its segment
@compute @workgroup_size(256)
fn segment_sum(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let idx = global_id.x;
    
    if (idx >= params.data_size) {
        return;
    }
    
    let segment_id = segment_ids[idx];
    if (segment_id >= 0 && segment_id < i32(params.num_segments)) {
        let data_val = data[idx];
        
        // Atomic add to accumulate sum for this segment
        atomicAdd(&output[segment_id], data_val);
    }
}

// Segmented max operation
// Uses atomic compare-exchange to find maximum
@compute @workgroup_size(256)
fn segment_max(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let idx = global_id.x;
    
    if (idx >= params.data_size) {
        return;
    }
    
    let segment_id = segment_ids[idx];
    if (segment_id >= 0 && segment_id < i32(params.num_segments)) {
        let data_val = data[idx];
        
        // Atomic max operation
        var current_max = atomicLoad(&output[segment_id]);
        var new_max = max(current_max, data_val);
        
        // Keep trying until we successfully update or find a larger value
        while (new_max > current_max) {
            let old_val = atomicCompareExchangeWeak(&output[segment_id], current_max, new_max);
            if (old_val == current_max) {
                break; // Successfully updated
            }
            current_max = old_val;
            new_max = max(current_max, data_val);
        }
    }
}

// Segmented mean operation - uses two buffers for sum and count
@group(0) @binding(3) var<storage, read_write> count_buffer: array<u32>;
@group(0) @binding(4) var<uniform> mean_params: SegmentParams;

@compute @workgroup_size(256)
fn segment_mean(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let idx = global_id.x;
    
    if (idx >= mean_params.data_size) {
        return;
    }
    
    let segment_id = segment_ids[idx];
    if (segment_id >= 0 && segment_id < i32(mean_params.num_segments)) {
        let data_val = data[idx];
        
        // Atomic add to accumulate sum for this segment
        atomicAdd(&output[segment_id], data_val);
        
        // Atomic add to count for this segment
        atomicAdd(&count_buffer[segment_id], 1u);
    }
}

// Second pass for segment_mean: divide sum by count
@compute @workgroup_size(64)
fn segment_mean_finalize(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let segment_id = global_id.x;
    
    if (segment_id >= mean_params.num_segments) {
        return;
    }
    
    let count = count_buffer[segment_id];
    if (count > 0u) {
        output[segment_id] = output[segment_id] / f32(count);
    }
}

// Optimized segmented reduction using workgroup shared memory
// This version processes segments in parallel within workgroups
var<workgroup> shared_sums: array<f32, 256>;
var<workgroup> shared_counts: array<u32, 256>;

@compute @workgroup_size(256)
fn segment_sum_optimized(@builtin(global_invocation_id) global_id: vec3<u32>,
                        @builtin(local_invocation_id) local_id: vec3<u32>,
                        @builtin(workgroup_id) workgroup_id: vec3<u32>) {
    let tid = local_id.x;
    let gid = global_id.x;
    let workgroup_size = 256u;
    
    // Initialize shared memory
    shared_sums[tid] = 0.0;
    shared_counts[tid] = 0u;
    workgroupBarrier();
    
    // Each thread processes multiple elements
    for (var i = gid; i < params.data_size; i += workgroup_size * 16u) {
        if (i < params.data_size) {
            let segment_id = segment_ids[i];
            if (segment_id >= 0 && segment_id < i32(params.num_segments)) {
                let data_val = data[i];
                
                // Accumulate in shared memory if segment maps to this thread
                if (u32(segment_id) % workgroup_size == tid) {
                    shared_sums[tid] += data_val;
                }
            }
        }
    }
    
    workgroupBarrier();
    
    // Write results from shared memory to global memory
    if (tid < params.num_segments) {
        atomicAdd(&output[tid], shared_sums[tid]);
    }
}

// Segmented argmax operation - finds index of maximum value in each segment
@group(0) @binding(2) var<storage, read_write> argmax_output: array<u32>;

@compute @workgroup_size(256)
fn segment_argmax(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let idx = global_id.x;
    
    if (idx >= params.data_size) {
        return;
    }
    
    let segment_id = segment_ids[idx];
    if (segment_id >= 0 && segment_id < i32(params.num_segments)) {
        let data_val = data[idx];
        
        // Try to update both max value and argmax index
        var current_max = atomicLoad(&output[segment_id]);
        
        while (data_val > current_max) {
            // Try to update max value
            let old_max = atomicCompareExchangeWeak(&output[segment_id], current_max, data_val);
            if (old_max == current_max) {
                // Successfully updated max, now update argmax
                atomicStore(&argmax_output[segment_id], idx);
                break;
            }
            current_max = old_max;
        }
    }
}

// Segmented scan operations - cumulative operations within segments
@compute @workgroup_size(256)
fn segment_cumsum(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let idx = global_id.x;
    
    if (idx >= params.data_size) {
        return;
    }
    
    let segment_id = segment_ids[idx];
    if (segment_id >= 0 && segment_id < i32(params.num_segments)) {
        let data_val = data[idx];
        
        // Simple implementation: sum all previous elements in same segment
        var cum_sum = 0.0;
        for (var i = 0u; i <= idx; i++) {
            if (i < params.data_size && segment_ids[i] == segment_id) {
                cum_sum += data[i];
            }
        }
        
        output[idx] = cum_sum;
    }
}

// Segmented prefix sum with efficient segment boundary handling
@compute @workgroup_size(256)
fn segment_prefix_sum(@builtin(global_invocation_id) global_id: vec3<u32>,
                     @builtin(local_invocation_id) local_id: vec3<u32>) {
    let tid = local_id.x;
    let gid = global_id.x;
    let workgroup_size = 256u;
    
    if (gid >= params.data_size) {
        return;
    }
    
    // Load data into shared memory
    shared_sums[tid] = data[gid];
    workgroupBarrier();
    
    // Parallel prefix sum within workgroup
    var step = 1u;
    while (step < workgroup_size) {
        if (tid >= step && gid >= step) {
            // Only add if we're in the same segment
            if (segment_ids[gid] == segment_ids[gid - step]) {
                shared_sums[tid] += shared_sums[tid - step];
            }
        }
        step *= 2u;
        workgroupBarrier();
    }
    
    // Store result
    output[gid] = shared_sums[tid];
}