1use burn::tensor::TensorData;
2use burn::tensor::{backend::Backend, Tensor};
3use burn_dataset::BurnBatch;
4use data_contracts::capture::CaptureMetadata;
5use data_contracts::preprocess::{stats_from_chw_f32, stats_from_rgb_u8};
6use serde::Deserialize;
7use std::fs;
8use std::path::PathBuf;
9
10#[derive(Debug, Clone, Deserialize)]
11pub struct DatasetPathConfig {
12 pub root: PathBuf,
13 pub labels_subdir: String,
14 pub images_subdir: String,
15}
16
17#[derive(Debug, Clone)]
18pub struct RunSample {
19 pub image: PathBuf,
20 pub metadata: CaptureMetadata,
21}
22
23#[derive(Debug, Clone)]
24pub struct CollatedBatch<B: Backend> {
25 pub images: Tensor<B, 4>,
26 pub boxes: Tensor<B, 3>,
28 pub box_mask: Tensor<B, 2>,
30 pub features: Tensor<B, 2>,
32}
33
34impl DatasetPathConfig {
35 pub fn load(&self) -> anyhow::Result<Vec<RunSample>> {
36 let mut samples = Vec::new();
37 let labels_dir = self.root.join(&self.labels_subdir);
38 for entry in fs::read_dir(&labels_dir)? {
39 let entry = entry?;
40 let path = entry.path();
41 if path.extension().and_then(|s| s.to_str()) != Some("json") {
42 continue;
43 }
44 let meta: CaptureMetadata = serde_json::from_slice(&fs::read(&path)?)?;
45 meta.validate()
46 .map_err(|e| anyhow::anyhow!("invalid metadata {:?}: {e}", path))?;
47 let img_path = self.root.join(&self.images_subdir).join(&meta.image);
48 samples.push(RunSample {
49 image: img_path,
50 metadata: meta,
51 });
52 }
53 Ok(samples)
54 }
55}
56
57pub fn collate<B: Backend>(
58 samples: &[RunSample],
59 max_boxes: usize,
60) -> anyhow::Result<CollatedBatch<B>> {
61 if samples.is_empty() {
62 anyhow::bail!("cannot collate empty batch");
63 }
64 let max_boxes = max_boxes.max(1);
65
66 let first = image::open(&samples[0].image)
68 .map_err(|e| anyhow::anyhow!("failed to open image {:?}: {e}", samples[0].image))?
69 .to_rgb8();
70 let (width, height) = first.dimensions();
71
72 let batch = samples.len();
73 let num_pixels = (width * height) as usize;
74 let mut image_buf: Vec<f32> = Vec::with_capacity(batch * num_pixels * 3);
75 let mut features: Vec<f32> = Vec::with_capacity(batch * 8); let mut all_boxes: Vec<Vec<[f32; 4]>> = Vec::with_capacity(batch);
79
80 for (idx, sample) in samples.iter().enumerate() {
81 let img = if idx == 0 {
82 first.clone()
83 } else {
84 let img = image::open(&sample.image)
85 .map_err(|e| anyhow::anyhow!("failed to open image {:?}: {e}", sample.image))?;
86 let rgb = img.to_rgb8();
87 let (w, h) = rgb.dimensions();
88 if w != width || h != height {
89 anyhow::bail!(
90 "image dimensions differ within batch: {:?} is {}x{}, expected {}x{}",
91 sample.image,
92 w,
93 h,
94 width,
95 height
96 );
97 }
98 rgb
99 };
100
101 let stats = stats_from_rgb_u8(width, height, img.as_raw())
102 .map_err(|e| anyhow::anyhow!("failed to compute image stats: {e}"))?;
103
104 for c in 0..3 {
106 for y in 0..height {
107 for x in 0..width {
108 let p = img.get_pixel(x, y);
109 let v = p[c] as f32 / 255.0;
110 image_buf.push(v);
111 }
112 }
113 }
114
115 let mut boxes = Vec::new();
116 for label in &sample.metadata.labels {
117 let bbox = if let Some(norm) = label.bbox_norm {
118 norm
119 } else if let Some(px) = label.bbox_px {
120 [
121 px[0] / width as f32,
122 px[1] / height as f32,
123 px[2] / width as f32,
124 px[3] / height as f32,
125 ]
126 } else {
127 continue;
128 };
129 boxes.push(bbox);
130 if boxes.len() >= max_boxes {
131 break;
132 }
133 }
134 let box_count = boxes.len() as f32;
135 features.extend_from_slice(&stats.feature_vector(box_count));
136 all_boxes.push(boxes);
137 }
138
139 let mut boxes_buf = vec![0.0f32; batch * max_boxes * 4];
140 let mut mask_buf = vec![0.0f32; batch * max_boxes];
141 for (b, boxes) in all_boxes.iter().enumerate() {
142 for (i, bbox) in boxes.iter().enumerate() {
143 let base = (b * max_boxes + i) * 4;
144 boxes_buf[base..base + 4].copy_from_slice(bbox);
145 mask_buf[b * max_boxes + i] = 1.0;
146 }
147 }
148
149 let device = &B::Device::default();
150 let images = Tensor::<B, 4>::from_data(
151 TensorData::new(image_buf, [batch, 3, height as usize, width as usize]),
152 device,
153 );
154 let boxes =
155 Tensor::<B, 3>::from_data(TensorData::new(boxes_buf, [batch, max_boxes, 4]), device);
156 let box_mask = Tensor::<B, 2>::from_data(TensorData::new(mask_buf, [batch, max_boxes]), device);
157
158 let features = Tensor::<B, 2>::from_data(TensorData::new(features, [batch, 8]), device);
159
160 Ok(CollatedBatch {
161 images,
162 boxes,
163 box_mask,
164 features,
165 })
166}
167
168pub fn collate_from_burn_batch<B: Backend>(
169 batch: BurnBatch<B>,
170 max_boxes: usize,
171) -> anyhow::Result<CollatedBatch<B>> {
172 let dims = batch.images.dims();
173 let batch_size = dims[0];
174 let channels = dims[1];
175 let height = dims[2];
176 let width = dims[3];
177 let max_boxes = max_boxes.max(1);
178
179 if channels != 3 {
180 anyhow::bail!("expected 3-channel images, got {channels}");
181 }
182 let box_dims = batch.boxes.dims();
183 if box_dims[1] != max_boxes {
184 anyhow::bail!(
185 "warehouse batch max_boxes {actual} does not match requested {expected}",
186 actual = box_dims[1],
187 expected = max_boxes
188 );
189 }
190
191 let image_data = batch
192 .images
193 .clone()
194 .into_data()
195 .to_vec::<f32>()
196 .unwrap_or_default();
197 let mask_data = batch
198 .box_mask
199 .clone()
200 .into_data()
201 .to_vec::<f32>()
202 .unwrap_or_default();
203 let pixels_per_channel = height * width;
204
205 if image_data.len() != batch_size * channels * pixels_per_channel {
206 anyhow::bail!("unexpected image buffer size in warehouse batch");
207 }
208 if mask_data.len() != batch_size * max_boxes {
209 anyhow::bail!("unexpected box mask size in warehouse batch");
210 }
211
212 let mut features: Vec<f32> = Vec::with_capacity(batch_size * 8);
213 for b in 0..batch_size {
214 let start = b * channels * pixels_per_channel;
215 let slice = &image_data[start..start + channels * pixels_per_channel];
216 let stats = stats_from_chw_f32(width, height, slice)
217 .map_err(|e| anyhow::anyhow!("failed to compute image stats: {e}"))?;
218 let mask_start = b * max_boxes;
219 let box_count = mask_data[mask_start..mask_start + max_boxes]
220 .iter()
221 .filter(|v| **v > 0.0)
222 .count() as f32;
223 features.extend_from_slice(&stats.feature_vector(box_count));
224 }
225
226 let device = batch.images.device();
227 let features = Tensor::<B, 2>::from_data(TensorData::new(features, [batch_size, 8]), &device);
228
229 Ok(CollatedBatch {
230 images: batch.images,
231 boxes: batch.boxes,
232 box_mask: batch.box_mask,
233 features,
234 })
235}