tenflowers-core 0.1.1

Core tensor operations and execution engine for TenfloweRS
Documentation
// Quantization operation compute shaders
// These kernels implement quantization and dequantization operations

@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read_write> output: array<i32>;
@group(0) @binding(2) var<storage, read> params: array<f32>; // [scale, zero_point, qmin, qmax]

// Quantization kernel - converts float to quantized integer
@compute @workgroup_size(64)
fn quantize(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let index = global_id.x;
    if (index >= arrayLength(&output)) {
        return;
    }
    
    let scale = params[0];
    let zero_point = params[1];
    let qmin = params[2];
    let qmax = params[3];
    
    let val = input[index];
    let quantized = round((val / scale) + zero_point);
    output[index] = i32(clamp(quantized, qmin, qmax));
}

// INT8 specific quantization shader
@group(0) @binding(0) var<storage, read> input_f32: array<f32>;
@group(0) @binding(1) var<storage, read_write> output_i8: array<i32>; // Using i32 for i8 values

@compute @workgroup_size(64)
fn quantize_int8(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let index = global_id.x;
    if (index >= arrayLength(&output_i8)) {
        return;
    }
    
    let scale = params[0];
    let zero_point = params[1];
    
    let val = input_f32[index];
    let quantized = round((val / scale) + zero_point);
    output_i8[index] = i32(clamp(quantized, -128.0, 127.0));
}

// INT4 specific quantization shader
@compute @workgroup_size(64)
fn quantize_int4(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let index = global_id.x;
    if (index >= arrayLength(&output_i8)) {
        return;
    }
    
    let scale = params[0];
    let zero_point = params[1];
    
    let val = input_f32[index];
    let quantized = round((val / scale) + zero_point);
    output_i8[index] = i32(clamp(quantized, -8.0, 7.0));
}

// Dequantization shader - converts quantized integer back to float
@group(0) @binding(0) var<storage, read> input_quantized: array<i32>;
@group(0) @binding(1) var<storage, read_write> output_float: array<f32>;

@compute @workgroup_size(64)
fn dequantize(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let index = global_id.x;
    if (index >= arrayLength(&output_float)) {
        return;
    }
    
    let scale = params[0];
    let zero_point = params[1];
    
    let quantized_val = input_quantized[index];
    output_float[index] = (f32(quantized_val) - zero_point) * scale;
}

// Fake quantization shader - simulates quantization effects during training
@group(0) @binding(0) var<storage, read> input_fake: array<f32>;
@group(0) @binding(1) var<storage, read_write> output_fake: array<f32>;

@compute @workgroup_size(64)
fn fake_quantize(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let index = global_id.x;
    if (index >= arrayLength(&output_fake)) {
        return;
    }
    
    let scale = params[0];
    let zero_point = params[1];
    let qmin = params[2];
    let qmax = params[3];
    
    let val = input_fake[index];
    let quantized = round((val / scale) + zero_point);
    let clamped = clamp(quantized, qmin, qmax);
    output_fake[index] = (clamped - zero_point) * scale;
}

// Dynamic quantization shader - calculates min/max on GPU
@group(0) @binding(0) var<storage, read> input_dynamic: array<f32>;
@group(0) @binding(1) var<storage, read_write> output_dynamic: array<i32>;
@group(0) @binding(2) var<storage, read_write> min_max_output: array<f32>; // [min, max]

var<workgroup> shared_data: array<f32, 64>;

@compute @workgroup_size(64)
fn dynamic_quantize(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let index = global_id.x;
    let local_idx = global_id.x % 64u;
    
    // Initialize shared memory
    shared_data[local_idx] = 0.0;
    
    // Load data if within bounds
    var val = 0.0;
    if (index < arrayLength(&input_dynamic)) {
        val = input_dynamic[index];
    }
    
    shared_data[local_idx] = val;
    workgroupBarrier();
    
    // Find min/max within workgroup
    var min_val = val;
    var max_val = val;
    
    for (var i = 0u; i < 64u; i++) {
        let shared_val = shared_data[i];
        min_val = min(min_val, shared_val);
        max_val = max(max_val, shared_val);
    }
    
    // Store results (first thread in workgroup)
    if (local_idx == 0u) {
        min_max_output[0] = min_val;
        min_max_output[1] = max_val;
    }
    
    workgroupBarrier();
    
    // Perform quantization with calculated min/max
    let abs_max = max(abs(min_val), abs(max_val));
    let scale = abs_max / 127.0;
    
    if (index < arrayLength(&output_dynamic)) {
        let quantized = round(val / scale);
        output_dynamic[index] = i32(clamp(quantized, -128.0, 127.0));
    }
}

// Per-channel quantization shader
@group(0) @binding(0) var<storage, read> input_per_channel: array<f32>;
@group(0) @binding(1) var<storage, read_write> output_per_channel: array<i32>;
@group(0) @binding(2) var<storage, read> channel_params: array<f32>; // [scale0, zero_point0, scale1, zero_point1, ...]
@group(0) @binding(3) var<storage, read> channel_metadata: array<u32>; // [channel_axis, num_channels, stride]

@compute @workgroup_size(64)
fn per_channel_quantize(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let index = global_id.x;
    if (index >= arrayLength(&output_per_channel)) {
        return;
    }
    
    let channel_axis = channel_metadata[0];
    let num_channels = channel_metadata[1];
    let stride = channel_metadata[2];
    
    // Calculate which channel this element belongs to
    let channel_idx = (index / stride) % num_channels;
    
    let scale = channel_params[channel_idx * 2u];
    let zero_point = channel_params[channel_idx * 2u + 1u];
    
    let val = input_per_channel[index];
    let quantized = round((val / scale) + zero_point);
    output_per_channel[index] = i32(clamp(quantized, -128.0, 127.0));
}