use crate::{GpuDevice, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TransformType {
DCT,
IDCT,
FFT,
IFFT,
Rotate90,
Rotate180,
Rotate270,
FlipHorizontal,
FlipVertical,
Transpose,
Affine,
Perspective,
}
pub struct TransformKernel {
transform_type: TransformType,
}
impl TransformKernel {
#[must_use]
pub fn new(transform_type: TransformType) -> Self {
Self { transform_type }
}
#[must_use]
pub fn dct() -> Self {
Self::new(TransformType::DCT)
}
#[must_use]
pub fn idct() -> Self {
Self::new(TransformType::IDCT)
}
#[must_use]
pub fn rotate(degrees: i32) -> Self {
let transform_type = match degrees % 360 {
90 | -270 => TransformType::Rotate90,
180 | -180 => TransformType::Rotate180,
270 | -90 => TransformType::Rotate270,
_ => TransformType::Rotate90, };
Self::new(transform_type)
}
#[must_use]
pub fn flip(horizontal: bool) -> Self {
let transform_type = if horizontal {
TransformType::FlipHorizontal
} else {
TransformType::FlipVertical
};
Self::new(transform_type)
}
pub fn execute(
&self,
device: &GpuDevice,
input: &[f32],
output: &mut [f32],
width: u32,
height: u32,
) -> Result<()> {
match self.transform_type {
TransformType::DCT => {
crate::ops::TransformOperation::dct_2d(device, input, output, width, height)
}
TransformType::IDCT => {
crate::ops::TransformOperation::idct_2d(device, input, output, width, height)
}
_ => Err(crate::GpuError::NotSupported(format!(
"Transform type {:?} not yet implemented",
self.transform_type
))),
}
}
#[must_use]
pub fn transform_type(&self) -> TransformType {
self.transform_type
}
#[must_use]
pub fn is_frequency_domain(&self) -> bool {
matches!(
self.transform_type,
TransformType::DCT | TransformType::IDCT | TransformType::FFT | TransformType::IFFT
)
}
#[must_use]
pub fn is_geometric(&self) -> bool {
matches!(
self.transform_type,
TransformType::Rotate90
| TransformType::Rotate180
| TransformType::Rotate270
| TransformType::FlipHorizontal
| TransformType::FlipVertical
| TransformType::Transpose
| TransformType::Affine
| TransformType::Perspective
)
}
#[must_use]
pub fn estimate_flops(width: u32, height: u32, transform_type: TransformType) -> u64 {
let n = u64::from(width) * u64::from(height);
match transform_type {
TransformType::DCT | TransformType::IDCT => {
let log_n = (n as f64).log2().ceil() as u64;
n * n * log_n
}
TransformType::FFT | TransformType::IFFT => {
let log_n = (n as f64).log2().ceil() as u64;
n * log_n * 5 }
_ => {
n
}
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct AffineMatrix {
pub elements: [f32; 6],
}
impl AffineMatrix {
#[must_use]
pub fn identity() -> Self {
Self {
elements: [1.0, 0.0, 0.0, 0.0, 1.0, 0.0],
}
}
#[must_use]
pub fn translation(tx: f32, ty: f32) -> Self {
Self {
elements: [1.0, 0.0, tx, 0.0, 1.0, ty],
}
}
#[must_use]
pub fn rotation(angle_radians: f32) -> Self {
let cos = angle_radians.cos();
let sin = angle_radians.sin();
Self {
elements: [cos, -sin, 0.0, sin, cos, 0.0],
}
}
#[must_use]
pub fn scaling(sx: f32, sy: f32) -> Self {
Self {
elements: [sx, 0.0, 0.0, 0.0, sy, 0.0],
}
}
#[must_use]
pub fn combine(&self, other: &Self) -> Self {
let a1 = self.elements;
let a2 = other.elements;
Self {
elements: [
a1[0] * a2[0] + a1[1] * a2[3],
a1[0] * a2[1] + a1[1] * a2[4],
a1[0] * a2[2] + a1[1] * a2[5] + a1[2],
a1[3] * a2[0] + a1[4] * a2[3],
a1[3] * a2[1] + a1[4] * a2[4],
a1[3] * a2[2] + a1[4] * a2[5] + a1[5],
],
}
}
#[must_use]
pub fn as_array(&self) -> [f32; 6] {
self.elements
}
}
impl Default for AffineMatrix {
fn default() -> Self {
Self::identity()
}
}
pub struct WarpKernel {
matrix: AffineMatrix,
}
impl WarpKernel {
#[must_use]
pub fn new(matrix: AffineMatrix) -> Self {
Self { matrix }
}
#[must_use]
pub fn rotation(angle_degrees: f32, center_x: f32, center_y: f32) -> Self {
let angle_radians = angle_degrees.to_radians();
let t1 = AffineMatrix::translation(-center_x, -center_y);
let r = AffineMatrix::rotation(angle_radians);
let t2 = AffineMatrix::translation(center_x, center_y);
let matrix = t1.combine(&r).combine(&t2);
Self::new(matrix)
}
#[must_use]
pub fn scaling(sx: f32, sy: f32, center_x: f32, center_y: f32) -> Self {
let t1 = AffineMatrix::translation(-center_x, -center_y);
let s = AffineMatrix::scaling(sx, sy);
let t2 = AffineMatrix::translation(center_x, center_y);
let matrix = t1.combine(&s).combine(&t2);
Self::new(matrix)
}
#[must_use]
pub fn matrix(&self) -> &AffineMatrix {
&self.matrix
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_transform_kernel_creation() {
let kernel = TransformKernel::dct();
assert_eq!(kernel.transform_type(), TransformType::DCT);
assert!(kernel.is_frequency_domain());
assert!(!kernel.is_geometric());
let kernel = TransformKernel::rotate(90);
assert_eq!(kernel.transform_type(), TransformType::Rotate90);
assert!(!kernel.is_frequency_domain());
assert!(kernel.is_geometric());
}
#[test]
fn test_affine_matrix_identity() {
let identity = AffineMatrix::identity();
let elements = identity.as_array();
assert_eq!(elements, [1.0, 0.0, 0.0, 0.0, 1.0, 0.0]);
}
#[test]
fn test_affine_matrix_translation() {
let trans = AffineMatrix::translation(10.0, 20.0);
let elements = trans.as_array();
assert_eq!(elements[2], 10.0);
assert_eq!(elements[5], 20.0);
}
#[test]
fn test_affine_matrix_scaling() {
let scale = AffineMatrix::scaling(2.0, 3.0);
let elements = scale.as_array();
assert_eq!(elements[0], 2.0);
assert_eq!(elements[4], 3.0);
}
#[test]
fn test_affine_matrix_combination() {
let t1 = AffineMatrix::translation(10.0, 20.0);
let s = AffineMatrix::scaling(2.0, 2.0);
let combined = t1.combine(&s);
assert!(combined.elements[0] > 0.0);
}
#[test]
fn test_flops_estimation() {
let flops_dct = TransformKernel::estimate_flops(64, 64, TransformType::DCT);
let flops_rotate = TransformKernel::estimate_flops(64, 64, TransformType::Rotate90);
assert!(flops_dct > 0);
assert!(flops_rotate > 0);
assert!(flops_dct > flops_rotate); }
}