tenflowers-core 0.1.1

Core tensor operations and execution engine for TenfloweRS
Documentation
// Fused operation compute shaders for kernel fusion optimization
// These kernels combine multiple operations to reduce memory bandwidth and improve performance

// ===== FUSED ADD + ACTIVATION KERNELS =====

struct BinaryOpParams {
    size: u32,
    broadcast_a: u32,  // 1 if tensor A should be broadcast, 0 otherwise
    broadcast_b: u32,  // 1 if tensor B should be broadcast, 0 otherwise
    size_a: u32,       // Size of tensor A
    size_b: u32,       // Size of tensor B
}

@group(0) @binding(0) var<storage, read> input_a: array<f32>;
@group(0) @binding(1) var<storage, read> input_b: array<f32>;
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
@group(0) @binding(3) var<uniform> params: BinaryOpParams;

// Fused Add + ReLU kernel
@compute @workgroup_size(256)
fn add_relu_kernel(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let idx = global_id.x;
    
    if (idx >= params.size) {
        return;
    }
    
    // Handle broadcasting
    let idx_a = select(idx, idx % params.size_a, params.broadcast_a == 1u);
    let idx_b = select(idx, idx % params.size_b, params.broadcast_b == 1u);
    
    // Fused Add + ReLU
    let sum = input_a[idx_a] + input_b[idx_b];
    output[idx] = max(0.0, sum);
}

// Fused Add + Sigmoid kernel
@compute @workgroup_size(256)
fn add_sigmoid_kernel(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let idx = global_id.x;
    
    if (idx >= params.size) {
        return;
    }
    
    // Handle broadcasting
    let idx_a = select(idx, idx % params.size_a, params.broadcast_a == 1u);
    let idx_b = select(idx, idx % params.size_b, params.broadcast_b == 1u);
    
    // Fused Add + Sigmoid
    let sum = input_a[idx_a] + input_b[idx_b];
    output[idx] = 1.0 / (1.0 + exp(-sum));
}

// Fused Add + Tanh kernel
@compute @workgroup_size(256)
fn add_tanh_kernel(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let idx = global_id.x;
    
    if (idx >= params.size) {
        return;
    }
    
    // Handle broadcasting
    let idx_a = select(idx, idx % params.size_a, params.broadcast_a == 1u);
    let idx_b = select(idx, idx % params.size_b, params.broadcast_b == 1u);
    
    // Fused Add + Tanh
    let sum = input_a[idx_a] + input_b[idx_b];
    output[idx] = tanh(sum);
}

// Fused Add + GELU kernel
@compute @workgroup_size(256)
fn add_gelu_kernel(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let idx = global_id.x;
    
    if (idx >= params.size) {
        return;
    }
    
    // Handle broadcasting
    let idx_a = select(idx, idx % params.size_a, params.broadcast_a == 1u);
    let idx_b = select(idx, idx % params.size_b, params.broadcast_b == 1u);
    
    // Fused Add + GELU (approximation: 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3))))
    let sum = input_a[idx_a] + input_b[idx_b];
    let x_cubed = sum * sum * sum;
    let inner = sqrt(2.0 / 3.14159265359) * (sum + 0.044715 * x_cubed);
    output[idx] = 0.5 * sum * (1.0 + tanh(inner));
}

// Fused Add + Swish (SiLU) kernel
@compute @workgroup_size(256)
fn add_swish_kernel(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let idx = global_id.x;
    
    if (idx >= params.size) {
        return;
    }
    
    // Handle broadcasting
    let idx_a = select(idx, idx % params.size_a, params.broadcast_a == 1u);
    let idx_b = select(idx, idx % params.size_b, params.broadcast_b == 1u);
    
    // Fused Add + Swish (x * sigmoid(x))
    let sum = input_a[idx_a] + input_b[idx_b];
    let sigmoid = 1.0 / (1.0 + exp(-sum));
    output[idx] = sum * sigmoid;
}

// Fused Add + Mish kernel
@compute @workgroup_size(256)
fn add_mish_kernel(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let idx = global_id.x;
    
    if (idx >= params.size) {
        return;
    }
    
    // Handle broadcasting
    let idx_a = select(idx, idx % params.size_a, params.broadcast_a == 1u);
    let idx_b = select(idx, idx % params.size_b, params.broadcast_b == 1u);
    
    // Fused Add + Mish (x * tanh(softplus(x)))
    let sum = input_a[idx_a] + input_b[idx_b];
    let softplus = log(1.0 + exp(sum));
    output[idx] = sum * tanh(softplus);
}

// Fused Add + LeakyReLU kernel
struct LeakyReLUParams {
    size: u32,
    broadcast_a: u32,
    broadcast_b: u32,
    size_a: u32,
    size_b: u32,
    negative_slope: f32,
}

@group(0) @binding(3) var<uniform> leaky_params: LeakyReLUParams;

@compute @workgroup_size(256)
fn add_leaky_relu_kernel(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let idx = global_id.x;
    
    if (idx >= leaky_params.size) {
        return;
    }
    
    // Handle broadcasting
    let idx_a = select(idx, idx % leaky_params.size_a, leaky_params.broadcast_a == 1u);
    let idx_b = select(idx, idx % leaky_params.size_b, leaky_params.broadcast_b == 1u);
    
    // Fused Add + LeakyReLU
    let sum = input_a[idx_a] + input_b[idx_b];
    output[idx] = select(leaky_params.negative_slope * sum, sum, sum > 0.0);
}

