Skip to main content

bunsen_firehose_image/
lib.rs

1use serde::{
2    Deserialize,
3    Serialize,
4};
5
6pub mod augmentation;
7pub mod burn_support;
8pub mod colortype_support;
9pub mod loader;
10pub mod test_util;
11
12/// Represents the shape of an image.
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ImageShape {
15    /// The width of the image in pixels.
16    pub width: u32,
17
18    /// The height of the image in pixels.
19    pub height: u32,
20}
21
22pub use image::ColorType;
23
24#[cfg(test)]
25mod tests {
26    use std::sync::Arc;
27
28    use bunsen::support::testing::PerfTestBackend;
29    use bunsen_firehose::{
30        core::{
31            FirehoseRowBatch,
32            FirehoseTableSchema,
33            FirehoseValue,
34            operations::executor::{
35                FirehoseBatchExecutor,
36                SequentialBatchExecutor,
37            },
38            rows::{
39                FirehoseRowReader,
40                FirehoseRowWriter,
41            },
42            schema::ColumnSchema,
43        },
44        ops::init_default_operator_environment,
45    };
46    use burn::prelude::TensorData;
47    use image::{
48        ColorType,
49        DynamicImage,
50        imageops::FilterType,
51    };
52    use indoc::indoc;
53
54    use crate::{
55        ImageShape,
56        burn_support::{
57            ImageToTensorData,
58            image_to_f32_tensor,
59        },
60        loader::{
61            ImageLoader,
62            ResizeSpec,
63        },
64        test_util,
65        test_util::assert_image_close,
66    };
67
68    #[test]
69    fn test_example() -> anyhow::Result<()> {
70        let temp_dir = tempfile::tempdir().unwrap();
71
72        type B = PerfTestBackend;
73
74        let device = Default::default();
75
76        let env = Arc::new(init_default_operator_environment());
77
78        let schema = {
79            let mut schema =
80                FirehoseTableSchema::from_columns(&[
81                    ColumnSchema::new::<String>("path").with_description("path to the image")
82                ]);
83
84            ImageLoader::default()
85                .with_resize(
86                    ResizeSpec::new(ImageShape {
87                        width: 16,
88                        height: 24,
89                    })
90                    .with_filter(FilterType::Nearest),
91                )
92                .with_recolor(ColorType::L16)
93                .to_plan("path", "image")
94                .apply_to_schema(&mut schema, env.as_ref())?;
95
96            ImageToTensorData::default()
97                .to_plan("image", "data")
98                .apply_to_schema(&mut schema, env.as_ref())?;
99
100            Arc::new(schema)
101        };
102
103        let executor = SequentialBatchExecutor::new(schema.clone(), env.clone())?;
104
105        assert_eq!(
106            serde_json::to_string_pretty(schema.as_ref()).unwrap(),
107            indoc! {r#"
108                {
109                  "columns": [
110                    {
111                      "name": "path",
112                      "description": "path to the image",
113                      "data_type": {
114                        "type_name": "alloc::string::String"
115                      }
116                    },
117                    {
118                      "name": "image",
119                      "description": "Image loaded from disk.",
120                      "data_type": {
121                        "type_name": "image::images::dynimage::DynamicImage"
122                      }
123                    },
124                    {
125                      "name": "data",
126                      "description": "TensorData representation of the image.",
127                      "data_type": {
128                        "type_name": "burn_backend::data::tensor::TensorData"
129                      }
130                    }
131                  ],
132                  "build_plans": [
133                    {
134                      "operator_id": "fh:op://bunsen_firehose_image::loader::LOAD_IMAGE",
135                      "description": "Loads an image from disk.",
136                      "config": {
137                        "recolor": "L16",
138                        "resize": {
139                          "filter": "Nearest",
140                          "shape": {
141                            "height": 24,
142                            "width": 16
143                          }
144                        }
145                      },
146                      "inputs": {
147                        "path": "path"
148                      },
149                      "outputs": {
150                        "image": "image"
151                      }
152                    },
153                    {
154                      "operator_id": "fh:op://bunsen_firehose_image::burn_support::IMAGE_TO_TENSOR_DATA",
155                      "description": "Converts an image to TensorData.",
156                      "config": {},
157                      "inputs": {
158                        "image": "image"
159                      },
160                      "outputs": {
161                        "data": "data"
162                      }
163                    }
164                  ]
165                }"#,
166            }
167        );
168
169        let mut batch = FirehoseRowBatch::new_with_size(schema.clone(), 1);
170
171        let source_image: DynamicImage = test_util::generate_gradient_pattern(ImageShape {
172            width: 32,
173            height: 32,
174        })
175        .into();
176
177        {
178            let image_path = temp_dir
179                .path()
180                .join("test.png")
181                .to_string_lossy()
182                .to_string();
183
184            source_image
185                .save(&image_path)
186                .expect("Failed to save test image");
187
188            batch[0].expect_set("path", FirehoseValue::serialized(image_path)?);
189        }
190
191        executor.execute_batch(&mut batch)?;
192
193        let row = &batch[0];
194
195        let row_image = row.maybe_get("image").unwrap().as_ref::<DynamicImage>()?;
196        assert_image_close(
197            row_image,
198            &source_image
199                .resize_exact(16, 24, FilterType::Nearest)
200                .to_luma8()
201                .into(),
202            None,
203        );
204
205        let row_data = row.maybe_get("data").unwrap().as_ref::<TensorData>()?;
206        row_data.assert_eq(
207            &image_to_f32_tensor::<B>(row_image, &device).to_data(),
208            true,
209        );
210
211        Ok(())
212    }
213}