bunsen_firehose_image/
lib.rs1use serde::{
2 Deserialize,
3 Serialize,
4};
5
6pub mod augmentation;
7pub mod burn_support;
8pub mod colortype_support;
9pub mod loader;
10pub mod test_util;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ImageShape {
15 pub width: u32,
17
18 pub height: u32,
20}
21
22pub use image::ColorType;
23
24#[cfg(test)]
25mod tests {
26 use std::sync::Arc;
27
28 use bunsen::support::testing::PerfTestBackend;
29 use bunsen_firehose::{
30 core::{
31 FirehoseRowBatch,
32 FirehoseTableSchema,
33 FirehoseValue,
34 operations::executor::{
35 FirehoseBatchExecutor,
36 SequentialBatchExecutor,
37 },
38 rows::{
39 FirehoseRowReader,
40 FirehoseRowWriter,
41 },
42 schema::ColumnSchema,
43 },
44 ops::init_default_operator_environment,
45 };
46 use burn::prelude::TensorData;
47 use image::{
48 ColorType,
49 DynamicImage,
50 imageops::FilterType,
51 };
52 use indoc::indoc;
53
54 use crate::{
55 ImageShape,
56 burn_support::{
57 ImageToTensorData,
58 image_to_f32_tensor,
59 },
60 loader::{
61 ImageLoader,
62 ResizeSpec,
63 },
64 test_util,
65 test_util::assert_image_close,
66 };
67
68 #[test]
69 fn test_example() -> anyhow::Result<()> {
70 let temp_dir = tempfile::tempdir().unwrap();
71
72 type B = PerfTestBackend;
73
74 let device = Default::default();
75
76 let env = Arc::new(init_default_operator_environment());
77
78 let schema = {
79 let mut schema =
80 FirehoseTableSchema::from_columns(&[
81 ColumnSchema::new::<String>("path").with_description("path to the image")
82 ]);
83
84 ImageLoader::default()
85 .with_resize(
86 ResizeSpec::new(ImageShape {
87 width: 16,
88 height: 24,
89 })
90 .with_filter(FilterType::Nearest),
91 )
92 .with_recolor(ColorType::L16)
93 .to_plan("path", "image")
94 .apply_to_schema(&mut schema, env.as_ref())?;
95
96 ImageToTensorData::default()
97 .to_plan("image", "data")
98 .apply_to_schema(&mut schema, env.as_ref())?;
99
100 Arc::new(schema)
101 };
102
103 let executor = SequentialBatchExecutor::new(schema.clone(), env.clone())?;
104
105 assert_eq!(
106 serde_json::to_string_pretty(schema.as_ref()).unwrap(),
107 indoc! {r#"
108 {
109 "columns": [
110 {
111 "name": "path",
112 "description": "path to the image",
113 "data_type": {
114 "type_name": "alloc::string::String"
115 }
116 },
117 {
118 "name": "image",
119 "description": "Image loaded from disk.",
120 "data_type": {
121 "type_name": "image::images::dynimage::DynamicImage"
122 }
123 },
124 {
125 "name": "data",
126 "description": "TensorData representation of the image.",
127 "data_type": {
128 "type_name": "burn_backend::data::tensor::TensorData"
129 }
130 }
131 ],
132 "build_plans": [
133 {
134 "operator_id": "fh:op://bunsen_firehose_image::loader::LOAD_IMAGE",
135 "description": "Loads an image from disk.",
136 "config": {
137 "recolor": "L16",
138 "resize": {
139 "filter": "Nearest",
140 "shape": {
141 "height": 24,
142 "width": 16
143 }
144 }
145 },
146 "inputs": {
147 "path": "path"
148 },
149 "outputs": {
150 "image": "image"
151 }
152 },
153 {
154 "operator_id": "fh:op://bunsen_firehose_image::burn_support::IMAGE_TO_TENSOR_DATA",
155 "description": "Converts an image to TensorData.",
156 "config": {},
157 "inputs": {
158 "image": "image"
159 },
160 "outputs": {
161 "data": "data"
162 }
163 }
164 ]
165 }"#,
166 }
167 );
168
169 let mut batch = FirehoseRowBatch::new_with_size(schema.clone(), 1);
170
171 let source_image: DynamicImage = test_util::generate_gradient_pattern(ImageShape {
172 width: 32,
173 height: 32,
174 })
175 .into();
176
177 {
178 let image_path = temp_dir
179 .path()
180 .join("test.png")
181 .to_string_lossy()
182 .to_string();
183
184 source_image
185 .save(&image_path)
186 .expect("Failed to save test image");
187
188 batch[0].expect_set("path", FirehoseValue::serialized(image_path)?);
189 }
190
191 executor.execute_batch(&mut batch)?;
192
193 let row = &batch[0];
194
195 let row_image = row.maybe_get("image").unwrap().as_ref::<DynamicImage>()?;
196 assert_image_close(
197 row_image,
198 &source_image
199 .resize_exact(16, 24, FilterType::Nearest)
200 .to_luma8()
201 .into(),
202 None,
203 );
204
205 let row_data = row.maybe_get("data").unwrap().as_ref::<TensorData>()?;
206 row_data.assert_eq(
207 &image_to_f32_tensor::<B>(row_image, &device).to_data(),
208 true,
209 );
210
211 Ok(())
212 }
213}