tenflowers-core 0.1.1

Core tensor operations and execution engine for TenfloweRS
Documentation
// Activation function compute shaders

@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read_write> output: array<f32>;

// ReLU activation
@compute @workgroup_size(64)
fn relu_op(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let index = global_id.x;
    if (index >= arrayLength(&input)) {
        return;
    }
    
    output[index] = max(input[index], 0.0);
}

// Sigmoid activation
@compute @workgroup_size(64)
fn sigmoid_op(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let index = global_id.x;
    if (index >= arrayLength(&input)) {
        return;
    }
    
    output[index] = 1.0 / (1.0 + exp(-input[index]));
}

// Tanh activation
@compute @workgroup_size(64)
fn tanh_op(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let index = global_id.x;
    if (index >= arrayLength(&input)) {
        return;
    }
    
    output[index] = tanh(input[index]);
}

// GELU activation (approximation)
@compute @workgroup_size(64)
fn gelu_op(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let index = global_id.x;
    if (index >= arrayLength(&input)) {
        return;
    }
    
    let x = input[index];
    // GELU approximation: 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3)))
    let sqrt_2_over_pi = 0.7978845608;
    let a = 0.044715;
    let inner = sqrt_2_over_pi * (x + a * x * x * x);
    output[index] = 0.5 * x * (1.0 + tanh(inner));
}

// Swish activation
@compute @workgroup_size(64)
fn swish_op(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let index = global_id.x;
    if (index >= arrayLength(&input)) {
        return;
    }
    
    let x = input[index];
    output[index] = x / (1.0 + exp(-x));
}

// ELU activation
@compute @workgroup_size(64)
fn elu_op(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let index = global_id.x;
    if (index >= arrayLength(&input)) {
        return;
    }
    
    let x = input[index];
    if (x >= 0.0) {
        output[index] = x;
    } else {
        output[index] = exp(x) - 1.0;
    }
}

// Leaky ReLU activation
@compute @workgroup_size(64)
fn leaky_relu_op(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let index = global_id.x;
    if (index >= arrayLength(&input)) {
        return;
    }
    
    let x = input[index];
    let negative_slope = 0.01;
    output[index] = max(x, negative_slope * x);
}

// Mish activation
@compute @workgroup_size(64)
fn mish_op(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let index = global_id.x;
    if (index >= arrayLength(&input)) {
        return;
    }
    
    let x = input[index];
    // mish(x) = x * tanh(ln(1 + exp(x)))
    // For numerical stability, we use ln(1 + exp(x)) = softplus(x)
    let softplus = log(1.0 + exp(x));
    output[index] = x * tanh(softplus);
}

// HardSwish activation
@compute @workgroup_size(64)
fn hard_swish_op(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let index = global_id.x;
    if (index >= arrayLength(&input)) {
        return;
    }
    
    let x = input[index];
    // hard_swish(x) = x * relu6(x + 3) / 6
    // relu6(y) = min(max(y, 0), 6)
    let y = x + 3.0;
    let relu6_val = clamp(y, 0.0, 6.0);
    output[index] = x * relu6_val / 6.0;
}

// PReLU activation with separate alpha buffer and metadata for channel-wise alpha
@group(0) @binding(2) var<storage, read> alpha: array<f32>;

// Metadata struct for tensor shape information
struct ActivationMetadata {
    total_elements: u32,
    batch_size: u32,
    channels: u32, 
    height: u32,
    width: u32,
    is_channelwise: u32, // 0 = scalar alpha, 1 = channel-wise alpha
    padding1: u32,
    padding2: u32,
}

@group(0) @binding(3) var<uniform> metadata: ActivationMetadata;

@compute @workgroup_size(64)
fn prelu_op(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let index = global_id.x;
    if (index >= metadata.total_elements) {
        return;
    }
    
    let x = input[index];
    
    var alpha_val: f32;
    
    if (metadata.is_channelwise == 1u) {
        // Calculate channel index for channel-wise alpha
        // Assuming NCHW format: [batch, channels, height, width]
        let spatial_size = metadata.height * metadata.width;
        let channel_idx = (index / spatial_size) % metadata.channels;
        alpha_val = alpha[channel_idx];
    } else {
        // Scalar alpha - use alpha[0] for all elements
        alpha_val = alpha[0];
    }
    
    if (x >= 0.0) {
        output[index] = x;
    } else {
        output[index] = alpha_val * x;
    }
}