oximedia-gpu 0.1.8

GPU compute pipeline using WGPU for OxiMedia - cross-platform acceleration
Documentation
// Transform shaders for DCT and FFT operations

@group(0) @binding(0) var<storage, read> input_data: array<f32>;
@group(0) @binding(1) var<storage, read_write> output_data: array<f32>;
@group(0) @binding(2) var<uniform> params: TransformParams;

struct TransformParams {
    width: u32,
    height: u32,
    block_size: u32,    // 8 for 8x8 DCT, power of 2 for FFT
    transform_type: u32, // 0=DCT, 1=IDCT, 2=FFT, 3=IFFT
    stride: u32,
    is_inverse: u32,
    padding1: u32,
    padding2: u32,
}

const PI: f32 = 3.14159265359;

// 2D DCT (Type-II) for 8x8 blocks
@compute @workgroup_size(8, 8, 1)
fn dct_8x8(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let block_x = global_id.x / 8u;
    let block_y = global_id.y / 8u;
    let u = global_id.x % 8u;
    let v = global_id.y % 8u;

    let blocks_x = params.width / 8u;
    let blocks_y = params.height / 8u;

    if (block_x >= blocks_x || block_y >= blocks_y) {
        return;
    }

    let base_x = block_x * 8u;
    let base_y = block_y * 8u;

    var sum = 0.0;

    for (var y = 0u; y < 8u; y = y + 1u) {
        for (var x = 0u; x < 8u; x = x + 1u) {
            let in_x = base_x + x;
            let in_y = base_y + y;
            let in_idx = in_y * params.stride + in_x;

            let fx = f32(x);
            let fy = f32(y);
            let fu = f32(u);
            let fv = f32(v);

            let cos_u = cos((2.0 * fx + 1.0) * fu * PI / 16.0);
            let cos_v = cos((2.0 * fy + 1.0) * fv * PI / 16.0);

            sum += input_data[in_idx] * cos_u * cos_v;
        }
    }

    // Apply normalization factors
    var alpha_u = 0.5;
    var alpha_v = 0.5;
    if (u == 0u) { alpha_u = 0.353553390593; } // 1/sqrt(2) / 2
    if (v == 0u) { alpha_v = 0.353553390593; }

    let out_x = base_x + u;
    let out_y = base_y + v;
    let out_idx = out_y * params.stride + out_x;

    output_data[out_idx] = alpha_u * alpha_v * sum;
}

// 2D Inverse DCT (Type-III) for 8x8 blocks
@compute @workgroup_size(8, 8, 1)
fn idct_8x8(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let block_x = global_id.x / 8u;
    let block_y = global_id.y / 8u;
    let x = global_id.x % 8u;
    let y = global_id.y % 8u;

    let blocks_x = params.width / 8u;
    let blocks_y = params.height / 8u;

    if (block_x >= blocks_x || block_y >= blocks_y) {
        return;
    }

    let base_x = block_x * 8u;
    let base_y = block_y * 8u;

    var sum = 0.0;

    for (var v = 0u; v < 8u; v = v + 1u) {
        for (var u = 0u; u < 8u; u = u + 1u) {
            let in_x = base_x + u;
            let in_y = base_y + v;
            let in_idx = in_y * params.stride + in_x;

            let fx = f32(x);
            let fy = f32(y);
            let fu = f32(u);
            let fv = f32(v);

            var alpha_u = 0.5;
            var alpha_v = 0.5;
            if (u == 0u) { alpha_u = 0.353553390593; }
            if (v == 0u) { alpha_v = 0.353553390593; }

            let cos_u = cos((2.0 * fx + 1.0) * fu * PI / 16.0);
            let cos_v = cos((2.0 * fy + 1.0) * fv * PI / 16.0);

            sum += alpha_u * alpha_v * input_data[in_idx] * cos_u * cos_v;
        }
    }

    let out_x = base_x + x;
    let out_y = base_y + y;
    let out_idx = out_y * params.stride + out_x;

    output_data[out_idx] = sum;
}

// Cooley-Tukey FFT butterfly operation (radix-2)
fn fft_butterfly(
    a_real: f32, a_imag: f32,
    b_real: f32, b_imag: f32,
    twiddle_real: f32, twiddle_imag: f32,
    out_a: ptr<function, vec2<f32>>,
    out_b: ptr<function, vec2<f32>>
) {
    let t_real = b_real * twiddle_real - b_imag * twiddle_imag;
    let t_imag = b_real * twiddle_imag + b_imag * twiddle_real;

    (*out_a).x = a_real + t_real;
    (*out_a).y = a_imag + t_imag;
    (*out_b).x = a_real - t_real;
    (*out_b).y = a_imag - t_imag;
}

