tenflowers-core 0.1.1

Core tensor operations and execution engine for TenfloweRS
Documentation
// Normalization operation compute shaders

struct NormalizationParams {
    normalized_size: u32,
    epsilon: f32,
}

@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read> gamma: array<f32>;
@group(0) @binding(2) var<storage, read> beta: array<f32>;
@group(0) @binding(3) var<storage, read_write> output: array<f32>;
@group(0) @binding(4) var<uniform> params: NormalizationParams;

// Layer normalization kernel
// Each workgroup processes one normalization group
@compute @workgroup_size(256)
fn layer_norm(@builtin(global_invocation_id) global_id: vec3<u32>,
              @builtin(local_invocation_id) local_id: vec3<u32>,
              @builtin(workgroup_id) group_id: vec3<u32>) {
    
    let normalized_size = params.normalized_size;
    let epsilon = params.epsilon;
    let group_idx = group_id.x;
    let local_idx = local_id.x;
    let workgroup_size = 256u;
    
    // Shared memory for reduction
    var<workgroup> shared_sum: array<f32, 256>;
    var<workgroup> shared_sum_sq: array<f32, 256>;
    
    // Calculate start index for this normalization group
    let group_start = group_idx * normalized_size;
    
    // Phase 1: Calculate local sums
    var local_sum = 0.0;
    var local_sum_sq = 0.0;
    
    // Each thread processes multiple elements
    for (var i = local_idx; i < normalized_size; i += workgroup_size) {
        let idx = group_start + i;
        if (idx < arrayLength(&input)) {
            let val = input[idx];
            local_sum += val;
            local_sum_sq += val * val;
        }
    }
    
    // Store in shared memory
    shared_sum[local_idx] = local_sum;
    shared_sum_sq[local_idx] = local_sum_sq;
    workgroupBarrier();
    
    // Phase 2: Reduction to calculate mean and variance
    // Tree reduction
    for (var stride = workgroup_size / 2u; stride > 0u; stride >>= 1u) {
        if (local_idx < stride) {
            shared_sum[local_idx] += shared_sum[local_idx + stride];
            shared_sum_sq[local_idx] += shared_sum_sq[local_idx + stride];
        }
        workgroupBarrier();
    }
    
    // Thread 0 has the final sum
    var mean: f32;
    var variance: f32;
    if (local_idx == 0u) {
        mean = shared_sum[0] / f32(normalized_size);
        variance = shared_sum_sq[0] / f32(normalized_size) - mean * mean;
    }
    
    // Broadcast mean and variance to all threads
    workgroupBarrier();
    if (local_idx == 0u) {
        shared_sum[0] = mean;
        shared_sum[1] = variance;
    }
    workgroupBarrier();
    
    mean = shared_sum[0];
    variance = shared_sum[1];
    let std_dev = sqrt(variance + epsilon);
    
    // Phase 3: Apply normalization
    for (var i = local_idx; i < normalized_size; i += workgroup_size) {
        let idx = group_start + i;
        if (idx < arrayLength(&input)) {
            let normalized = (input[idx] - mean) / std_dev;
            output[idx] = gamma[i] * normalized + beta[i];
        }
    }
}

// Simple layer norm for small normalized sizes (no shared memory optimization)
@compute @workgroup_size(64)
fn layer_norm_simple(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let group_idx = global_id.x;
    let normalized_size = params.normalized_size;
    let epsilon = params.epsilon;
    
    // Calculate start index for this normalization group
    let group_start = group_idx * normalized_size;
    
    // Check bounds
    if (group_start >= arrayLength(&input)) {
        return;
    }
    
    // Calculate mean
    var sum = 0.0;
    for (var i = 0u; i < normalized_size; i++) {
        let idx = group_start + i;
        if (idx < arrayLength(&input)) {
            sum += input[idx];
        }
    }
    let mean = sum / f32(normalized_size);
    
    // Calculate variance
    var sum_sq = 0.0;
    for (var i = 0u; i < normalized_size; i++) {
        let idx = group_start + i;
        if (idx < arrayLength(&input)) {
            let diff = input[idx] - mean;
            sum_sq += diff * diff;
        }
    }
    let variance = sum_sq / f32(normalized_size);
    let std_dev = sqrt(variance + epsilon);
    
    // Apply normalization
    for (var i = 0u; i < normalized_size; i++) {
        let idx = group_start + i;
        if (idx < arrayLength(&input)) {
            let normalized = (input[idx] - mean) / std_dev;
            output[idx] = gamma[i] * normalized + beta[i];
        }
    }
}

