tenflowers-dataset 0.1.1

Data pipeline and dataset utilities for TenfloweRS
Documentation
// Rotation compute shader for GPU-accelerated image rotation

struct Uniforms {
    width: u32,
    height: u32,
    channels: u32,
    padding: u32,
    cos_angle: f32,
    sin_angle: f32,
    center_x: f32,
    center_y: f32,
}

@group(0) @binding(0) var<storage, read> input_data: array<f32>;
@group(0) @binding(1) var<storage, read_write> output_data: array<f32>;
@group(0) @binding(2) var<uniform> uniforms: Uniforms;

fn bilinear_interpolate(x: f32, y: f32, channel: u32) -> f32 {
    let x0 = u32(floor(x));
    let y0 = u32(floor(y));
    let x1 = min(x0 + 1u, uniforms.width - 1u);
    let y1 = min(y0 + 1u, uniforms.height - 1u);
    
    let fx = x - f32(x0);
    let fy = y - f32(y0);
    
    if (x0 >= uniforms.width || y0 >= uniforms.height) {
        return 0.0;
    }
    
    let idx00 = channel * uniforms.width * uniforms.height + y0 * uniforms.width + x0;
    let idx01 = channel * uniforms.width * uniforms.height + y0 * uniforms.width + x1;
    let idx10 = channel * uniforms.width * uniforms.height + y1 * uniforms.width + x0;
    let idx11 = channel * uniforms.width * uniforms.height + y1 * uniforms.width + x1;
    
    let v00 = input_data[idx00];
    let v01 = input_data[idx01];
    let v10 = input_data[idx10];
    let v11 = input_data[idx11];
    
    let v0 = v00 * (1.0 - fx) + v01 * fx;
    let v1 = v10 * (1.0 - fx) + v11 * fx;
    
    return v0 * (1.0 - fy) + v1 * fy;
}

@compute @workgroup_size(16, 16, 1)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let x = global_id.x;
    let y = global_id.y;
    let channel = global_id.z;

    // Check bounds
    if (x >= uniforms.width || y >= uniforms.height || channel >= uniforms.channels) {
        return;
    }

    // Translate to center
    let centered_x = f32(x) - uniforms.center_x;
    let centered_y = f32(y) - uniforms.center_y;

    // Apply rotation
    let rotated_x = centered_x * uniforms.cos_angle - centered_y * uniforms.sin_angle;
    let rotated_y = centered_x * uniforms.sin_angle + centered_y * uniforms.cos_angle;

    // Translate back
    let source_x = rotated_x + uniforms.center_x;
    let source_y = rotated_y + uniforms.center_y;

    // Calculate output index
    let output_index = channel * uniforms.width * uniforms.height + y * uniforms.width + x;

    // Sample with bilinear interpolation
    if (source_x >= 0.0 && source_x < f32(uniforms.width) && 
        source_y >= 0.0 && source_y < f32(uniforms.height)) {
        output_data[output_index] = bilinear_interpolate(source_x, source_y, channel);
    } else {
        // Fill with black for out-of-bounds pixels
        output_data[output_index] = 0.0;
    }
}