bimm_firehose_image/
lib.rs

1use serde::{Deserialize, Serialize};
2
3pub mod augmentation;
4pub mod burn_support;
5pub mod colortype_support;
6pub mod loader;
7pub mod test_util;
8
9/// Represents the shape of an image.
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct ImageShape {
12    /// The width of the image in pixels.
13    pub width: u32,
14
15    /// The height of the image in pixels.
16    pub height: u32,
17}
18
19pub use image::ColorType;
20
21#[cfg(test)]
22mod tests {
23    use crate::burn_support::ImageToTensorData;
24    use crate::burn_support::image_to_f32_tensor;
25    use crate::loader::{ImageLoader, ResizeSpec};
26    use crate::test_util::assert_image_close;
27    use crate::{ImageShape, test_util};
28
29    use bimm_firehose::core::FirehoseValue;
30    use bimm_firehose::core::operations::executor::FirehoseBatchExecutor;
31    use bimm_firehose::core::operations::executor::SequentialBatchExecutor;
32    use bimm_firehose::core::rows::{FirehoseRowReader, FirehoseRowWriter};
33    use bimm_firehose::core::schema::ColumnSchema;
34    use bimm_firehose::core::{FirehoseRowBatch, FirehoseTableSchema};
35    use bimm_firehose::ops::init_default_operator_environment;
36    use burn::backend::NdArray;
37    use burn::prelude::TensorData;
38    use image::imageops::FilterType;
39    use image::{ColorType, DynamicImage};
40    use indoc::indoc;
41    use std::sync::Arc;
42
43    #[test]
44    fn test_example() -> anyhow::Result<()> {
45        let temp_dir = tempfile::tempdir().unwrap();
46
47        type B = NdArray;
48
49        let device = Default::default();
50
51        let env = Arc::new(init_default_operator_environment());
52
53        let schema = {
54            let mut schema =
55                FirehoseTableSchema::from_columns(&[
56                    ColumnSchema::new::<String>("path").with_description("path to the image")
57                ]);
58
59            ImageLoader::default()
60                .with_resize(
61                    ResizeSpec::new(ImageShape {
62                        width: 16,
63                        height: 24,
64                    })
65                    .with_filter(FilterType::Nearest),
66                )
67                .with_recolor(ColorType::L16)
68                .to_plan("path", "image")
69                .apply_to_schema(&mut schema, env.as_ref())?;
70
71            ImageToTensorData::default()
72                .to_plan("image", "data")
73                .apply_to_schema(&mut schema, env.as_ref())?;
74
75            Arc::new(schema)
76        };
77
78        let executor = SequentialBatchExecutor::new(schema.clone(), env.clone())?;
79
80        assert_eq!(
81            serde_json::to_string_pretty(schema.as_ref()).unwrap(),
82            indoc! {r#"
83                {
84                  "columns": [
85                    {
86                      "name": "path",
87                      "description": "path to the image",
88                      "data_type": {
89                        "type_name": "alloc::string::String"
90                      }
91                    },
92                    {
93                      "name": "image",
94                      "description": "Image loaded from disk.",
95                      "data_type": {
96                        "type_name": "image::images::dynimage::DynamicImage"
97                      }
98                    },
99                    {
100                      "name": "data",
101                      "description": "TensorData representation of the image.",
102                      "data_type": {
103                        "type_name": "burn_tensor::tensor::data::TensorData"
104                      }
105                    }
106                  ],
107                  "build_plans": [
108                    {
109                      "operator_id": "fh:op://bimm_firehose_image::loader::LOAD_IMAGE",
110                      "description": "Loads an image from disk.",
111                      "config": {
112                        "recolor": "L16",
113                        "resize": {
114                          "filter": "Nearest",
115                          "shape": {
116                            "height": 24,
117                            "width": 16
118                          }
119                        }
120                      },
121                      "inputs": {
122                        "path": "path"
123                      },
124                      "outputs": {
125                        "image": "image"
126                      }
127                    },
128                    {
129                      "operator_id": "fh:op://bimm_firehose_image::burn_support::IMAGE_TO_TENSOR_DATA",
130                      "description": "Converts an image to TensorData.",
131                      "config": {},
132                      "inputs": {
133                        "image": "image"
134                      },
135                      "outputs": {
136                        "data": "data"
137                      }
138                    }
139                  ]
140                }"#,
141            }
142        );
143
144        let mut batch = FirehoseRowBatch::new_with_size(schema.clone(), 1);
145
146        let source_image: DynamicImage = test_util::generate_gradient_pattern(ImageShape {
147            width: 32,
148            height: 32,
149        })
150        .into();
151
152        {
153            let image_path = temp_dir
154                .path()
155                .join("test.png")
156                .to_string_lossy()
157                .to_string();
158
159            source_image
160                .save(&image_path)
161                .expect("Failed to save test image");
162
163            batch[0].expect_set("path", FirehoseValue::serialized(image_path)?);
164        }
165
166        executor.execute_batch(&mut batch)?;
167
168        let row = &batch[0];
169
170        let row_image = row.maybe_get("image").unwrap().as_ref::<DynamicImage>()?;
171        assert_image_close(
172            row_image,
173            &source_image
174                .resize_exact(16, 24, FilterType::Nearest)
175                .to_luma8()
176                .into(),
177            None,
178        );
179
180        let row_data = row.maybe_get("data").unwrap().as_ref::<TensorData>()?;
181        row_data.assert_eq(
182            &image_to_f32_tensor::<B>(row_image, &device).to_data(),
183            true,
184        );
185
186        Ok(())
187    }
188}