// Group normalization kernel
struct GroupNormParams {
    batch_size: u32,
    num_channels: u32,
    num_groups: u32,
    spatial_size: u32,
    epsilon: f32,
}

@group(0) @binding(0) var<storage, read> group_input: array<f32>;
@group(0) @binding(1) var<storage, read> group_gamma: array<f32>;
@group(0) @binding(2) var<storage, read> group_beta: array<f32>;
@group(0) @binding(3) var<storage, read_write> group_output: array<f32>;
@group(0) @binding(4) var<uniform> group_params: GroupNormParams;

@compute @workgroup_size(256)
fn group_norm(@builtin(global_invocation_id) global_id: vec3<u32>,
              @builtin(local_invocation_id) local_id: vec3<u32>,
              @builtin(workgroup_id) group_id: vec3<u32>) {
    
    let batch_idx = group_id.x;
    let group_idx = group_id.y;
    let local_idx = local_id.x;
    
    let channels_per_group = group_params.num_channels / group_params.num_groups;
    let group_size = channels_per_group * group_params.spatial_size;
    let workgroup_size = 256u;
    
    // Shared memory for reduction
    var<workgroup> shared_sum: array<f32, 256>;
    var<workgroup> shared_sum_sq: array<f32, 256>;
    
    // Calculate local sums
    var local_sum = 0.0;
    var local_sum_sq = 0.0;
    
    let batch_offset = batch_idx * group_params.num_channels * group_params.spatial_size;
    let group_offset = group_idx * channels_per_group * group_params.spatial_size;
    
    // Each thread processes multiple elements
    for (var i = local_idx; i < group_size; i += workgroup_size) {
        let idx = batch_offset + group_offset + i;
        if (idx < arrayLength(&group_input)) {
            let val = group_input[idx];
            local_sum += val;
            local_sum_sq += val * val;
        }
    }
    
    // Store in shared memory
    shared_sum[local_idx] = local_sum;
    shared_sum_sq[local_idx] = local_sum_sq;
    workgroupBarrier();
    
    // Tree reduction
    for (var stride = workgroup_size / 2u; stride > 0u; stride >>= 1u) {
        if (local_idx < stride) {
            shared_sum[local_idx] += shared_sum[local_idx + stride];
            shared_sum_sq[local_idx] += shared_sum_sq[local_idx + stride];
        }
        workgroupBarrier();
    }
    
    // Calculate mean and variance
    var mean: f32;
    var variance: f32;
    if (local_idx == 0u) {
        mean = shared_sum[0] / f32(group_size);
        variance = shared_sum_sq[0] / f32(group_size) - mean * mean;
    }
    
    // Broadcast to all threads
    workgroupBarrier();
    if (local_idx == 0u) {
        shared_sum[0] = mean;
        shared_sum[1] = variance;
    }
    workgroupBarrier();
    
    mean = shared_sum[0];
    variance = shared_sum[1];
    let std_dev = sqrt(variance + group_params.epsilon);
    
    // Apply normalization
    for (var i = local_idx; i < group_size; i += workgroup_size) {
        let idx = batch_offset + group_offset + i;
        if (idx < arrayLength(&group_input)) {
            let channel_idx = (group_idx * channels_per_group + i / group_params.spatial_size) % group_params.num_channels;
            let normalized = (group_input[idx] - mean) / std_dev;
            group_output[idx] = group_gamma[channel_idx] * normalized + group_beta[channel_idx];
        }
    }
}

// Synchronized Batch Normalization shaders

struct BatchStatsParams {
    batch_size: u32,
    channels: u32,
    height: u32,
    width: u32,
    spatial_size: u32,
}

// Compute mean for synchronized batch norm
@group(0) @binding(0) var<storage, read> sync_input: array<f32>;
@group(0) @binding(1) var<storage, read_write> sync_means: array<f32>;
@group(0) @binding(2) var<uniform> sync_params: BatchStatsParams;

