Skip to main content

rs_cinic_10_burn/
lib.rs

1use anyhow::Result;
2use burn::prelude::{Backend, Tensor, TensorData};
3use burn::tensor;
4use rs_cinic_10_index::images::{RgbImageBatch, load_bhwc_rgbimagebatch};
5use rs_cinic_10_index::index::DatasetIndex;
6use std::path::Path;
7
8fn batch_to_tensordata(batch: RgbImageBatch) -> TensorData {
9    TensorData::from_bytes(batch.data, batch.shape, tensor::DType::U8)
10}
11
12pub fn load_bhwc_u8_tensordata_image_batch<P>(paths: &[P]) -> Result<TensorData>
13where
14    P: AsRef<Path>,
15{
16    let batch = load_bhwc_rgbimagebatch(paths)?;
17    let tensor_data = batch_to_tensordata(batch);
18    Ok(tensor_data)
19}
20
21pub fn load_bhwc_u8_tensor_image_batch<B, P>(
22    paths: &[P],
23    device: &B::Device,
24) -> Result<Tensor<B, 4>>
25where
26    B: Backend,
27    P: AsRef<Path>,
28{
29    let data = load_bhwc_u8_tensordata_image_batch(paths)?;
30    let tensor = Tensor::from_data(data, device);
31    Ok(tensor)
32}
33
34pub fn load_hwc_u8_tensor_image<B, P>(
35    path: P,
36    device: &B::Device,
37) -> Result<Tensor<B, 3>>
38where
39    B: Backend,
40    P: AsRef<Path>,
41{
42    let paths = vec![path.as_ref()];
43
44    let batch = load_bhwc_u8_tensor_image_batch(&paths, device)?;
45    let tensor = batch.squeeze(0);
46
47    Ok(tensor)
48}
49
50pub trait WithTensorBatches {
51    fn load_tensor<B>(
52        &self,
53        index: usize,
54        device: &B::Device,
55    ) -> Result<Tensor<B, 3>>
56    where
57        B: Backend,
58    {
59        Ok(self.load_tensor_batch(&[index], device)?.squeeze(0))
60    }
61
62    fn load_tensor_batch<B>(
63        &self,
64        indexes: &[usize],
65        device: &B::Device,
66    ) -> Result<Tensor<B, 4>>
67    where
68        B: Backend;
69}
70
71impl WithTensorBatches for DatasetIndex {
72    fn load_tensor_batch<B>(
73        &self,
74        indexes: &[usize],
75        device: &B::Device,
76    ) -> Result<Tensor<B, 4>>
77    where
78        B: Backend,
79    {
80        let paths = self.indices_to_paths(indexes);
81        let tensor = load_bhwc_u8_tensor_image_batch(&paths, device)?;
82        Ok(tensor)
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89    use burn::backend::NdArray;
90
91    use rs_cinic_10_index::index::{CHANNELS, HEIGHT, ObjectClass, SAMPLES_PER_CLASS, WIDTH};
92    use rs_cinic_10_index::{Cinic10Index, default_data_path_or_panic};
93
94    #[test]
95    fn test_load_image() -> Result<()> {
96        let root_path = default_data_path_or_panic();
97        let path = root_path.join("train/airplane/cifar10-train-3318.png");
98
99        let device = Default::default();
100        let tensor: Tensor<NdArray, 3> = load_hwc_u8_tensor_image(&path, &device)?;
101
102        assert_eq!(tensor.dims(), [32, 32, 3]);
103
104        Ok(())
105    }
106
107    #[test]
108    fn test_load_image_batch() -> Result<()> {
109        let root_path = default_data_path_or_panic();
110        let paths = vec![
111            root_path.join("train/airplane/cifar10-train-3318.png"),
112            root_path.join("train/airplane/cifar10-train-3318.png"),
113        ];
114
115        let device = Default::default();
116        let tensor: Tensor<NdArray, 4> = load_bhwc_u8_tensor_image_batch(&paths, &device)?;
117
118        assert_eq!(tensor.dims(), [2, 32, 32, 3]);
119
120        Ok(())
121    }
122
123    #[test]
124    fn test_load_test_batch() -> Result<()> {
125        let cinic: Cinic10Index = Default::default();
126        let indices = (0..3).map(|i| i * SAMPLES_PER_CLASS).collect::<Vec<_>>();
127
128        let device = Default::default();
129        let tensor: Tensor<NdArray, 4> = cinic.test.load_tensor_batch(&indices, &device)?;
130        let classes = cinic.test.indices_to_classes(&indices);
131
132        assert_eq!(tensor.dims(), [3, HEIGHT, WIDTH, CHANNELS]);
133        assert_eq!(
134            classes,
135            vec![
136                ObjectClass::Airplane,
137                ObjectClass::Automobile,
138                ObjectClass::Bird
139            ]
140        );
141
142        Ok(())
143    }
144}