bimm_firehose_image/
lib.rs1use serde::{Deserialize, Serialize};
2
3pub mod augmentation;
4pub mod burn_support;
5pub mod colortype_support;
6pub mod loader;
7pub mod test_util;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct ImageShape {
12 pub width: u32,
14
15 pub height: u32,
17}
18
19pub use image::ColorType;
20
21#[cfg(test)]
22mod tests {
23 use crate::burn_support::ImageToTensorData;
24 use crate::burn_support::image_to_f32_tensor;
25 use crate::loader::{ImageLoader, ResizeSpec};
26 use crate::test_util::assert_image_close;
27 use crate::{ImageShape, test_util};
28
29 use bimm_firehose::core::FirehoseValue;
30 use bimm_firehose::core::operations::executor::FirehoseBatchExecutor;
31 use bimm_firehose::core::operations::executor::SequentialBatchExecutor;
32 use bimm_firehose::core::rows::{FirehoseRowReader, FirehoseRowWriter};
33 use bimm_firehose::core::schema::ColumnSchema;
34 use bimm_firehose::core::{FirehoseRowBatch, FirehoseTableSchema};
35 use bimm_firehose::ops::init_default_operator_environment;
36 use burn::backend::NdArray;
37 use burn::prelude::TensorData;
38 use image::imageops::FilterType;
39 use image::{ColorType, DynamicImage};
40 use indoc::indoc;
41 use std::sync::Arc;
42
43 #[test]
44 fn test_example() -> anyhow::Result<()> {
45 let temp_dir = tempfile::tempdir().unwrap();
46
47 type B = NdArray;
48
49 let device = Default::default();
50
51 let env = Arc::new(init_default_operator_environment());
52
53 let schema = {
54 let mut schema =
55 FirehoseTableSchema::from_columns(&[
56 ColumnSchema::new::<String>("path").with_description("path to the image")
57 ]);
58
59 ImageLoader::default()
60 .with_resize(
61 ResizeSpec::new(ImageShape {
62 width: 16,
63 height: 24,
64 })
65 .with_filter(FilterType::Nearest),
66 )
67 .with_recolor(ColorType::L16)
68 .to_plan("path", "image")
69 .apply_to_schema(&mut schema, env.as_ref())?;
70
71 ImageToTensorData::default()
72 .to_plan("image", "data")
73 .apply_to_schema(&mut schema, env.as_ref())?;
74
75 Arc::new(schema)
76 };
77
78 let executor = SequentialBatchExecutor::new(schema.clone(), env.clone())?;
79
80 assert_eq!(
81 serde_json::to_string_pretty(schema.as_ref()).unwrap(),
82 indoc! {r#"
83 {
84 "columns": [
85 {
86 "name": "path",
87 "description": "path to the image",
88 "data_type": {
89 "type_name": "alloc::string::String"
90 }
91 },
92 {
93 "name": "image",
94 "description": "Image loaded from disk.",
95 "data_type": {
96 "type_name": "image::images::dynimage::DynamicImage"
97 }
98 },
99 {
100 "name": "data",
101 "description": "TensorData representation of the image.",
102 "data_type": {
103 "type_name": "burn_tensor::tensor::data::TensorData"
104 }
105 }
106 ],
107 "build_plans": [
108 {
109 "operator_id": "fh:op://bimm_firehose_image::loader::LOAD_IMAGE",
110 "description": "Loads an image from disk.",
111 "config": {
112 "recolor": "L16",
113 "resize": {
114 "filter": "Nearest",
115 "shape": {
116 "height": 24,
117 "width": 16
118 }
119 }
120 },
121 "inputs": {
122 "path": "path"
123 },
124 "outputs": {
125 "image": "image"
126 }
127 },
128 {
129 "operator_id": "fh:op://bimm_firehose_image::burn_support::IMAGE_TO_TENSOR_DATA",
130 "description": "Converts an image to TensorData.",
131 "config": {},
132 "inputs": {
133 "image": "image"
134 },
135 "outputs": {
136 "data": "data"
137 }
138 }
139 ]
140 }"#,
141 }
142 );
143
144 let mut batch = FirehoseRowBatch::new_with_size(schema.clone(), 1);
145
146 let source_image: DynamicImage = test_util::generate_gradient_pattern(ImageShape {
147 width: 32,
148 height: 32,
149 })
150 .into();
151
152 {
153 let image_path = temp_dir
154 .path()
155 .join("test.png")
156 .to_string_lossy()
157 .to_string();
158
159 source_image
160 .save(&image_path)
161 .expect("Failed to save test image");
162
163 batch[0].expect_set("path", FirehoseValue::serialized(image_path)?);
164 }
165
166 executor.execute_batch(&mut batch)?;
167
168 let row = &batch[0];
169
170 let row_image = row.maybe_get("image").unwrap().as_ref::<DynamicImage>()?;
171 assert_image_close(
172 row_image,
173 &source_image
174 .resize_exact(16, 24, FilterType::Nearest)
175 .to_luma8()
176 .into(),
177 None,
178 );
179
180 let row_data = row.maybe_get("data").unwrap().as_ref::<TensorData>()?;
181 row_data.assert_eq(
182 &image_to_f32_tensor::<B>(row_image, &device).to_data(),
183 true,
184 );
185
186 Ok(())
187 }
188}