Skip to main content

burn_vision/utils/
save.rs

1//! Utilities for saving tensors as images
2
3use burn_tensor::{ElementConversion, Tensor, backend::Backend};
4use image::{Rgb, RgbImage};
5use std::fs;
6use std::path::Path;
7
8/// How to save a tensor as an image
9pub struct TensorDisplayOptions {
10    /// How should the dimensions be interpreted
11    pub dim_order: ImageDimOrder,
12    /// What colors should be used
13    pub color_opts: ColorDisplayOpts,
14    /// How to handle batches
15    pub batch_opts: Option<BatchDisplayOpts>,
16    /// Output image width
17    pub width_out: usize,
18    /// Output image height
19    pub height_out: usize,
20}
21
22/// How to interpret dimensions for image tensors
23pub enum ImageDimOrder {
24    /// dims: (height, width)
25    Hw,
26    /// dims: (channels, height, width)
27    Chw,
28    /// dims: (height, width, channels)
29    Hwc,
30    /// dims: (batch_size, height, width)
31    Nhw,
32    /// dims: (batch_size, channels, height, width)
33    Nchw,
34    /// dims: (batch_size, height, width, channels)
35    Nhwc,
36}
37
38/// How to translate tensor values to colors
39pub enum ColorDisplayOpts {
40    /// The values in each channel are respectively assigned to an RGB channel
41    Rgb,
42    /// The channel value is mapped between two colors
43    Monochrome {
44        /// Color assigned to the minimum value
45        min: [f32; 3],
46        /// Color assigned to the maximum value
47        max: [f32; 3],
48    },
49}
50
51/// How to handle multi-batch tensors
52#[derive(Clone, Copy, PartialEq, Eq)]
53pub enum BatchDisplayOpts {
54    /// Each item is placed consecutively in the image
55    Tiled,
56    /// Each item is aggregated
57    Aggregated,
58}
59
60/// Save a tensor of a batch of images as an image
61///
62/// * `tensor` - Image batch with shape (N, height, width)
63/// * `opts` - Options for how to draw the tensor
64/// * `path` - The file path to use
65pub fn save_tensor_as_image<B: Backend, const D: usize, P: AsRef<std::ffi::OsStr>>(
66    tensor: Tensor<B, D>,
67    opts: TensorDisplayOptions,
68    path: P,
69) -> Result<(), Box<dyn std::error::Error>> {
70    // Output file
71    let path = Path::new(&path);
72    if let Some(parent) = path.parent() {
73        fs::create_dir_all(parent)?;
74    }
75
76    let tensor = normalize(tensor);
77
78    // convert to (N,C,H,W) format
79    let tensor: Tensor<B, 4> = match opts.dim_order {
80        ImageDimOrder::Hw => {
81            let [h, w] = tensor.shape().dims();
82            tensor.reshape([1, 1, h, w])
83        }
84        ImageDimOrder::Chw => {
85            let [c, h, w] = tensor.shape().dims();
86            tensor.reshape([1, c, h, w])
87        }
88        ImageDimOrder::Hwc => {
89            let [h, w, c] = tensor.shape().dims();
90            tensor.swap_dims(0, 2).swap_dims(1, 2).reshape([1, c, h, w])
91        }
92        ImageDimOrder::Nhw => {
93            let [n, h, w] = tensor.shape().dims();
94            tensor.reshape([n, 1, h, w])
95        }
96        ImageDimOrder::Nchw => tensor.reshape([0, 0, 0, 0]),
97        ImageDimOrder::Nhwc => tensor.swap_dims(1, 3).swap_dims(2, 3).reshape([0, 0, 0, 0]),
98    };
99
100    let data = tensor.to_data();
101    let shape = data.shape.clone();
102    let (batch, channels, src_height, src_width) = (shape[0], shape[1], shape[2], shape[3]);
103
104    let mut img = if let Some(batch_opts) = &opts.batch_opts
105        && BatchDisplayOpts::Tiled == *batch_opts
106    {
107        RgbImage::new(opts.width_out as u32, (opts.height_out * batch) as u32)
108    } else {
109        RgbImage::new(opts.width_out as u32, opts.height_out as u32)
110    };
111
112    let data_vec = data.to_vec::<f32>().unwrap();
113
114    let mut channel_vals = vec![0 as f32; channels]; // value for each channel in a given pixel
115    for n in 0..batch {
116        for x in 0..opts.width_out {
117            for y in 0..opts.height_out {
118                let i = ((x as f32) / (opts.width_out as f32) * (src_width as f32))
119                    .floor()
120                    .clamp(0.0, src_width as f32) as usize;
121                let j = ((y as f32) / (opts.height_out as f32) * (src_height as f32))
122                    .floor()
123                    .clamp(0.0, src_height as f32) as usize;
124
125                for c in 0..channels {
126                    channel_vals[c] =
127                        data_vec[i + (j + (n * channels + c) * src_height) * src_width];
128                }
129
130                let (x, y) = if let Some(batch_opts) = opts.batch_opts
131                    && BatchDisplayOpts::Tiled == batch_opts
132                {
133                    let batch_x = 0;
134                    let batch_y = n as u32 * opts.height_out as u32;
135                    (x as u32 + batch_x, y as u32 + batch_y)
136                } else {
137                    (x as u32, y as u32)
138                };
139
140                let mut pixel = [0 as f32; 3];
141                match opts.color_opts {
142                    ColorDisplayOpts::Rgb => match channels {
143                        1 => {
144                            pixel[0] = channel_vals[0];
145                            pixel[1] = 0.0;
146                            pixel[2] = 0.0;
147                        }
148                        2 => {
149                            pixel[0] = channel_vals[0];
150                            pixel[1] = channel_vals[1];
151                            pixel[2] = 0.0;
152                        }
153                        3 => {
154                            pixel[0] = channel_vals[0];
155                            pixel[1] = channel_vals[1];
156                            pixel[2] = channel_vals[2];
157                        }
158                        _ => unimplemented!("More than 3 channels not supported ({channels})"),
159                    },
160                    ColorDisplayOpts::Monochrome { min, max } => {
161                        let val: f32 = channel_vals.iter().sum();
162                        pixel[0] = min[0] * (1.0 - val) + max[0] * val;
163                        pixel[1] = min[1] * (1.0 - val) + max[1] * val;
164                        pixel[2] = min[2] * (1.0 - val) + max[2] * val;
165                    }
166                }
167
168                let pixel = [
169                    (pixel[0] * 255.0) as u8,
170                    (pixel[1] * 255.0) as u8,
171                    (pixel[2] * 255.0) as u8,
172                ];
173                img.put_pixel(x, y, Rgb(pixel));
174            }
175        }
176    }
177
178    img.save(path)?;
179    Ok(())
180}
181
182/// Normalize values in 2D tensor from 0 to 1
183fn normalize<B: Backend, const D: usize>(tensor: Tensor<B, D>) -> Tensor<B, D> {
184    let min = tensor.clone().min().into_scalar().elem::<f32>();
185    let max = tensor.clone().max().into_scalar().elem::<f32>();
186    let range = if max - min == 0.0 { 1.0 } else { max - min };
187
188    tensor
189        .sub_scalar(min.elem::<f32>())
190        .div_scalar(range.elem::<f32>())
191}