use std::fs;
use std::path::PathBuf;
use data_contracts::capture::{CaptureMetadata, DetectionLabel};
use image::{Rgb, RgbImage};
use training::{collate, DatasetPathConfig};
#[test]
fn load_and_collate_synthetic() {
let temp = tempfile::tempdir().unwrap();
let labels_dir = temp.path().join("labels");
fs::create_dir_all(&labels_dir).unwrap();
let meta = CaptureMetadata {
frame_id: 1,
sim_time: 0.0,
unix_time: 0.0,
image: "frame_00001.png".into(),
image_present: true,
camera_active: true,
label_seed: 42,
labels: vec![DetectionLabel {
center_world: [0.0, 0.0, 0.0],
bbox_px: Some([0.0, 0.0, 10.0, 10.0]),
bbox_norm: Some([0.1, 0.1, 0.2, 0.2]),
source: None,
source_confidence: None,
}],
};
let json = serde_json::to_vec(&meta).unwrap();
fs::write(labels_dir.join("frame_00001.json"), json).unwrap();
let mut img = RgbImage::new(2, 2);
for pixel in img.pixels_mut() {
*pixel = Rgb([255, 0, 0]);
}
let img_path = temp.path().join("frame_00001.png");
img.save(&img_path).unwrap();
let cfg = DatasetPathConfig {
root: PathBuf::from(temp.path()),
labels_subdir: "labels".into(),
images_subdir: ".".into(),
};
let samples = cfg.load().unwrap();
assert_eq!(samples.len(), 1);
let batch = collate::<burn_ndarray::NdArray<f32>>(&samples, 4).unwrap();
assert_eq!(batch.images.dims(), [1, 3, 2, 2]);
assert_eq!(batch.boxes.dims(), [1, 4, 4]);
assert_eq!(batch.box_mask.dims(), [1, 4]);
assert_eq!(batch.features.dims(), [1, 8]);
let mask: Vec<f32> = batch
.box_mask
.into_data()
.to_vec::<f32>()
.unwrap_or_default();
assert_eq!(mask, vec![1.0, 0.0, 0.0, 0.0]);
let feats: Vec<f32> = batch
.features
.into_data()
.to_vec::<f32>()
.unwrap_or_default();
assert_eq!(feats.len(), 8);
assert!((feats[0] - 1.0).abs() < 1e-5);
assert!((feats[1] - 0.0).abs() < 1e-5);
assert!((feats[2] - 0.0).abs() < 1e-5);
assert!((feats[3]).abs() < 1e-5);
assert!((feats[4]).abs() < 1e-5);
assert!((feats[5]).abs() < 1e-5);
assert!((feats[6] - 1.0).abs() < 1e-5); assert!((feats[7] - 1.0).abs() < 1e-5); }