wagahai_lut 0.1.0

CUBE LUT parser and image processing library with SIMD
Documentation
/*
 * SPDX-FileCopyrightText: © 2026 Jinwoo Park (pmnxis@gmail.com)
 *
 * SPDX-License-Identifier: MIT
 */

//! 3D LUT data structures with Structure of Arrays (SoA) for cache efficiency

use super::Rgb;
use crate::error::{CubeError, Result};

/// Structure of Arrays (SoA) layout for cache-efficient LUT storage
/// Channels are stored separately: r_values, g_values, b_values
/// This improves cache locality for trilinear interpolation
#[derive(Clone)]
pub struct Lut3D {
    pub r: Vec<f32>,
    pub g: Vec<f32>,
    pub b: Vec<f32>,
    pub size: usize,
}

impl Lut3D {
    /// Create a new 3D LUT with SoA layout
    pub fn new(size: usize) -> Result<Self> {
        if !(2..=256).contains(&size) {
            return Err(CubeError::LutSizeOutOfRange {
                size,
                min: 2,
                max: 256,
            });
        }

        let total_points = size * size * size;
        Ok(Lut3D {
            r: vec![0.0; total_points],
            g: vec![0.0; total_points],
            b: vec![0.0; total_points],
            size,
        })
    }

    /// Get size/dimension of LUT (e.g., 33 for a 33x33x33 LUT)
    #[inline]
    pub fn size(&self) -> usize {
        self.size
    }

    /// Get dimension of LUT (alias for size())
    #[inline]
    pub fn dimension(&self) -> usize {
        self.size
    }

    /// Get raw pointers to channel data (unsafe, for maximum performance)
    #[inline]
    pub unsafe fn channel_pointers(&self) -> (*const f32, *const f32, *const f32) {
        (self.r.as_ptr(), self.g.as_ptr(), self.b.as_ptr())
    }

    /// Get value at flat index for a specific channel
    #[inline]
    pub fn get_channel(&self, flat_idx: usize, channel: usize) -> f32 {
        match channel {
            0 => self.r[flat_idx],
            1 => self.g[flat_idx],
            2 => self.b[flat_idx],
            _ => panic!("Invalid channel: {}", channel),
        }
    }

    /// Get value at flat index for a specific channel (unchecked)
    #[inline]
    pub unsafe fn get_channel_unchecked(&self, flat_idx: usize, channel: usize) -> f32 {
        match channel {
            0 => *self.r.get_unchecked(flat_idx),
            1 => *self.g.get_unchecked(flat_idx),
            2 => *self.b.get_unchecked(flat_idx),
            _ => panic!("Invalid channel: {}", channel),
        }
    }

    /// Get RGB triplet at flat index
    #[inline]
    pub fn get_rgb_flat(&self, flat_idx: usize) -> Rgb {
        [self.r[flat_idx], self.g[flat_idx], self.b[flat_idx]]
    }

    /// Get RGB triplet at flat index (unchecked)
    #[inline]
    pub unsafe fn get_rgb_flat_unchecked(&self, flat_idx: usize) -> Rgb {
        [
            *self.r.get_unchecked(flat_idx),
            *self.g.get_unchecked(flat_idx),
            *self.b.get_unchecked(flat_idx),
        ]
    }

    /// Set value at flat index for a specific channel
    #[inline]
    pub fn set_channel(&mut self, flat_idx: usize, channel: usize, value: f32) {
        match channel {
            0 => self.r[flat_idx] = value,
            1 => self.g[flat_idx] = value,
            2 => self.b[flat_idx] = value,
            _ => panic!("Invalid channel: {}", channel),
        }
    }

    /// Set RGB triplet at flat index
    #[inline]
    pub fn set_rgb_flat(&mut self, flat_idx: usize, rgb: Rgb) {
        self.r[flat_idx] = rgb[0];
        self.g[flat_idx] = rgb[1];
        self.b[flat_idx] = rgb[2];
    }

    /// Convert 3D indices to flat index (red changes fastest)
    #[inline]
    pub fn to_flat_index(&self, r: usize, g: usize, b: usize) -> usize {
        r + self.size * g + self.size * self.size * b
    }

    /// Get value at 3D indices for a specific channel
    pub fn get(&self, r: usize, g: usize, b: usize, channel: usize) -> Result<f32> {
        if channel > 2 {
            return Err(CubeError::IndexOutOfBounds {
                index: channel,
                len: 3,
            });
        }
        if r >= self.size || g >= self.size || b >= self.size {
            return Err(CubeError::IndexOutOfBounds {
                index: r.max(g).max(b),
                len: self.size,
            });
        }
        let flat_idx = self.to_flat_index(r, g, b);
        Ok(self.get_channel(flat_idx, channel))
    }

    /// Get value without bounds checking (unsafe, for performance)
    #[inline]
    pub unsafe fn get_unchecked(&self, r: usize, g: usize, b: usize, channel: usize) -> f32 {
        let flat_idx = self.to_flat_index(r, g, b);
        self.get_channel_unchecked(flat_idx, channel)
    }

    /// Set value at 3D indices for a specific channel
    pub fn set(&mut self, r: usize, g: usize, b: usize, channel: usize, value: f32) -> Result<()> {
        if channel > 2 {
            return Err(CubeError::IndexOutOfBounds {
                index: channel,
                len: 3,
            });
        }
        if r >= self.size || g >= self.size || b >= self.size {
            return Err(CubeError::IndexOutOfBounds {
                index: r.max(g).max(b),
                len: self.size,
            });
        }
        let flat_idx = self.to_flat_index(r, g, b);
        self.set_channel(flat_idx, channel, value);
        Ok(())
    }

    /// Get the full RGB triplet at 3D indices
    pub fn get_rgb(&self, r: usize, g: usize, b: usize) -> Result<Rgb> {
        if r >= self.size || g >= self.size || b >= self.size {
            return Err(CubeError::IndexOutOfBounds {
                index: r.max(g).max(b),
                len: self.size,
            });
        }
        let flat_idx = self.to_flat_index(r, g, b);
        Ok(self.get_rgb_flat(flat_idx))
    }

    /// Get RGB without bounds checking (unsafe, for performance)
    #[inline]
    pub unsafe fn get_rgb_unchecked(&self, r: usize, g: usize, b: usize) -> Rgb {
        let flat_idx = self.to_flat_index(r, g, b);
        self.get_rgb_flat_unchecked(flat_idx)
    }

    /// Set the full RGB triplet at 3D indices
    pub fn set_rgb(&mut self, r: usize, g: usize, b: usize, rgb: Rgb) -> Result<()> {
        if r >= self.size || g >= self.size || b >= self.size {
            return Err(CubeError::IndexOutOfBounds {
                index: r.max(g).max(b),
                len: self.size,
            });
        }
        let flat_idx = self.to_flat_index(r, g, b);
        self.set_rgb_flat(flat_idx, rgb);
        Ok(())
    }
}

impl std::fmt::Debug for Lut3D {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "Lut3D (size: {}x{}x{})", self.size, self.size, self.size)
    }
}

// For backward compatibility, keep Lut3DSoA as an alias
pub type Lut3DSoA = Lut3D;