// ===== FUSED MULTIPLY + ACTIVATION KERNELS =====

// Fused Mul + ReLU kernel
@compute @workgroup_size(256)
fn mul_relu_kernel(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let idx = global_id.x;
    
    if (idx >= params.size) {
        return;
    }
    
    // Handle broadcasting
    let idx_a = select(idx, idx % params.size_a, params.broadcast_a == 1u);
    let idx_b = select(idx, idx % params.size_b, params.broadcast_b == 1u);
    
    // Fused Mul + ReLU
    let product = input_a[idx_a] * input_b[idx_b];
    output[idx] = max(0.0, product);
}

// Fused Mul + Sigmoid kernel
@compute @workgroup_size(256)
fn mul_sigmoid_kernel(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let idx = global_id.x;
    
    if (idx >= params.size) {
        return;
    }
    
    // Handle broadcasting
    let idx_a = select(idx, idx % params.size_a, params.broadcast_a == 1u);
    let idx_b = select(idx, idx % params.size_b, params.broadcast_b == 1u);
    
    // Fused Mul + Sigmoid
    let product = input_a[idx_a] * input_b[idx_b];
    output[idx] = 1.0 / (1.0 + exp(-product));
}

// Fused Mul + Tanh kernel
@compute @workgroup_size(256)
fn mul_tanh_kernel(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let idx = global_id.x;
    
    if (idx >= params.size) {
        return;
    }
    
    // Handle broadcasting
    let idx_a = select(idx, idx % params.size_a, params.broadcast_a == 1u);
    let idx_b = select(idx, idx % params.size_b, params.broadcast_b == 1u);
    
    // Fused Mul + Tanh
    let product = input_a[idx_a] * input_b[idx_b];
    output[idx] = tanh(product);
}

// ===== FUSED NORMALIZATION + ACTIVATION KERNELS =====

struct NormActivationParams {
    size: u32,
    normalized_shape: u32,  // Size of the last dimension being normalized
    eps: f32,              // Epsilon for numerical stability
}

@group(0) @binding(0) var<storage, read> norm_input: array<f32>;
@group(0) @binding(1) var<storage, read> norm_weight: array<f32>;
@group(0) @binding(2) var<storage, read> norm_bias: array<f32>;
@group(0) @binding(3) var<storage, read_write> norm_output: array<f32>;
@group(0) @binding(4) var<uniform> norm_params: NormActivationParams;

// Fused LayerNorm + ReLU kernel
@compute @workgroup_size(256)
fn layernorm_relu_kernel(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let idx = global_id.x;
    
    if (idx >= norm_params.size) {
        return;
    }
    
    let normalized_shape = norm_params.normalized_shape;
    let batch_idx = idx / normalized_shape;
    let feature_idx = idx % normalized_shape;
    
    // Compute mean and variance for this batch
    var sum = 0.0;
    var sum_sq = 0.0;
    
    for (var i: u32 = 0u; i < normalized_shape; i++) {
        let val = norm_input[batch_idx * normalized_shape + i];
        sum += val;
        sum_sq += val * val;
    }
    
    let mean = sum / f32(normalized_shape);
    let variance = (sum_sq / f32(normalized_shape)) - (mean * mean);
    let std_dev = sqrt(variance + norm_params.eps);
    
    // Normalize and apply scale/bias
    let normalized = (norm_input[idx] - mean) / std_dev;
    let scaled = normalized * norm_weight[feature_idx] + norm_bias[feature_idx];
    
    // Apply ReLU activation
    norm_output[idx] = max(0.0, scaled);
}

// Fused LayerNorm + GELU kernel
@compute @workgroup_size(256)
fn layernorm_gelu_kernel(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let idx = global_id.x;
    
    if (idx >= norm_params.size) {
        return;
    }
    
    let normalized_shape = norm_params.normalized_shape;
    let batch_idx = idx / normalized_shape;
    let feature_idx = idx % normalized_shape;
    
    // Compute mean and variance for this batch
    var sum = 0.0;
    var sum_sq = 0.0;
    
    for (var i: u32 = 0u; i < normalized_shape; i++) {
        let val = norm_input[batch_idx * normalized_shape + i];
        sum += val;
        sum_sq += val * val;
    }
    
    let mean = sum / f32(normalized_shape);
    let variance = (sum_sq / f32(normalized_shape)) - (mean * mean);
    let std_dev = sqrt(variance + norm_params.eps);
    
    // Normalize and apply scale/bias
    let normalized = (norm_input[idx] - mean) / std_dev;
    let scaled = normalized * norm_weight[feature_idx] + norm_bias[feature_idx];
    
    // Apply GELU activation
    let x_cubed = scaled * scaled * scaled;
    let inner = sqrt(2.0 / 3.14159265359) * (scaled + 0.044715 * x_cubed);
    norm_output[idx] = 0.5 * scaled * (1.0 + tanh(inner));
}