training/
dataset.rs

1use burn::tensor::TensorData;
2use burn::tensor::{backend::Backend, Tensor};
3use data_contracts::capture::CaptureMetadata;
4use serde::Deserialize;
5use std::fs;
6use std::path::PathBuf;
7
8#[derive(Debug, Clone, Deserialize)]
9pub struct DatasetConfig {
10    pub root: PathBuf,
11    pub labels_subdir: String,
12    pub images_subdir: String,
13}
14
15#[derive(Debug, Clone)]
16pub struct RunSample {
17    pub image: PathBuf,
18    pub metadata: CaptureMetadata,
19}
20
21#[derive(Debug, Clone)]
22pub struct CollatedBatch<B: Backend> {
23    pub images: Tensor<B, 4>,
24    /// Normalized boxes per sample (shape: [batch, max_boxes, 4]).
25    pub boxes: Tensor<B, 3>,
26    /// Mask indicating which box slots are populated (shape: [batch, max_boxes]).
27    pub box_mask: Tensor<B, 2>,
28    /// Global/image features per sample (mean/std RGB, aspect ratio, box count) shape [batch, F].
29    pub features: Tensor<B, 2>,
30}
31
32impl DatasetConfig {
33    pub fn load(&self) -> anyhow::Result<Vec<RunSample>> {
34        let mut samples = Vec::new();
35        let labels_dir = self.root.join(&self.labels_subdir);
36        for entry in fs::read_dir(&labels_dir)? {
37            let entry = entry?;
38            let path = entry.path();
39            if path.extension().and_then(|s| s.to_str()) != Some("json") {
40                continue;
41            }
42            let meta: CaptureMetadata = serde_json::from_slice(&fs::read(&path)?)?;
43            meta.validate()
44                .map_err(|e| anyhow::anyhow!("invalid metadata {:?}: {e}", path))?;
45            let img_path = self.root.join(&self.images_subdir).join(&meta.image);
46            samples.push(RunSample {
47                image: img_path,
48                metadata: meta,
49            });
50        }
51        Ok(samples)
52    }
53}
54
55pub fn collate<B: Backend>(
56    samples: &[RunSample],
57    max_boxes: usize,
58) -> anyhow::Result<CollatedBatch<B>> {
59    if samples.is_empty() {
60        anyhow::bail!("cannot collate empty batch");
61    }
62    let max_boxes = max_boxes.max(1);
63
64    // Load first image to establish dimensions.
65    let first = image::open(&samples[0].image)
66        .map_err(|e| anyhow::anyhow!("failed to open image {:?}: {e}", samples[0].image))?
67        .to_rgb8();
68    let (width, height) = first.dimensions();
69
70    let batch = samples.len();
71    let num_pixels = (width * height) as usize;
72    let mut image_buf: Vec<f32> = Vec::with_capacity(batch * num_pixels * 3);
73    let mut features: Vec<f32> = Vec::with_capacity(batch * 6); // mean/std RGB, aspect, box_count
74
75    // Gather normalized boxes, truncated to max_boxes.
76    let mut all_boxes: Vec<Vec<[f32; 4]>> = Vec::with_capacity(batch);
77
78    for (idx, sample) in samples.iter().enumerate() {
79        let img = if idx == 0 {
80            first.clone()
81        } else {
82            let img = image::open(&sample.image)
83                .map_err(|e| anyhow::anyhow!("failed to open image {:?}: {e}", sample.image))?;
84            let rgb = img.to_rgb8();
85            let (w, h) = rgb.dimensions();
86            if w != width || h != height {
87                anyhow::bail!(
88                    "image dimensions differ within batch: {:?} is {}x{}, expected {}x{}",
89                    sample.image,
90                    w,
91                    h,
92                    width,
93                    height
94                );
95            }
96            rgb
97        };
98
99        // Push normalized pixel data in CHW order.
100        let mut sum = [0f32; 3];
101        let mut sumsq = [0f32; 3];
102        for c in 0..3 {
103            for y in 0..height {
104                for x in 0..width {
105                    let p = img.get_pixel(x, y);
106                    let v = p[c] as f32 / 255.0;
107                    image_buf.push(v);
108                    sum[c] += v;
109                    sumsq[c] += v * v;
110                }
111            }
112        }
113        let pix_count = (width * height) as f32;
114        let mean = [sum[0] / pix_count, sum[1] / pix_count, sum[2] / pix_count];
115        let std = [
116            (sumsq[0] / pix_count - mean[0] * mean[0]).max(0.0).sqrt(),
117            (sumsq[1] / pix_count - mean[1] * mean[1]).max(0.0).sqrt(),
118            (sumsq[2] / pix_count - mean[2] * mean[2]).max(0.0).sqrt(),
119        ];
120
121        let mut boxes = Vec::new();
122        for label in &sample.metadata.polyp_labels {
123            let bbox = if let Some(norm) = label.bbox_norm {
124                norm
125            } else if let Some(px) = label.bbox_px {
126                [
127                    px[0] / width as f32,
128                    px[1] / height as f32,
129                    px[2] / width as f32,
130                    px[3] / height as f32,
131                ]
132            } else {
133                continue;
134            };
135            boxes.push(bbox);
136            if boxes.len() >= max_boxes {
137                break;
138            }
139        }
140        let box_count = boxes.len() as f32;
141        features.extend_from_slice(&[
142            mean[0],
143            mean[1],
144            mean[2],
145            std[0],
146            std[1],
147            std[2],
148            width as f32 / height as f32,
149            box_count,
150        ]);
151        all_boxes.push(boxes);
152    }
153
154    let mut boxes_buf = vec![0.0f32; batch * max_boxes * 4];
155    let mut mask_buf = vec![0.0f32; batch * max_boxes];
156    for (b, boxes) in all_boxes.iter().enumerate() {
157        for (i, bbox) in boxes.iter().enumerate() {
158            let base = (b * max_boxes + i) * 4;
159            boxes_buf[base..base + 4].copy_from_slice(bbox);
160            mask_buf[b * max_boxes + i] = 1.0;
161        }
162    }
163
164    let device = &B::Device::default();
165    let images = Tensor::<B, 4>::from_data(
166        TensorData::new(image_buf, [batch, 3, height as usize, width as usize]),
167        device,
168    );
169    let boxes =
170        Tensor::<B, 3>::from_data(TensorData::new(boxes_buf, [batch, max_boxes, 4]), device);
171    let box_mask = Tensor::<B, 2>::from_data(TensorData::new(mask_buf, [batch, max_boxes]), device);
172
173    let features = Tensor::<B, 2>::from_data(TensorData::new(features, [batch, 8]), device);
174
175    Ok(CollatedBatch {
176        images,
177        boxes,
178        box_mask,
179        features,
180    })
181}