1use burn::tensor::TensorData;
2use burn::tensor::{backend::Backend, Tensor};
3use data_contracts::capture::CaptureMetadata;
4use serde::Deserialize;
5use std::fs;
6use std::path::PathBuf;
7
8#[derive(Debug, Clone, Deserialize)]
9pub struct DatasetConfig {
10 pub root: PathBuf,
11 pub labels_subdir: String,
12 pub images_subdir: String,
13}
14
15#[derive(Debug, Clone)]
16pub struct RunSample {
17 pub image: PathBuf,
18 pub metadata: CaptureMetadata,
19}
20
21#[derive(Debug, Clone)]
22pub struct CollatedBatch<B: Backend> {
23 pub images: Tensor<B, 4>,
24 pub boxes: Tensor<B, 3>,
26 pub box_mask: Tensor<B, 2>,
28 pub features: Tensor<B, 2>,
30}
31
32impl DatasetConfig {
33 pub fn load(&self) -> anyhow::Result<Vec<RunSample>> {
34 let mut samples = Vec::new();
35 let labels_dir = self.root.join(&self.labels_subdir);
36 for entry in fs::read_dir(&labels_dir)? {
37 let entry = entry?;
38 let path = entry.path();
39 if path.extension().and_then(|s| s.to_str()) != Some("json") {
40 continue;
41 }
42 let meta: CaptureMetadata = serde_json::from_slice(&fs::read(&path)?)?;
43 meta.validate()
44 .map_err(|e| anyhow::anyhow!("invalid metadata {:?}: {e}", path))?;
45 let img_path = self.root.join(&self.images_subdir).join(&meta.image);
46 samples.push(RunSample {
47 image: img_path,
48 metadata: meta,
49 });
50 }
51 Ok(samples)
52 }
53}
54
55pub fn collate<B: Backend>(
56 samples: &[RunSample],
57 max_boxes: usize,
58) -> anyhow::Result<CollatedBatch<B>> {
59 if samples.is_empty() {
60 anyhow::bail!("cannot collate empty batch");
61 }
62 let max_boxes = max_boxes.max(1);
63
64 let first = image::open(&samples[0].image)
66 .map_err(|e| anyhow::anyhow!("failed to open image {:?}: {e}", samples[0].image))?
67 .to_rgb8();
68 let (width, height) = first.dimensions();
69
70 let batch = samples.len();
71 let num_pixels = (width * height) as usize;
72 let mut image_buf: Vec<f32> = Vec::with_capacity(batch * num_pixels * 3);
73 let mut features: Vec<f32> = Vec::with_capacity(batch * 6); let mut all_boxes: Vec<Vec<[f32; 4]>> = Vec::with_capacity(batch);
77
78 for (idx, sample) in samples.iter().enumerate() {
79 let img = if idx == 0 {
80 first.clone()
81 } else {
82 let img = image::open(&sample.image)
83 .map_err(|e| anyhow::anyhow!("failed to open image {:?}: {e}", sample.image))?;
84 let rgb = img.to_rgb8();
85 let (w, h) = rgb.dimensions();
86 if w != width || h != height {
87 anyhow::bail!(
88 "image dimensions differ within batch: {:?} is {}x{}, expected {}x{}",
89 sample.image,
90 w,
91 h,
92 width,
93 height
94 );
95 }
96 rgb
97 };
98
99 let mut sum = [0f32; 3];
101 let mut sumsq = [0f32; 3];
102 for c in 0..3 {
103 for y in 0..height {
104 for x in 0..width {
105 let p = img.get_pixel(x, y);
106 let v = p[c] as f32 / 255.0;
107 image_buf.push(v);
108 sum[c] += v;
109 sumsq[c] += v * v;
110 }
111 }
112 }
113 let pix_count = (width * height) as f32;
114 let mean = [sum[0] / pix_count, sum[1] / pix_count, sum[2] / pix_count];
115 let std = [
116 (sumsq[0] / pix_count - mean[0] * mean[0]).max(0.0).sqrt(),
117 (sumsq[1] / pix_count - mean[1] * mean[1]).max(0.0).sqrt(),
118 (sumsq[2] / pix_count - mean[2] * mean[2]).max(0.0).sqrt(),
119 ];
120
121 let mut boxes = Vec::new();
122 for label in &sample.metadata.polyp_labels {
123 let bbox = if let Some(norm) = label.bbox_norm {
124 norm
125 } else if let Some(px) = label.bbox_px {
126 [
127 px[0] / width as f32,
128 px[1] / height as f32,
129 px[2] / width as f32,
130 px[3] / height as f32,
131 ]
132 } else {
133 continue;
134 };
135 boxes.push(bbox);
136 if boxes.len() >= max_boxes {
137 break;
138 }
139 }
140 let box_count = boxes.len() as f32;
141 features.extend_from_slice(&[
142 mean[0],
143 mean[1],
144 mean[2],
145 std[0],
146 std[1],
147 std[2],
148 width as f32 / height as f32,
149 box_count,
150 ]);
151 all_boxes.push(boxes);
152 }
153
154 let mut boxes_buf = vec![0.0f32; batch * max_boxes * 4];
155 let mut mask_buf = vec![0.0f32; batch * max_boxes];
156 for (b, boxes) in all_boxes.iter().enumerate() {
157 for (i, bbox) in boxes.iter().enumerate() {
158 let base = (b * max_boxes + i) * 4;
159 boxes_buf[base..base + 4].copy_from_slice(bbox);
160 mask_buf[b * max_boxes + i] = 1.0;
161 }
162 }
163
164 let device = &B::Device::default();
165 let images = Tensor::<B, 4>::from_data(
166 TensorData::new(image_buf, [batch, 3, height as usize, width as usize]),
167 device,
168 );
169 let boxes =
170 Tensor::<B, 3>::from_data(TensorData::new(boxes_buf, [batch, max_boxes, 4]), device);
171 let box_mask = Tensor::<B, 2>::from_data(TensorData::new(mask_buf, [batch, max_boxes]), device);
172
173 let features = Tensor::<B, 2>::from_data(TensorData::new(features, [batch, 8]), device);
174
175 Ok(CollatedBatch {
176 images,
177 boxes,
178 box_mask,
179 features,
180 })
181}