1use anyhow::Result;
2use burn::prelude::{Backend, Tensor, TensorData};
3use burn::tensor;
4use rs_cinic_10_index::images::{RgbImageBatch, load_bhwc_rgbimagebatch};
5use rs_cinic_10_index::index::DatasetIndex;
6use std::path::Path;
7
8fn batch_to_tensordata(batch: RgbImageBatch) -> TensorData {
9 TensorData::from_bytes(batch.data, batch.shape, tensor::DType::U8)
10}
11
12pub fn load_bhwc_u8_tensordata_image_batch<P>(paths: &[P]) -> Result<TensorData>
13where
14 P: AsRef<Path>,
15{
16 let batch = load_bhwc_rgbimagebatch(paths)?;
17 let tensor_data = batch_to_tensordata(batch);
18 Ok(tensor_data)
19}
20
21pub fn load_bhwc_u8_tensor_image_batch<B, P>(
22 paths: &[P],
23 device: &B::Device,
24) -> Result<Tensor<B, 4>>
25where
26 B: Backend,
27 P: AsRef<Path>,
28{
29 let data = load_bhwc_u8_tensordata_image_batch(paths)?;
30 let tensor = Tensor::from_data(data, device);
31 Ok(tensor)
32}
33
34pub fn load_hwc_u8_tensor_image<B, P>(
35 path: P,
36 device: &B::Device,
37) -> Result<Tensor<B, 3>>
38where
39 B: Backend,
40 P: AsRef<Path>,
41{
42 let paths = vec![path.as_ref()];
43
44 let batch = load_bhwc_u8_tensor_image_batch(&paths, device)?;
45 let tensor = batch.squeeze(0);
46
47 Ok(tensor)
48}
49
50pub trait WithTensorBatches {
51 fn load_tensor<B>(
52 &self,
53 index: usize,
54 device: &B::Device,
55 ) -> Result<Tensor<B, 3>>
56 where
57 B: Backend,
58 {
59 Ok(self.load_tensor_batch(&[index], device)?.squeeze(0))
60 }
61
62 fn load_tensor_batch<B>(
63 &self,
64 indexes: &[usize],
65 device: &B::Device,
66 ) -> Result<Tensor<B, 4>>
67 where
68 B: Backend;
69}
70
71impl WithTensorBatches for DatasetIndex {
72 fn load_tensor_batch<B>(
73 &self,
74 indexes: &[usize],
75 device: &B::Device,
76 ) -> Result<Tensor<B, 4>>
77 where
78 B: Backend,
79 {
80 let paths = self.indices_to_paths(indexes);
81 let tensor = load_bhwc_u8_tensor_image_batch(&paths, device)?;
82 Ok(tensor)
83 }
84}
85
86#[cfg(test)]
87mod tests {
88 use super::*;
89 use burn::backend::NdArray;
90
91 use rs_cinic_10_index::index::{CHANNELS, HEIGHT, ObjectClass, SAMPLES_PER_CLASS, WIDTH};
92 use rs_cinic_10_index::{Cinic10Index, default_data_path_or_panic};
93
94 #[test]
95 fn test_load_image() -> Result<()> {
96 let root_path = default_data_path_or_panic();
97 let path = root_path.join("train/airplane/cifar10-train-3318.png");
98
99 let device = Default::default();
100 let tensor: Tensor<NdArray, 3> = load_hwc_u8_tensor_image(&path, &device)?;
101
102 assert_eq!(tensor.dims(), [32, 32, 3]);
103
104 Ok(())
105 }
106
107 #[test]
108 fn test_load_image_batch() -> Result<()> {
109 let root_path = default_data_path_or_panic();
110 let paths = vec![
111 root_path.join("train/airplane/cifar10-train-3318.png"),
112 root_path.join("train/airplane/cifar10-train-3318.png"),
113 ];
114
115 let device = Default::default();
116 let tensor: Tensor<NdArray, 4> = load_bhwc_u8_tensor_image_batch(&paths, &device)?;
117
118 assert_eq!(tensor.dims(), [2, 32, 32, 3]);
119
120 Ok(())
121 }
122
123 #[test]
124 fn test_load_test_batch() -> Result<()> {
125 let cinic: Cinic10Index = Default::default();
126 let indices = (0..3).map(|i| i * SAMPLES_PER_CLASS).collect::<Vec<_>>();
127
128 let device = Default::default();
129 let tensor: Tensor<NdArray, 4> = cinic.test.load_tensor_batch(&indices, &device)?;
130 let classes = cinic.test.indices_to_classes(&indices);
131
132 assert_eq!(tensor.dims(), [3, HEIGHT, WIDTH, CHANNELS]);
133 assert_eq!(
134 classes,
135 vec![
136 ObjectClass::Airplane,
137 ObjectClass::Automobile,
138 ObjectClass::Bird
139 ]
140 );
141
142 Ok(())
143 }
144}