burn-vision 0.20.1

Vision processing operations for burn tensors
Documentation
use std::path::PathBuf;

use burn_tensor::{Shape, Tensor, TensorData, backend::Backend};
use image::{DynamicImage, ImageBuffer, Luma, Rgb};

use burn_tensor::{Bool, Int};

#[cfg(all(
    any(feature = "test-cpu", feature = "ndarray"),
    not(any(feature = "test-wgpu", feature = "test-cuda"))
))]
pub type TestBackend = burn_ndarray::NdArray<f32, i32>;

#[cfg(all(test, feature = "test-wgpu"))]
pub type TestBackend = burn_wgpu::Wgpu;

#[cfg(all(test, feature = "test-cuda"))]
pub type TestBackend = burn_cuda::Cuda;

#[allow(unused)]
pub type TestTensor<const D: usize> = burn_tensor::Tensor<TestBackend, D>;
pub type TestTensorInt<const D: usize> = burn_tensor::Tensor<TestBackend, D, Int>;
pub type TestTensorBool<const D: usize> = burn_tensor::Tensor<TestBackend, D, Bool>;

#[allow(unused)]
pub type IntType = <TestBackend as burn_tensor::backend::Backend>::IntElem;

#[allow(missing_docs)]
#[macro_export]
macro_rules! as_type {
    ($ty:ident: [$($elem:tt),*]) => {
        [$($crate::as_type![$ty: $elem]),*]
    };
    ($ty:ident: [$($elem:tt,)*]) => {
        [$($crate::as_type![$ty: $elem]),*]
    };
    ($ty:ident: $elem:expr) => {
        {
            use cubecl::prelude::*;

            $ty::new($elem)
        }
    };
}

#[allow(unused)]
pub fn test_image<B: Backend>(name: &str, device: &B::Device, luma: bool) -> Tensor<B, 3> {
    let file = PathBuf::from("tests/images").join(name);
    let image = image::open(file).unwrap();
    if luma {
        let image = image.to_luma32f();
        let h = image.height() as usize;
        let w = image.width() as usize;
        let data = TensorData::new(image.into_vec(), Shape::new([h, w, 1]));
        Tensor::from_data(data, device)
    } else {
        let image = image.to_rgb32f();
        let h = image.height() as usize;
        let w = image.width() as usize;
        let data = TensorData::new(image.into_vec(), Shape::new([h, w, 3]));
        Tensor::from_data(data, device)
    }
}

#[allow(unused)]
pub fn save_test_image<B: Backend>(name: &str, tensor: Tensor<B, 3>, luma: bool) {
    let file = PathBuf::from("tests/images").join(name);
    let [h, w, _] = tensor.shape().dims();
    let data = tensor
        .into_data()
        .convert::<f32>()
        .into_vec::<f32>()
        .unwrap();
    if luma {
        let image = ImageBuffer::<Luma<f32>, _>::from_raw(w as u32, h as u32, data).unwrap();
        DynamicImage::from(image).to_luma8().save(file).unwrap();
    } else {
        let image = ImageBuffer::<Rgb<f32>, _>::from_raw(w as u32, h as u32, data).unwrap();
        DynamicImage::from(image).to_rgb8().save(file).unwrap();
    }
}