1use burn_tensor::{ElementConversion, Tensor, backend::Backend};
4use image::{Rgb, RgbImage};
5use std::fs;
6use std::path::Path;
7
8pub struct TensorDisplayOptions {
10 pub dim_order: ImageDimOrder,
12 pub color_opts: ColorDisplayOpts,
14 pub batch_opts: Option<BatchDisplayOpts>,
16 pub width_out: usize,
18 pub height_out: usize,
20}
21
22pub enum ImageDimOrder {
24 Hw,
26 Chw,
28 Hwc,
30 Nhw,
32 Nchw,
34 Nhwc,
36}
37
38pub enum ColorDisplayOpts {
40 Rgb,
42 Monochrome {
44 min: [f32; 3],
46 max: [f32; 3],
48 },
49}
50
51#[derive(Clone, Copy, PartialEq, Eq)]
53pub enum BatchDisplayOpts {
54 Tiled,
56 Aggregated,
58}
59
60pub fn save_tensor_as_image<B: Backend, const D: usize, P: AsRef<std::ffi::OsStr>>(
66 tensor: Tensor<B, D>,
67 opts: TensorDisplayOptions,
68 path: P,
69) -> Result<(), Box<dyn std::error::Error>> {
70 let path = Path::new(&path);
72 if let Some(parent) = path.parent() {
73 fs::create_dir_all(parent)?;
74 }
75
76 let tensor = normalize(tensor);
77
78 let tensor: Tensor<B, 4> = match opts.dim_order {
80 ImageDimOrder::Hw => {
81 let [h, w] = tensor.shape().dims();
82 tensor.reshape([1, 1, h, w])
83 }
84 ImageDimOrder::Chw => {
85 let [c, h, w] = tensor.shape().dims();
86 tensor.reshape([1, c, h, w])
87 }
88 ImageDimOrder::Hwc => {
89 let [h, w, c] = tensor.shape().dims();
90 tensor.swap_dims(0, 2).swap_dims(1, 2).reshape([1, c, h, w])
91 }
92 ImageDimOrder::Nhw => {
93 let [n, h, w] = tensor.shape().dims();
94 tensor.reshape([n, 1, h, w])
95 }
96 ImageDimOrder::Nchw => tensor.reshape([0, 0, 0, 0]),
97 ImageDimOrder::Nhwc => tensor.swap_dims(1, 3).swap_dims(2, 3).reshape([0, 0, 0, 0]),
98 };
99
100 let data = tensor.to_data();
101 let shape = data.shape.clone();
102 let (batch, channels, src_height, src_width) = (shape[0], shape[1], shape[2], shape[3]);
103
104 let mut img = if let Some(batch_opts) = &opts.batch_opts
105 && BatchDisplayOpts::Tiled == *batch_opts
106 {
107 RgbImage::new(opts.width_out as u32, (opts.height_out * batch) as u32)
108 } else {
109 RgbImage::new(opts.width_out as u32, opts.height_out as u32)
110 };
111
112 let data_vec = data.to_vec::<f32>().unwrap();
113
114 let mut channel_vals = vec![0 as f32; channels]; for n in 0..batch {
116 for x in 0..opts.width_out {
117 for y in 0..opts.height_out {
118 let i = ((x as f32) / (opts.width_out as f32) * (src_width as f32))
119 .floor()
120 .clamp(0.0, src_width as f32) as usize;
121 let j = ((y as f32) / (opts.height_out as f32) * (src_height as f32))
122 .floor()
123 .clamp(0.0, src_height as f32) as usize;
124
125 for c in 0..channels {
126 channel_vals[c] =
127 data_vec[i + (j + (n * channels + c) * src_height) * src_width];
128 }
129
130 let (x, y) = if let Some(batch_opts) = opts.batch_opts
131 && BatchDisplayOpts::Tiled == batch_opts
132 {
133 let batch_x = 0;
134 let batch_y = n as u32 * opts.height_out as u32;
135 (x as u32 + batch_x, y as u32 + batch_y)
136 } else {
137 (x as u32, y as u32)
138 };
139
140 let mut pixel = [0 as f32; 3];
141 match opts.color_opts {
142 ColorDisplayOpts::Rgb => match channels {
143 1 => {
144 pixel[0] = channel_vals[0];
145 pixel[1] = 0.0;
146 pixel[2] = 0.0;
147 }
148 2 => {
149 pixel[0] = channel_vals[0];
150 pixel[1] = channel_vals[1];
151 pixel[2] = 0.0;
152 }
153 3 => {
154 pixel[0] = channel_vals[0];
155 pixel[1] = channel_vals[1];
156 pixel[2] = channel_vals[2];
157 }
158 _ => unimplemented!("More than 3 channels not supported ({channels})"),
159 },
160 ColorDisplayOpts::Monochrome { min, max } => {
161 let val: f32 = channel_vals.iter().sum();
162 pixel[0] = min[0] * (1.0 - val) + max[0] * val;
163 pixel[1] = min[1] * (1.0 - val) + max[1] * val;
164 pixel[2] = min[2] * (1.0 - val) + max[2] * val;
165 }
166 }
167
168 let pixel = [
169 (pixel[0] * 255.0) as u8,
170 (pixel[1] * 255.0) as u8,
171 (pixel[2] * 255.0) as u8,
172 ];
173 img.put_pixel(x, y, Rgb(pixel));
174 }
175 }
176 }
177
178 img.save(path)?;
179 Ok(())
180}
181
182fn normalize<B: Backend, const D: usize>(tensor: Tensor<B, D>) -> Tensor<B, D> {
184 let min = tensor.clone().min().into_scalar().elem::<f32>();
185 let max = tensor.clone().max().into_scalar().elem::<f32>();
186 let range = if max - min == 0.0 { 1.0 } else { max - min };
187
188 tensor
189 .sub_scalar(min.elem::<f32>())
190 .div_scalar(range.elem::<f32>())
191}