1use burn::backend::Autodiff;
2use burn::module::Module;
3use burn::nn::loss::{MseLoss, Reduction};
4use burn::optim::{AdamConfig, GradientsParams, Optimizer};
5use burn::record::{BinFileRecorder, FullPrecisionSettings, RecorderError};
6use burn::tensor::{Tensor, TensorData};
7use std::path::Path;
8
9use crate::{BigDet, BigDetConfig, DatasetConfig, TinyDet, TinyDetConfig, TrainBackend};
10use clap::{Parser, ValueEnum};
11use std::fs;
12
13pub fn load_tinydet_from_checkpoint<P: AsRef<Path>>(
14 path: P,
15 device: &<TrainBackend as burn::tensor::backend::Backend>::Device,
16) -> Result<TinyDet<TrainBackend>, RecorderError> {
17 let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
18 TinyDet::<TrainBackend>::new(TinyDetConfig::default(), device).load_file(
19 path.as_ref(),
20 &recorder,
21 device,
22 )
23}
24
25#[derive(ValueEnum, Debug, Clone, Copy)]
26pub enum ModelKind {
27 Tiny,
28 Big,
29}
30
31#[derive(ValueEnum, Debug, Clone, Copy)]
32pub enum BackendKind {
33 NdArray,
34 Wgpu,
35}
36
37#[derive(Parser, Debug)]
38#[command(name = "train", about = "Train TinyDet/BigDet on capture metadata")]
39pub struct TrainArgs {
40 #[arg(long, value_enum, default_value_t = ModelKind::Tiny)]
42 pub model: ModelKind,
43 #[arg(long, value_enum, default_value_t = BackendKind::NdArray)]
45 pub backend: BackendKind,
46 #[arg(long, default_value_t = 64)]
48 pub max_boxes: usize,
49 #[arg(long, default_value_t = 1.0)]
51 pub lambda_box: f32,
52 #[arg(long, default_value_t = 1.0)]
54 pub lambda_obj: f32,
55 #[arg(long, default_value = "assets/datasets/captures_filtered")]
57 pub dataset_root: String,
58 #[arg(long, default_value = "labels")]
60 pub labels_subdir: String,
61 #[arg(long, default_value = ".")]
63 pub images_subdir: String,
64 #[arg(long, default_value_t = 1)]
66 pub epochs: usize,
67 #[arg(long, default_value_t = 1)]
69 pub batch_size: usize,
70 #[arg(long, default_value_t = 1e-3)]
72 pub lr: f32,
73 #[arg(long, default_value_t = 0.3)]
75 pub infer_obj_thresh: f32,
76 #[arg(long, default_value_t = 0.5)]
78 pub infer_iou_thresh: f32,
79 #[arg(long)]
81 pub checkpoint_out: Option<String>,
82}
83
84pub fn run_train(args: TrainArgs) -> anyhow::Result<()> {
85 validate_backend_choice(args.backend)?;
86
87 let cfg = DatasetConfig {
88 root: args.dataset_root.clone().into(),
89 labels_subdir: args.labels_subdir.clone(),
90 images_subdir: args.images_subdir.clone(),
91 };
92 let samples = cfg.load()?;
93 if samples.is_empty() {
94 println!("No samples found under {}", cfg.root.display());
95 return Ok(());
96 }
97
98 let ckpt_path = args
99 .checkpoint_out
100 .clone()
101 .unwrap_or_else(|| match args.model {
102 ModelKind::Tiny => "checkpoints/tinydet.bin".to_string(),
103 ModelKind::Big => "checkpoints/bigdet.bin".to_string(),
104 });
105
106 if let Some(parent) = Path::new(&ckpt_path).parent() {
107 fs::create_dir_all(parent)?;
108 }
109
110 match args.model {
111 ModelKind::Tiny => train_tinydet(&args, &samples, &ckpt_path)?,
112 ModelKind::Big => train_bigdet(&args, &samples, &ckpt_path)?,
113 }
114
115 println!("Saved checkpoint to {}", ckpt_path);
116 Ok(())
117}
118
119type ADBackend = Autodiff<TrainBackend>;
120
121fn train_tinydet(
122 args: &TrainArgs,
123 samples: &[crate::RunSample],
124 ckpt_path: &str,
125) -> anyhow::Result<()> {
126 let device = <ADBackend as burn::tensor::backend::Backend>::Device::default();
127 let mut model = TinyDet::<ADBackend>::new(TinyDetConfig::default(), &device);
128 let mut optim = AdamConfig::new().init();
129
130 let batch_size = args.batch_size.max(1);
131 let data = samples.to_vec();
132 for epoch in 0..args.epochs {
133 let mut losses = Vec::new();
134 for batch in data.chunks(batch_size) {
135 let batch = crate::collate::<ADBackend>(batch, args.max_boxes)?;
136 let boxes = batch.boxes.clone();
138 let first_box = boxes
139 .clone()
140 .slice([0..boxes.dims()[0], 0..1, 0..4])
141 .reshape([boxes.dims()[0], 4]);
142
143 let mask = batch.box_mask.clone();
145 let has_box = mask.clone().sum_dim(1).reshape([mask.dims()[0], 1]);
146
147 let preds = model.forward(first_box);
148 let mse = MseLoss::new();
149 let loss = mse.forward(preds, has_box, Reduction::Mean);
150 let loss_detached = loss.clone().detach();
151 let grads = GradientsParams::from_grads(loss.backward(), &model);
152 model = optim.step(args.lr as f64, model, grads);
153
154 let loss_val: f32 = loss_detached
155 .into_data()
156 .to_vec::<f32>()
157 .unwrap_or_default()
158 .into_iter()
159 .next()
160 .unwrap_or(0.0);
161 losses.push(loss_val);
162 }
163 let avg_loss: f32 = if losses.is_empty() {
164 0.0
165 } else {
166 losses.iter().sum::<f32>() / losses.len() as f32
167 };
168 println!("epoch {epoch}: avg loss {avg_loss:.4}");
169 }
170
171 let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
172 model
173 .clone()
174 .save_file(Path::new(ckpt_path), &recorder)
175 .map_err(|e| anyhow::anyhow!("failed to save checkpoint: {e}"))?;
176
177 Ok(())
178}
179
180fn train_bigdet(
181 args: &TrainArgs,
182 samples: &[crate::RunSample],
183 ckpt_path: &str,
184) -> anyhow::Result<()> {
185 let device = <ADBackend as burn::tensor::backend::Backend>::Device::default();
186 let mut model = BigDet::<ADBackend>::new(
187 BigDetConfig {
188 input_dim: Some(4 + 8), ..Default::default()
190 },
191 &device,
192 );
193 let mut optim = AdamConfig::new().init();
194
195 let batch_size = args.batch_size.max(1);
196 let data = samples.to_vec();
197 for epoch in 0..args.epochs {
198 let mut losses = Vec::new();
199 for batch in data.chunks(batch_size) {
200 let batch = crate::collate::<ADBackend>(batch, args.max_boxes)?;
201 let boxes = batch.boxes.clone();
203 let first_box = boxes
204 .clone()
205 .slice([0..boxes.dims()[0], 0..1, 0..4])
206 .reshape([boxes.dims()[0], 4]);
207 let features = batch.features.clone();
208 let input = burn::tensor::Tensor::cat(vec![first_box, features], 1);
209
210 let (pred_boxes, pred_scores) = model.forward_multibox(input);
211
212 let gt_boxes = batch.boxes.clone();
214 let gt_mask = batch.box_mask.clone();
215
216 let (obj_targets, box_targets, box_weights) =
218 build_greedy_targets(pred_boxes.clone(), gt_boxes.clone(), gt_mask.clone());
219 let eps = 1e-6;
223 let pred_scores_clamped = pred_scores.clamp(eps, 1.0 - eps);
224 let obj_targets_inv =
225 Tensor::<ADBackend, 2>::ones(obj_targets.dims(), &obj_targets.device())
226 - obj_targets.clone();
227 let obj_loss = -((obj_targets.clone() * pred_scores_clamped.clone().log())
228 + (obj_targets_inv
229 * (Tensor::<ADBackend, 2>::ones(
230 pred_scores_clamped.dims(),
231 &pred_scores_clamped.device(),
232 ) - pred_scores_clamped)
233 .log()))
234 .sum()
235 .div_scalar((obj_targets.dims()[0] * obj_targets.dims()[1]) as f32);
236
237 let box_err = (pred_boxes - box_targets.clone()).abs() * box_weights.clone();
239 let matched = box_weights.clone().sum().div_scalar(4.0);
240 let matched_scalar = matched
241 .into_data()
242 .to_vec::<f32>()
243 .unwrap_or_default()
244 .first()
245 .copied()
246 .unwrap_or(0.0);
247 let box_loss = if matched_scalar > 0.0 {
248 box_err.sum().div_scalar(matched_scalar)
249 } else {
250 let zeros = vec![0.0f32; 1];
252 Tensor::<ADBackend, 1>::from_data(
253 TensorData::new(zeros, [1]),
254 &box_weights.device(),
255 )
256 };
257
258 let loss = box_loss * args.lambda_box + obj_loss * args.lambda_obj;
259 let loss_detached = loss.clone().detach();
260 let grads = GradientsParams::from_grads(loss.backward(), &model);
261 model = optim.step(args.lr as f64, model, grads);
262
263 let loss_val: f32 = loss_detached
264 .into_data()
265 .to_vec::<f32>()
266 .unwrap_or_default()
267 .into_iter()
268 .next()
269 .unwrap_or(0.0);
270 losses.push(loss_val);
271 }
272 let avg_loss: f32 = if losses.is_empty() {
273 0.0
274 } else {
275 losses.iter().sum::<f32>() / losses.len() as f32
276 };
277 println!("epoch {epoch}: avg loss {avg_loss:.4}");
278 }
279
280 let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
281 model
282 .clone()
283 .save_file(Path::new(ckpt_path), &recorder)
284 .map_err(|e| anyhow::anyhow!("failed to save checkpoint: {e}"))?;
285
286 Ok(())
287}
288
289pub fn validate_backend_choice(kind: BackendKind) -> anyhow::Result<()> {
290 let built_wgpu = cfg!(feature = "backend-wgpu");
291 match (kind, built_wgpu) {
292 (BackendKind::Wgpu, false) => {
293 anyhow::bail!("backend-wgpu feature not enabled; rebuild with --features backend-wgpu or choose ndarray backend")
294 }
295 (BackendKind::NdArray, true) => {
296 println!("note: built with backend-wgpu; training will still use the WGPU backend despite --backend ndarray");
297 }
298 _ => {}
299 }
300 Ok(())
301}
302
303fn iou_xyxy(a: [f32; 4], b: [f32; 4]) -> f32 {
304 let ax0 = a[0].min(a[2]);
305 let ay0 = a[1].min(a[3]);
306 let ax1 = a[0].max(a[2]);
307 let ay1 = a[1].max(a[3]);
308 let bx0 = b[0].min(b[2]);
309 let by0 = b[1].min(b[3]);
310 let bx1 = b[0].max(b[2]);
311 let by1 = b[1].max(b[3]);
312
313 let inter_x0 = ax0.max(bx0);
314 let inter_y0 = ay0.max(by0);
315 let inter_x1 = ax1.min(bx1);
316 let inter_y1 = ay1.min(by1);
317
318 let inter_w = (inter_x1 - inter_x0).max(0.0);
319 let inter_h = (inter_y1 - inter_y0).max(0.0);
320 let inter_area = inter_w * inter_h;
321
322 let area_a = (ax1 - ax0).max(0.0) * (ay1 - ay0).max(0.0);
323 let area_b = (bx1 - bx0).max(0.0) * (by1 - by0).max(0.0);
324 let denom = area_a + area_b - inter_area;
325 if denom <= 0.0 {
326 0.0
327 } else {
328 inter_area / denom
329 }
330}
331pub fn load_bigdet_from_checkpoint<P: AsRef<Path>>(
332 path: P,
333 device: &<TrainBackend as burn::tensor::backend::Backend>::Device,
334 max_boxes: usize,
335) -> Result<BigDet<TrainBackend>, RecorderError> {
336 let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
337 BigDet::<TrainBackend>::new(
338 BigDetConfig {
339 max_boxes,
340 input_dim: Some(4 + 8),
341 ..Default::default()
342 },
343 device,
344 )
345 .load_file(path.as_ref(), &recorder, device)
346}
347
348pub fn build_greedy_targets<B: burn::tensor::backend::Backend>(
349 pred_boxes: Tensor<B, 3>,
350 gt_boxes: Tensor<B, 3>,
351 gt_mask: Tensor<B, 2>,
352) -> (Tensor<B, 2>, Tensor<B, 3>, Tensor<B, 3>) {
353 let batch = pred_boxes.dims()[0];
354 let max_pred = pred_boxes.dims()[1];
355 let max_gt = gt_boxes.dims()[1];
356
357 let gt_mask_vec = gt_mask
358 .clone()
359 .into_data()
360 .to_vec::<f32>()
361 .unwrap_or_default();
362 let gt_boxes_vec = gt_boxes
363 .clone()
364 .into_data()
365 .to_vec::<f32>()
366 .unwrap_or_default();
367 let pred_boxes_vec = pred_boxes
368 .clone()
369 .into_data()
370 .to_vec::<f32>()
371 .unwrap_or_default();
372
373 let mut obj_targets = vec![0.0f32; batch * max_pred];
374 let mut box_targets = vec![0.0f32; batch * max_pred * 4];
375 let mut box_weights = vec![0.0f32; batch * max_pred * 4];
376
377 for b in 0..batch {
378 for g in 0..max_gt {
379 let mask_idx = b * max_gt + g;
380 if gt_mask_vec.get(mask_idx).copied().unwrap_or(0.0) < 0.5 {
381 continue;
382 }
383 let gb = [
384 gt_boxes_vec[(b * max_gt + g) * 4],
385 gt_boxes_vec[(b * max_gt + g) * 4 + 1],
386 gt_boxes_vec[(b * max_gt + g) * 4 + 2],
387 gt_boxes_vec[(b * max_gt + g) * 4 + 3],
388 ];
389
390 let mut best_iou = -1.0f32;
391 let mut best_p = 0usize;
392 for p in 0..max_pred {
393 let pb = [
394 pred_boxes_vec[(b * max_pred + p) * 4],
395 pred_boxes_vec[(b * max_pred + p) * 4 + 1],
396 pred_boxes_vec[(b * max_pred + p) * 4 + 2],
397 pred_boxes_vec[(b * max_pred + p) * 4 + 3],
398 ];
399 let iou = iou_xyxy(pb, gb);
400 if iou > best_iou {
401 best_iou = iou;
402 best_p = p;
403 }
404 }
405
406 let obj_idx = b * max_pred + best_p;
407 obj_targets[obj_idx] = 1.0;
408 let bt_base = (b * max_pred + best_p) * 4;
409 box_targets[bt_base..bt_base + 4].copy_from_slice(&gb);
410 box_weights[bt_base..bt_base + 4].copy_from_slice(&[1.0, 1.0, 1.0, 1.0]);
411 }
412 }
413
414 let device = &B::Device::default();
415 let obj_targets =
416 Tensor::<B, 2>::from_data(TensorData::new(obj_targets, [batch, max_pred]), device);
417 let box_targets =
418 Tensor::<B, 3>::from_data(TensorData::new(box_targets, [batch, max_pred, 4]), device);
419 let box_weights =
420 Tensor::<B, 3>::from_data(TensorData::new(box_weights, [batch, max_pred, 4]), device);
421
422 (obj_targets, box_targets, box_weights)
423}