// 1D FFT horizontal pass (Stockham algorithm for simplicity)
@compute @workgroup_size(256, 1, 1)
fn fft_horizontal(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let y = global_id.x;

    if (y >= params.height) {
        return;
    }

    let n = params.width;
    let log2n = u32(log2(f32(n)));

    // Bit-reverse permutation
    var real_buffer: array<f32, 1024>;
    var imag_buffer: array<f32, 1024>;

    for (var i = 0u; i < n; i = i + 1u) {
        var j = 0u;
        var temp = i;
        for (var b = 0u; b < log2n; b = b + 1u) {
            j = (j << 1u) | (temp & 1u);
            temp = temp >> 1u;
        }

        let idx = y * params.stride + i;
        real_buffer[j] = input_data[idx * 2u];
        imag_buffer[j] = input_data[idx * 2u + 1u];
    }

    // Cooley-Tukey FFT
    for (var s = 1u; s <= log2n; s = s + 1u) {
        let m = 1u << s;
        let m2 = m / 2u;

        for (var k = 0u; k < n; k = k + m) {
            for (var j = 0u; j < m2; j = j + 1u) {
                let t = -2.0 * PI * f32(j) / f32(m);
                let twiddle_real = cos(t);
                let twiddle_imag = sin(t);

                let idx1 = k + j;
                let idx2 = k + j + m2;

                var out_a: vec2<f32>;
                var out_b: vec2<f32>;

                fft_butterfly(
                    real_buffer[idx1], imag_buffer[idx1],
                    real_buffer[idx2], imag_buffer[idx2],
                    twiddle_real, twiddle_imag,
                    &out_a, &out_b
                );

                real_buffer[idx1] = out_a.x;
                imag_buffer[idx1] = out_a.y;
                real_buffer[idx2] = out_b.x;
                imag_buffer[idx2] = out_b.y;
            }
        }
    }

    // Write back results
    for (var i = 0u; i < n; i = i + 1u) {
        let idx = y * params.stride + i;
        output_data[idx * 2u] = real_buffer[i];
        output_data[idx * 2u + 1u] = imag_buffer[i];
    }
}

// 2D DCT using row-column decomposition (more general, any size)
@compute @workgroup_size(256, 1, 1)
fn dct_row(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let idx = global_id.x;
    let y = idx / params.width;
    let u = idx % params.width;

    if (y >= params.height || u >= params.width) {
        return;
    }

    let n = params.width;
    var sum = 0.0;

    for (var x = 0u; x < n; x = x + 1u) {
        let in_idx = y * params.stride + x;
        let fx = f32(x);
        let fu = f32(u);
        let cos_val = cos((2.0 * fx + 1.0) * fu * PI / (2.0 * f32(n)));
        sum += input_data[in_idx] * cos_val;
    }

    var alpha = sqrt(2.0 / f32(n));
    if (u == 0u) {
        alpha = sqrt(1.0 / f32(n));
    }

    let out_idx = y * params.stride + u;
    output_data[out_idx] = alpha * sum;
}

@compute @workgroup_size(256, 1, 1)
fn dct_col(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let idx = global_id.x;
    let x = idx / params.height;
    let v = idx % params.height;

    if (x >= params.width || v >= params.height) {
        return;
    }

    let n = params.height;
    var sum = 0.0;

    for (var y = 0u; y < n; y = y + 1u) {
        let in_idx = y * params.stride + x;
        let fy = f32(y);
        let fv = f32(v);
        let cos_val = cos((2.0 * fy + 1.0) * fv * PI / (2.0 * f32(n)));
        sum += input_data[in_idx] * cos_val;
    }

    var alpha = sqrt(2.0 / f32(n));
    if (v == 0u) {
        alpha = sqrt(1.0 / f32(n));
    }

    let out_idx = v * params.stride + x;
    output_data[out_idx] = alpha * sum;
}

// Magnitude computation for frequency domain visualization
@compute @workgroup_size(256, 1, 1)
fn compute_magnitude(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let idx = global_id.x;

    if (idx >= params.width * params.height) {
        return;
    }

    let real = input_data[idx * 2u];
    let imag = input_data[idx * 2u + 1u];
    let magnitude = sqrt(real * real + imag * imag);

    output_data[idx] = magnitude;
}

// Phase computation for frequency domain analysis
@compute @workgroup_size(256, 1, 1)
fn compute_phase(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let idx = global_id.x;

    if (idx >= params.width * params.height) {
        return;
    }

    let real = input_data[idx * 2u];
    let imag = input_data[idx * 2u + 1u];
    let phase = atan2(imag, real);

    output_data[idx] = phase;
}