tenflowers-dataset 0.1.1

Data pipeline and dataset utilities for TenfloweRS
Documentation
// GPU-accelerated color jitter shader
// Adjusts brightness, contrast, saturation, and hue of RGB images

struct Uniforms {
    width: u32,
    height: u32,
    channels: u32,
    padding: u32,
    brightness: f32,
    contrast: f32,
    saturation: f32,
    hue: f32,
};

@group(0) @binding(0) var<storage, read> input_data: array<f32>;
@group(0) @binding(1) var<storage, read_write> output_data: array<f32>;
@group(0) @binding(2) var<uniform> uniforms: Uniforms;

fn get_input_pixel(channel: u32, y: u32, x: u32) -> f32 {
    let idx = channel * uniforms.height * uniforms.width + y * uniforms.width + x;
    if (idx >= arrayLength(&input_data)) {
        return 0.0;
    }
    return input_data[idx];
}

fn set_output_pixel(channel: u32, y: u32, x: u32, value: f32) {
    let idx = channel * uniforms.height * uniforms.width + y * uniforms.width + x;
    if (idx < arrayLength(&output_data)) {
        output_data[idx] = value;
    }
}

fn rgb_to_hsv(rgb: vec3<f32>) -> vec3<f32> {
    let max_val = max(max(rgb.r, rgb.g), rgb.b);
    let min_val = min(min(rgb.r, rgb.g), rgb.b);
    let delta = max_val - min_val;
    
    var h: f32;
    var s: f32;
    let v = max_val;
    
    if (delta == 0.0) {
        h = 0.0;
        s = 0.0;
    } else {
        s = delta / max_val;
        
        if (rgb.r == max_val) {
            h = (rgb.g - rgb.b) / delta;
        } else if (rgb.g == max_val) {
            h = 2.0 + (rgb.b - rgb.r) / delta;
        } else {
            h = 4.0 + (rgb.r - rgb.g) / delta;
        }
        
        h = h * 60.0;
        if (h < 0.0) {
            h = h + 360.0;
        }
    }
    
    return vec3<f32>(h, s, v);
}

fn hsv_to_rgb(hsv: vec3<f32>) -> vec3<f32> {
    let h = hsv.x;
    let s = hsv.y;
    let v = hsv.z;
    
    if (s == 0.0) {
        return vec3<f32>(v, v, v);
    }
    
    let h_sector = h / 60.0;
    let i = u32(floor(h_sector));
    let f = h_sector - f32(i);
    
    let p = v * (1.0 - s);
    let q = v * (1.0 - s * f);
    let t = v * (1.0 - s * (1.0 - f));
    
    switch (i % 6u) {
        case 0u: { return vec3<f32>(v, t, p); }
        case 1u: { return vec3<f32>(q, v, p); }
        case 2u: { return vec3<f32>(p, v, t); }
        case 3u: { return vec3<f32>(p, q, v); }
        case 4u: { return vec3<f32>(t, p, v); }
        default: { return vec3<f32>(v, p, q); }
    }
}

@compute @workgroup_size(8, 8, 1)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let x = global_id.x;
    let y = global_id.y;
    
    if (x >= uniforms.width || y >= uniforms.height) {
        return;
    }
    
    // Get RGB values
    let r = get_input_pixel(0u, y, x);
    let g = get_input_pixel(1u, y, x);
    let b = get_input_pixel(2u, y, x);
    
    var rgb = vec3<f32>(r, g, b);
    
    // Apply brightness adjustment
    rgb = rgb + uniforms.brightness;
    
    // Apply contrast adjustment
    rgb = (rgb - 0.5) * uniforms.contrast + 0.5;
    
    // Convert to HSV for saturation and hue adjustments
    var hsv = rgb_to_hsv(rgb);
    
    // Apply saturation adjustment
    hsv.y = hsv.y * uniforms.saturation;
    
    // Apply hue adjustment
    hsv.x = hsv.x + uniforms.hue;
    if (hsv.x < 0.0) {
        hsv.x = hsv.x + 360.0;
    } else if (hsv.x >= 360.0) {
        hsv.x = hsv.x - 360.0;
    }
    
    // Convert back to RGB
    rgb = hsv_to_rgb(hsv);
    
    // Clamp values to [0, 1]
    rgb = clamp(rgb, vec3<f32>(0.0), vec3<f32>(1.0));
    
    // Store result
    set_output_pixel(0u, y, x, rgb.r);
    set_output_pixel(1u, y, x, rgb.g);
    set_output_pixel(2u, y, x, rgb.b);
}