training/
dataset.rs

1use burn::tensor::TensorData;
2use burn::tensor::{backend::Backend, Tensor};
3use burn_dataset::BurnBatch;
4use data_contracts::capture::CaptureMetadata;
5use data_contracts::preprocess::{stats_from_chw_f32, stats_from_rgb_u8};
6use serde::Deserialize;
7use std::fs;
8use std::path::PathBuf;
9
10#[derive(Debug, Clone, Deserialize)]
11pub struct DatasetPathConfig {
12    pub root: PathBuf,
13    pub labels_subdir: String,
14    pub images_subdir: String,
15}
16
17#[derive(Debug, Clone)]
18pub struct RunSample {
19    pub image: PathBuf,
20    pub metadata: CaptureMetadata,
21}
22
23#[derive(Debug, Clone)]
24pub struct CollatedBatch<B: Backend> {
25    pub images: Tensor<B, 4>,
26    /// Normalized boxes per sample (shape: [batch, max_boxes, 4]).
27    pub boxes: Tensor<B, 3>,
28    /// Mask indicating which box slots are populated (shape: [batch, max_boxes]).
29    pub box_mask: Tensor<B, 2>,
30    /// Global/image features per sample (mean/std RGB, aspect ratio, box count) shape [batch, F].
31    pub features: Tensor<B, 2>,
32}
33
34impl DatasetPathConfig {
35    pub fn load(&self) -> anyhow::Result<Vec<RunSample>> {
36        let mut samples = Vec::new();
37        let labels_dir = self.root.join(&self.labels_subdir);
38        for entry in fs::read_dir(&labels_dir)? {
39            let entry = entry?;
40            let path = entry.path();
41            if path.extension().and_then(|s| s.to_str()) != Some("json") {
42                continue;
43            }
44            let meta: CaptureMetadata = serde_json::from_slice(&fs::read(&path)?)?;
45            meta.validate()
46                .map_err(|e| anyhow::anyhow!("invalid metadata {:?}: {e}", path))?;
47            let img_path = self.root.join(&self.images_subdir).join(&meta.image);
48            samples.push(RunSample {
49                image: img_path,
50                metadata: meta,
51            });
52        }
53        Ok(samples)
54    }
55}
56
57pub fn collate<B: Backend>(
58    samples: &[RunSample],
59    max_boxes: usize,
60) -> anyhow::Result<CollatedBatch<B>> {
61    if samples.is_empty() {
62        anyhow::bail!("cannot collate empty batch");
63    }
64    let max_boxes = max_boxes.max(1);
65
66    // Load first image to establish dimensions.
67    let first = image::open(&samples[0].image)
68        .map_err(|e| anyhow::anyhow!("failed to open image {:?}: {e}", samples[0].image))?
69        .to_rgb8();
70    let (width, height) = first.dimensions();
71
72    let batch = samples.len();
73    let num_pixels = (width * height) as usize;
74    let mut image_buf: Vec<f32> = Vec::with_capacity(batch * num_pixels * 3);
75    let mut features: Vec<f32> = Vec::with_capacity(batch * 8); // mean/std RGB, aspect, box_count
76
77    // Gather normalized boxes, truncated to max_boxes.
78    let mut all_boxes: Vec<Vec<[f32; 4]>> = Vec::with_capacity(batch);
79
80    for (idx, sample) in samples.iter().enumerate() {
81        let img = if idx == 0 {
82            first.clone()
83        } else {
84            let img = image::open(&sample.image)
85                .map_err(|e| anyhow::anyhow!("failed to open image {:?}: {e}", sample.image))?;
86            let rgb = img.to_rgb8();
87            let (w, h) = rgb.dimensions();
88            if w != width || h != height {
89                anyhow::bail!(
90                    "image dimensions differ within batch: {:?} is {}x{}, expected {}x{}",
91                    sample.image,
92                    w,
93                    h,
94                    width,
95                    height
96                );
97            }
98            rgb
99        };
100
101        let stats = stats_from_rgb_u8(width, height, img.as_raw())
102            .map_err(|e| anyhow::anyhow!("failed to compute image stats: {e}"))?;
103
104        // Push normalized pixel data in CHW order.
105        for c in 0..3 {
106            for y in 0..height {
107                for x in 0..width {
108                    let p = img.get_pixel(x, y);
109                    let v = p[c] as f32 / 255.0;
110                    image_buf.push(v);
111                }
112            }
113        }
114
115        let mut boxes = Vec::new();
116        for label in &sample.metadata.labels {
117            let bbox = if let Some(norm) = label.bbox_norm {
118                norm
119            } else if let Some(px) = label.bbox_px {
120                [
121                    px[0] / width as f32,
122                    px[1] / height as f32,
123                    px[2] / width as f32,
124                    px[3] / height as f32,
125                ]
126            } else {
127                continue;
128            };
129            boxes.push(bbox);
130            if boxes.len() >= max_boxes {
131                break;
132            }
133        }
134        let box_count = boxes.len() as f32;
135        features.extend_from_slice(&stats.feature_vector(box_count));
136        all_boxes.push(boxes);
137    }
138
139    let mut boxes_buf = vec![0.0f32; batch * max_boxes * 4];
140    let mut mask_buf = vec![0.0f32; batch * max_boxes];
141    for (b, boxes) in all_boxes.iter().enumerate() {
142        for (i, bbox) in boxes.iter().enumerate() {
143            let base = (b * max_boxes + i) * 4;
144            boxes_buf[base..base + 4].copy_from_slice(bbox);
145            mask_buf[b * max_boxes + i] = 1.0;
146        }
147    }
148
149    let device = &B::Device::default();
150    let images = Tensor::<B, 4>::from_data(
151        TensorData::new(image_buf, [batch, 3, height as usize, width as usize]),
152        device,
153    );
154    let boxes =
155        Tensor::<B, 3>::from_data(TensorData::new(boxes_buf, [batch, max_boxes, 4]), device);
156    let box_mask = Tensor::<B, 2>::from_data(TensorData::new(mask_buf, [batch, max_boxes]), device);
157
158    let features = Tensor::<B, 2>::from_data(TensorData::new(features, [batch, 8]), device);
159
160    Ok(CollatedBatch {
161        images,
162        boxes,
163        box_mask,
164        features,
165    })
166}
167
168pub fn collate_from_burn_batch<B: Backend>(
169    batch: BurnBatch<B>,
170    max_boxes: usize,
171) -> anyhow::Result<CollatedBatch<B>> {
172    let dims = batch.images.dims();
173    let batch_size = dims[0];
174    let channels = dims[1];
175    let height = dims[2];
176    let width = dims[3];
177    let max_boxes = max_boxes.max(1);
178
179    if channels != 3 {
180        anyhow::bail!("expected 3-channel images, got {channels}");
181    }
182    let box_dims = batch.boxes.dims();
183    if box_dims[1] != max_boxes {
184        anyhow::bail!(
185            "warehouse batch max_boxes {actual} does not match requested {expected}",
186            actual = box_dims[1],
187            expected = max_boxes
188        );
189    }
190
191    let image_data = batch
192        .images
193        .clone()
194        .into_data()
195        .to_vec::<f32>()
196        .unwrap_or_default();
197    let mask_data = batch
198        .box_mask
199        .clone()
200        .into_data()
201        .to_vec::<f32>()
202        .unwrap_or_default();
203    let pixels_per_channel = height * width;
204
205    if image_data.len() != batch_size * channels * pixels_per_channel {
206        anyhow::bail!("unexpected image buffer size in warehouse batch");
207    }
208    if mask_data.len() != batch_size * max_boxes {
209        anyhow::bail!("unexpected box mask size in warehouse batch");
210    }
211
212    let mut features: Vec<f32> = Vec::with_capacity(batch_size * 8);
213    for b in 0..batch_size {
214        let start = b * channels * pixels_per_channel;
215        let slice = &image_data[start..start + channels * pixels_per_channel];
216        let stats = stats_from_chw_f32(width, height, slice)
217            .map_err(|e| anyhow::anyhow!("failed to compute image stats: {e}"))?;
218        let mask_start = b * max_boxes;
219        let box_count = mask_data[mask_start..mask_start + max_boxes]
220            .iter()
221            .filter(|v| **v > 0.0)
222            .count() as f32;
223        features.extend_from_slice(&stats.feature_vector(box_count));
224    }
225
226    let device = batch.images.device();
227    let features = Tensor::<B, 2>::from_data(TensorData::new(features, [batch_size, 8]), &device);
228
229    Ok(CollatedBatch {
230        images: batch.images,
231        boxes: batch.boxes,
232        box_mask: batch.box_mask,
233        features,
234    })
235}