@compute @workgroup_size(256)
fn sync_batch_norm_compute_mean(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let channel_idx = global_id.x;
    
    if (channel_idx >= sync_params.channels) {
        return;
    }
    
    let batch_size = sync_params.batch_size;
    let height = sync_params.height;
    let width = sync_params.width;
    let spatial_size = sync_params.spatial_size;
    let elements_per_channel = batch_size * spatial_size;
    
    var sum = 0.0;
    var count = 0u;
    
    // Sum all elements for this channel across all batches and spatial dimensions
    for (var b = 0u; b < batch_size; b++) {
        for (var s = 0u; s < spatial_size; s++) {
            let idx = b * sync_params.channels * spatial_size + channel_idx * spatial_size + s;
            if (idx < arrayLength(&sync_input)) {
                sum += sync_input[idx];
                count += 1u;
            }
        }
    }
    
    // Store the mean for this channel
    if (count > 0u) {
        sync_means[channel_idx] = sum / f32(count);
    }
}

// Compute variance for synchronized batch norm
@group(0) @binding(0) var<storage, read> sync_var_input: array<f32>;
@group(0) @binding(1) var<storage, read> sync_channel_means: array<f32>;
@group(0) @binding(2) var<storage, read_write> sync_vars: array<f32>;
@group(0) @binding(3) var<uniform> sync_var_params: BatchStatsParams;

@compute @workgroup_size(256)
fn sync_batch_norm_compute_var(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let channel_idx = global_id.x;
    
    if (channel_idx >= sync_var_params.channels) {
        return;
    }
    
    let batch_size = sync_var_params.batch_size;
    let height = sync_var_params.height;
    let width = sync_var_params.width;
    let spatial_size = sync_var_params.spatial_size;
    let mean = sync_channel_means[channel_idx];
    
    var sum_squared_diff = 0.0;
    var count = 0u;
    
    // Sum squared differences for this channel across all batches and spatial dimensions
    for (var b = 0u; b < batch_size; b++) {
        for (var s = 0u; s < spatial_size; s++) {
            let idx = b * sync_var_params.channels * spatial_size + channel_idx * spatial_size + s;
            if (idx < arrayLength(&sync_var_input)) {
                let diff = sync_var_input[idx] - mean;
                sum_squared_diff += diff * diff;
                count += 1u;
            }
        }
    }
    
    // Store the variance for this channel
    if (count > 0u) {
        sync_vars[channel_idx] = sum_squared_diff / f32(count);
    }
}

// Apply synchronized batch normalization
struct SyncBatchNormParams {
    batch_size: u32,
    channels: u32,
    height: u32,
    width: u32,
    epsilon: f32,
}

@group(0) @binding(0) var<storage, read> sync_apply_input: array<f32>;
@group(0) @binding(1) var<storage, read> sync_apply_means: array<f32>;
@group(0) @binding(2) var<storage, read> sync_apply_vars: array<f32>;
@group(0) @binding(3) var<storage, read> sync_apply_gamma: array<f32>;
@group(0) @binding(4) var<storage, read> sync_apply_beta: array<f32>;
@group(0) @binding(5) var<storage, read_write> sync_apply_output: array<f32>;
@group(0) @binding(6) var<uniform> sync_apply_params: SyncBatchNormParams;

@compute @workgroup_size(256)
fn sync_batch_norm_apply(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let idx = global_id.x;
    
    if (idx >= arrayLength(&sync_apply_input)) {
        return;
    }
    
    let batch_size = sync_apply_params.batch_size;
    let channels = sync_apply_params.channels;
    let height = sync_apply_params.height;
    let width = sync_apply_params.width;
    let spatial_size = height * width;
    let epsilon = sync_apply_params.epsilon;
    
    // Calculate which channel this element belongs to
    let elements_per_batch = channels * spatial_size;
    let batch_idx = idx / elements_per_batch;
    let idx_in_batch = idx % elements_per_batch;
    let channel_idx = idx_in_batch / spatial_size;
    
    if (channel_idx >= channels) {
        return;
    }
    
    // Get statistics for this channel
    let mean = sync_apply_means[channel_idx];
    let variance = sync_apply_vars[channel_idx];
    let std_dev = sqrt(variance + epsilon);
    let gamma = sync_apply_gamma[channel_idx];
    let beta = sync_apply_beta[channel_idx];
    
    // Apply normalization
    let normalized = (sync_apply_input[idx] - mean) / std_dev;
    sync_apply_output[idx] = gamma * normalized + beta;
}