1use burn::backend::Autodiff;
2use burn::module::Module;
3use burn::nn::loss::{MseLoss, Reduction};
4use burn::optim::{AdamConfig, GradientsParams, Optimizer};
5use burn::record::{BinFileRecorder, FullPrecisionSettings, RecorderError};
6use burn::tensor::{Tensor, TensorData};
7use std::path::Path;
8
9use crate::{BigDet, BigDetConfig, DatasetConfig, TinyDet, TinyDetConfig, TrainBackend};
10use clap::{Parser, ValueEnum};
11use std::fs;
12
13pub fn load_tinydet_from_checkpoint<P: AsRef<Path>>(
14 path: P,
15 device: &<TrainBackend as burn::tensor::backend::Backend>::Device,
16) -> Result<TinyDet<TrainBackend>, RecorderError> {
17 let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
18 TinyDet::<TrainBackend>::new(TinyDetConfig::default(), device).load_file(
19 path.as_ref(),
20 &recorder,
21 device,
22 )
23}
24
25#[derive(ValueEnum, Debug, Clone, Copy)]
26pub enum ModelKind {
27 Tiny,
28 Big,
29}
30
31#[derive(ValueEnum, Debug, Clone, Copy)]
32pub enum BackendKind {
33 NdArray,
34 Wgpu,
35}
36
37#[derive(Parser, Debug)]
38#[command(name = "train", about = "Train TinyDet/BigDet on capture metadata")]
39pub struct TrainArgs {
40 #[arg(long, value_enum, default_value_t = ModelKind::Tiny)]
42 pub model: ModelKind,
43 #[arg(long, value_enum, default_value_t = BackendKind::NdArray)]
45 pub backend: BackendKind,
46 #[arg(long, default_value_t = 64)]
48 pub max_boxes: usize,
49 #[arg(long, default_value_t = 1.0)]
51 pub lambda_box: f32,
52 #[arg(long, default_value_t = 1.0)]
54 pub lambda_obj: f32,
55 #[arg(long, default_value = "assets/datasets/captures_filtered")]
57 pub dataset_root: String,
58 #[arg(long, default_value = "labels")]
60 pub labels_subdir: String,
61 #[arg(long, default_value = ".")]
63 pub images_subdir: String,
64 #[arg(long, default_value_t = 1)]
66 pub epochs: usize,
67 #[arg(long, default_value_t = 1)]
69 pub batch_size: usize,
70 #[arg(long, default_value_t = 1e-3)]
72 pub lr: f32,
73 #[arg(long, default_value_t = 0.3)]
75 pub infer_obj_thresh: f32,
76 #[arg(long, default_value_t = 0.5)]
78 pub infer_iou_thresh: f32,
79 #[arg(long)]
81 pub checkpoint_out: Option<String>,
82}
83
84pub fn run_train(args: TrainArgs) -> anyhow::Result<()> {
85 validate_backend_choice(args.backend)?;
86
87 let cfg = DatasetConfig {
88 root: args.dataset_root.clone().into(),
89 labels_subdir: args.labels_subdir.clone(),
90 images_subdir: args.images_subdir.clone(),
91 };
92 let samples = cfg.load()?;
93 if samples.is_empty() {
94 println!("No samples found under {}", cfg.root.display());
95 return Ok(());
96 }
97
98 let ckpt_path = args
99 .checkpoint_out
100 .clone()
101 .unwrap_or_else(|| match args.model {
102 ModelKind::Tiny => "checkpoints/tinydet.bin".to_string(),
103 ModelKind::Big => "checkpoints/bigdet.bin".to_string(),
104 });
105
106 if let Some(parent) = Path::new(&ckpt_path).parent() {
107 fs::create_dir_all(parent)?;
108 }
109
110 match args.model {
111 ModelKind::Tiny => train_tinydet(&args, &samples, &ckpt_path)?,
112 ModelKind::Big => train_bigdet(&args, &samples, &ckpt_path)?,
113 }
114
115 println!("Saved checkpoint to {}", ckpt_path);
116 Ok(())
117}
118
119type ADBackend = Autodiff<TrainBackend>;
120
121fn train_tinydet(
122 args: &TrainArgs,
123 samples: &[crate::RunSample],
124 ckpt_path: &str,
125) -> anyhow::Result<()> {
126 let device = <ADBackend as burn::tensor::backend::Backend>::Device::default();
127 let mut model = TinyDet::<ADBackend>::new(TinyDetConfig::default(), &device);
128 let mut optim = AdamConfig::new().init();
129
130 let batch_size = args.batch_size.max(1);
131 let data = samples.to_vec();
132 for epoch in 0..args.epochs {
133 let mut losses = Vec::new();
134 for batch in data.chunks(batch_size) {
135 let batch = crate::collate::<ADBackend>(batch, args.max_boxes)?;
136 let boxes = batch.boxes.clone();
138 let first_box = boxes
139 .clone()
140 .slice([0..boxes.dims()[0], 0..1, 0..4])
141 .reshape([boxes.dims()[0], 4]);
142
143 let mask = batch.box_mask.clone();
145 let has_box = mask.clone().sum_dim(1).reshape([mask.dims()[0], 1]);
146
147 let preds = model.forward(first_box);
148 let mse = MseLoss::new();
149 let loss = mse.forward(preds, has_box, Reduction::Mean);
150 let loss_detached = loss.clone().detach();
151 let grads = GradientsParams::from_grads(loss.backward(), &model);
152 model = optim.step(args.lr as f64, model, grads);
153
154 let loss_val: f32 = loss_detached
155 .into_data()
156 .to_vec::<f32>()
157 .unwrap_or_default()
158 .into_iter()
159 .next()
160 .unwrap_or(0.0);
161 losses.push(loss_val);
162 }
163 let avg_loss: f32 = if losses.is_empty() {
164 0.0
165 } else {
166 losses.iter().sum::<f32>() / losses.len() as f32
167 };
168 println!("epoch {epoch}: avg loss {avg_loss:.4}");
169 }
170
171 let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
172 model
173 .clone()
174 .save_file(Path::new(ckpt_path), &recorder)
175 .map_err(|e| anyhow::anyhow!("failed to save checkpoint: {e}"))?;
176
177 Ok(())
178}
179
180fn train_bigdet(
181 args: &TrainArgs,
182 samples: &[crate::RunSample],
183 ckpt_path: &str,
184) -> anyhow::Result<()> {
185 let device = <ADBackend as burn::tensor::backend::Backend>::Device::default();
186 let mut model = BigDet::<ADBackend>::new(
187 BigDetConfig {
188 input_dim: Some(4 + 8), max_boxes: args.max_boxes,
190 ..Default::default()
191 },
192 &device,
193 );
194 let mut optim = AdamConfig::new().init();
195
196 let batch_size = args.batch_size.max(1);
197 let data = samples.to_vec();
198 for epoch in 0..args.epochs {
199 let mut losses = Vec::new();
200 for batch in data.chunks(batch_size) {
201 let batch = crate::collate::<ADBackend>(batch, args.max_boxes)?;
202 let boxes = batch.boxes.clone();
204 let first_box = boxes
205 .clone()
206 .slice([0..boxes.dims()[0], 0..1, 0..4])
207 .reshape([boxes.dims()[0], 4]);
208 let features = batch.features.clone();
209 let input = burn::tensor::Tensor::cat(vec![first_box, features], 1);
210
211 let (pred_boxes, pred_scores) = model.forward_multibox(input);
212
213 let gt_boxes = batch.boxes.clone();
215 let gt_mask = batch.box_mask.clone();
216
217 let (obj_targets, box_targets, box_weights) =
219 build_greedy_targets(pred_boxes.clone(), gt_boxes.clone(), gt_mask.clone());
220 let eps = 1e-6;
224 let pred_scores_clamped = pred_scores.clamp(eps, 1.0 - eps);
225 let obj_targets_inv =
226 Tensor::<ADBackend, 2>::ones(obj_targets.dims(), &obj_targets.device())
227 - obj_targets.clone();
228 let obj_loss = -((obj_targets.clone() * pred_scores_clamped.clone().log())
229 + (obj_targets_inv
230 * (Tensor::<ADBackend, 2>::ones(
231 pred_scores_clamped.dims(),
232 &pred_scores_clamped.device(),
233 ) - pred_scores_clamped)
234 .log()))
235 .sum()
236 .div_scalar((obj_targets.dims()[0] * obj_targets.dims()[1]) as f32);
237
238 let box_err = (pred_boxes - box_targets.clone()).abs() * box_weights.clone();
240 let matched = box_weights.clone().sum().div_scalar(4.0);
241 let matched_scalar = matched
242 .into_data()
243 .to_vec::<f32>()
244 .unwrap_or_default()
245 .first()
246 .copied()
247 .unwrap_or(0.0);
248 let box_loss = if matched_scalar > 0.0 {
249 box_err.sum().div_scalar(matched_scalar)
250 } else {
251 let zeros = vec![0.0f32; 1];
253 Tensor::<ADBackend, 1>::from_data(
254 TensorData::new(zeros, [1]),
255 &box_weights.device(),
256 )
257 };
258
259 let loss = box_loss * args.lambda_box + obj_loss * args.lambda_obj;
260 let loss_detached = loss.clone().detach();
261 let grads = GradientsParams::from_grads(loss.backward(), &model);
262 model = optim.step(args.lr as f64, model, grads);
263
264 let loss_val: f32 = loss_detached
265 .into_data()
266 .to_vec::<f32>()
267 .unwrap_or_default()
268 .into_iter()
269 .next()
270 .unwrap_or(0.0);
271 losses.push(loss_val);
272 }
273 let avg_loss: f32 = if losses.is_empty() {
274 0.0
275 } else {
276 losses.iter().sum::<f32>() / losses.len() as f32
277 };
278 println!("epoch {epoch}: avg loss {avg_loss:.4}");
279 }
280
281 let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
282 model
283 .clone()
284 .save_file(Path::new(ckpt_path), &recorder)
285 .map_err(|e| anyhow::anyhow!("failed to save checkpoint: {e}"))?;
286
287 Ok(())
288}
289
290pub fn validate_backend_choice(kind: BackendKind) -> anyhow::Result<()> {
291 let built_wgpu = cfg!(feature = "backend-wgpu");
292 match (kind, built_wgpu) {
293 (BackendKind::Wgpu, false) => {
294 anyhow::bail!("backend-wgpu feature not enabled; rebuild with --features backend-wgpu or choose ndarray backend")
295 }
296 (BackendKind::NdArray, true) => {
297 println!("note: built with backend-wgpu; training will still use the WGPU backend despite --backend ndarray");
298 }
299 _ => {}
300 }
301 Ok(())
302}
303
304fn iou_xyxy(a: [f32; 4], b: [f32; 4]) -> f32 {
305 let ax0 = a[0].min(a[2]);
306 let ay0 = a[1].min(a[3]);
307 let ax1 = a[0].max(a[2]);
308 let ay1 = a[1].max(a[3]);
309 let bx0 = b[0].min(b[2]);
310 let by0 = b[1].min(b[3]);
311 let bx1 = b[0].max(b[2]);
312 let by1 = b[1].max(b[3]);
313
314 let inter_x0 = ax0.max(bx0);
315 let inter_y0 = ay0.max(by0);
316 let inter_x1 = ax1.min(bx1);
317 let inter_y1 = ay1.min(by1);
318
319 let inter_w = (inter_x1 - inter_x0).max(0.0);
320 let inter_h = (inter_y1 - inter_y0).max(0.0);
321 let inter_area = inter_w * inter_h;
322
323 let area_a = (ax1 - ax0).max(0.0) * (ay1 - ay0).max(0.0);
324 let area_b = (bx1 - bx0).max(0.0) * (by1 - by0).max(0.0);
325 let denom = area_a + area_b - inter_area;
326 if denom <= 0.0 {
327 0.0
328 } else {
329 inter_area / denom
330 }
331}
332pub fn load_bigdet_from_checkpoint<P: AsRef<Path>>(
333 path: P,
334 device: &<TrainBackend as burn::tensor::backend::Backend>::Device,
335 max_boxes: usize,
336) -> Result<BigDet<TrainBackend>, RecorderError> {
337 let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
338 BigDet::<TrainBackend>::new(
339 BigDetConfig {
340 max_boxes,
341 input_dim: Some(4 + 8),
342 ..Default::default()
343 },
344 device,
345 )
346 .load_file(path.as_ref(), &recorder, device)
347}
348
349pub fn build_greedy_targets<B: burn::tensor::backend::Backend>(
350 pred_boxes: Tensor<B, 3>,
351 gt_boxes: Tensor<B, 3>,
352 gt_mask: Tensor<B, 2>,
353) -> (Tensor<B, 2>, Tensor<B, 3>, Tensor<B, 3>) {
354 let batch = pred_boxes.dims()[0];
355 let max_pred = pred_boxes.dims()[1];
356 let max_gt = gt_boxes.dims()[1];
357
358 let gt_mask_vec = gt_mask
359 .clone()
360 .into_data()
361 .to_vec::<f32>()
362 .unwrap_or_default();
363 let gt_boxes_vec = gt_boxes
364 .clone()
365 .into_data()
366 .to_vec::<f32>()
367 .unwrap_or_default();
368 let pred_boxes_vec = pred_boxes
369 .clone()
370 .into_data()
371 .to_vec::<f32>()
372 .unwrap_or_default();
373
374 let mut obj_targets = vec![0.0f32; batch * max_pred];
375 let mut box_targets = vec![0.0f32; batch * max_pred * 4];
376 let mut box_weights = vec![0.0f32; batch * max_pred * 4];
377
378 for b in 0..batch {
379 for g in 0..max_gt {
380 let mask_idx = b * max_gt + g;
381 if gt_mask_vec.get(mask_idx).copied().unwrap_or(0.0) < 0.5 {
382 continue;
383 }
384 let gb = [
385 gt_boxes_vec[(b * max_gt + g) * 4],
386 gt_boxes_vec[(b * max_gt + g) * 4 + 1],
387 gt_boxes_vec[(b * max_gt + g) * 4 + 2],
388 gt_boxes_vec[(b * max_gt + g) * 4 + 3],
389 ];
390
391 let mut best_iou = -1.0f32;
392 let mut best_p = 0usize;
393 for p in 0..max_pred {
394 let pb = [
395 pred_boxes_vec[(b * max_pred + p) * 4],
396 pred_boxes_vec[(b * max_pred + p) * 4 + 1],
397 pred_boxes_vec[(b * max_pred + p) * 4 + 2],
398 pred_boxes_vec[(b * max_pred + p) * 4 + 3],
399 ];
400 let iou = iou_xyxy(pb, gb);
401 if iou > best_iou {
402 best_iou = iou;
403 best_p = p;
404 }
405 }
406
407 let obj_idx = b * max_pred + best_p;
408 obj_targets[obj_idx] = 1.0;
409 let bt_base = (b * max_pred + best_p) * 4;
410 box_targets[bt_base..bt_base + 4].copy_from_slice(&gb);
411 box_weights[bt_base..bt_base + 4].copy_from_slice(&[1.0, 1.0, 1.0, 1.0]);
412 }
413 }
414
415 let device = &B::Device::default();
416 let obj_targets =
417 Tensor::<B, 2>::from_data(TensorData::new(obj_targets, [batch, max_pred]), device);
418 let box_targets =
419 Tensor::<B, 3>::from_data(TensorData::new(box_targets, [batch, max_pred, 4]), device);
420 let box_weights =
421 Tensor::<B, 3>::from_data(TensorData::new(box_weights, [batch, max_pred, 4]), device);
422
423 (obj_targets, box_targets, box_weights)
424}