1use burn::backend::Autodiff;
2use burn::module::Module;
3use burn::nn::loss::{MseLoss, Reduction};
4use burn::optim::{AdamConfig, GradientsParams, Optimizer};
5use burn::record::{BinFileRecorder, FullPrecisionSettings, RecorderError};
6use burn::tensor::{Tensor, TensorData};
7use burn_dataset::WarehouseLoaders;
8use std::path::Path;
9
10use crate::{BigDet, BigDetConfig, DatasetConfig, TinyDet, TinyDetConfig, TrainBackend};
11use clap::{Parser, ValueEnum};
12use std::fs;
13
14pub fn load_tinydet_from_checkpoint<P: AsRef<Path>>(
15 path: P,
16 device: &<TrainBackend as burn::tensor::backend::Backend>::Device,
17) -> Result<TinyDet<TrainBackend>, RecorderError> {
18 let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
19 TinyDet::<TrainBackend>::new(TinyDetConfig::default(), device).load_file(
20 path.as_ref(),
21 &recorder,
22 device,
23 )
24}
25
26#[derive(ValueEnum, Debug, Clone, Copy)]
27pub enum ModelKind {
28 Tiny,
29 Big,
30}
31
32#[derive(ValueEnum, Debug, Clone, Copy)]
33pub enum BackendKind {
34 NdArray,
35 Wgpu,
36}
37
38#[derive(ValueEnum, Debug, Clone, Copy)]
39pub enum TrainingInputSource {
40 Warehouse,
41 CaptureLogs,
42}
43
44#[derive(Parser, Debug)]
45#[command(name = "train", about = "Train TinyDet/BigDet (warehouse-first)")]
46pub struct TrainArgs {
47 #[arg(long, value_enum, default_value_t = ModelKind::Tiny)]
49 pub model: ModelKind,
50 #[arg(long, value_enum, default_value_t = BackendKind::NdArray)]
52 pub backend: BackendKind,
53 #[arg(long, default_value_t = 64)]
55 pub max_boxes: usize,
56 #[arg(long, default_value_t = 1.0)]
58 pub lambda_box: f32,
59 #[arg(long, default_value_t = 1.0)]
61 pub lambda_obj: f32,
62 #[arg(long, value_enum, default_value_t = TrainingInputSource::Warehouse)]
64 pub input_source: TrainingInputSource,
65 #[arg(long, default_value = "assets/warehouse/manifest.json")]
67 pub warehouse_manifest: String,
68 #[arg(long, default_value = "assets/datasets/captures_filtered")]
70 pub dataset_root: String,
71 #[arg(long, default_value = "labels")]
73 pub labels_subdir: String,
74 #[arg(long, default_value = ".")]
76 pub images_subdir: String,
77 #[arg(long, default_value_t = 1)]
79 pub epochs: usize,
80 #[arg(long, default_value_t = 1)]
82 pub batch_size: usize,
83 #[arg(long, default_value_t = 1e-3)]
85 pub lr: f32,
86 #[arg(long, default_value_t = 0.3)]
88 pub infer_obj_thresh: f32,
89 #[arg(long, default_value_t = 0.5)]
91 pub infer_iou_thresh: f32,
92 #[arg(long)]
94 pub checkpoint_out: Option<String>,
95}
96
97pub fn run_train(args: TrainArgs) -> anyhow::Result<()> {
98 validate_backend_choice(args.backend)?;
99
100 let ckpt_path = args
101 .checkpoint_out
102 .clone()
103 .unwrap_or_else(|| match args.model {
104 ModelKind::Tiny => "checkpoints/tinydet.bin".to_string(),
105 ModelKind::Big => "checkpoints/bigdet.bin".to_string(),
106 });
107
108 if let Some(parent) = Path::new(&ckpt_path).parent() {
109 fs::create_dir_all(parent)?;
110 }
111
112 match args.input_source {
113 TrainingInputSource::Warehouse => {
114 let manifest_path = Path::new(&args.warehouse_manifest);
115 let loaders = WarehouseLoaders::from_manifest_path(manifest_path, 0.0, None, false)
116 .map_err(|e| {
117 anyhow::anyhow!(
118 "failed to load warehouse manifest at {}: {e}",
119 manifest_path.display()
120 )
121 })?;
122 if loaders.train_len() == 0 {
123 anyhow::bail!(
124 "warehouse manifest {} contains no training shards",
125 manifest_path.display()
126 );
127 }
128 match args.model {
129 ModelKind::Tiny => train_tinydet_warehouse(&args, &loaders, &ckpt_path)?,
130 ModelKind::Big => train_bigdet_warehouse(&args, &loaders, &ckpt_path)?,
131 }
132 }
133 TrainingInputSource::CaptureLogs => {
134 println!("training from capture logs (legacy path); prefer warehouse manifests");
135 let cfg = DatasetConfig {
136 root: args.dataset_root.clone().into(),
137 labels_subdir: args.labels_subdir.clone(),
138 images_subdir: args.images_subdir.clone(),
139 };
140 let samples = cfg.load()?;
141 if samples.is_empty() {
142 println!("No samples found under {}", cfg.root.display());
143 return Ok(());
144 }
145 match args.model {
146 ModelKind::Tiny => train_tinydet(&args, &samples, &ckpt_path)?,
147 ModelKind::Big => train_bigdet(&args, &samples, &ckpt_path)?,
148 }
149 }
150 }
151
152 println!("Saved checkpoint to {}", ckpt_path);
153 Ok(())
154}
155
156type ADBackend = Autodiff<TrainBackend>;
157
158fn train_tinydet(
159 args: &TrainArgs,
160 samples: &[crate::RunSample],
161 ckpt_path: &str,
162) -> anyhow::Result<()> {
163 let device = <ADBackend as burn::tensor::backend::Backend>::Device::default();
164 let mut model = TinyDet::<ADBackend>::new(TinyDetConfig::default(), &device);
165 let mut optim = AdamConfig::new().init();
166
167 let batch_size = args.batch_size.max(1);
168 let data = samples.to_vec();
169 for epoch in 0..args.epochs {
170 let mut losses = Vec::new();
171 for batch in data.chunks(batch_size) {
172 let batch = crate::collate::<ADBackend>(batch, args.max_boxes)?;
173 let boxes = batch.boxes.clone();
175 let first_box = boxes
176 .clone()
177 .slice([0..boxes.dims()[0], 0..1, 0..4])
178 .reshape([boxes.dims()[0], 4]);
179
180 let mask = batch.box_mask.clone();
182 let has_box = mask.clone().sum_dim(1).reshape([mask.dims()[0], 1]);
183
184 let preds = model.forward(first_box);
185 let mse = MseLoss::new();
186 let loss = mse.forward(preds, has_box, Reduction::Mean);
187 let loss_detached = loss.clone().detach();
188 let grads = GradientsParams::from_grads(loss.backward(), &model);
189 model = optim.step(args.lr as f64, model, grads);
190
191 let loss_val: f32 = loss_detached
192 .into_data()
193 .to_vec::<f32>()
194 .unwrap_or_default()
195 .into_iter()
196 .next()
197 .unwrap_or(0.0);
198 losses.push(loss_val);
199 }
200 let avg_loss: f32 = if losses.is_empty() {
201 0.0
202 } else {
203 losses.iter().sum::<f32>() / losses.len() as f32
204 };
205 println!("epoch {epoch}: avg loss {avg_loss:.4}");
206 }
207
208 let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
209 model
210 .clone()
211 .save_file(Path::new(ckpt_path), &recorder)
212 .map_err(|e| anyhow::anyhow!("failed to save checkpoint: {e}"))?;
213
214 Ok(())
215}
216
217fn train_tinydet_warehouse(
218 args: &TrainArgs,
219 loaders: &WarehouseLoaders,
220 ckpt_path: &str,
221) -> anyhow::Result<()> {
222 let device = <ADBackend as burn::tensor::backend::Backend>::Device::default();
223 let mut model = TinyDet::<ADBackend>::new(TinyDetConfig::default(), &device);
224 let mut optim = AdamConfig::new().init();
225
226 let batch_size = args.batch_size.max(1);
227 for epoch in 0..args.epochs {
228 let mut losses = Vec::new();
229 let mut iter = loaders.train_iter();
230 loop {
231 let batch = match iter.next_batch::<ADBackend>(batch_size, &device)? {
232 Some(batch) => batch,
233 None => break,
234 };
235 let batch = crate::collate_from_burn_batch::<ADBackend>(batch, args.max_boxes)?;
236
237 let boxes = batch.boxes.clone();
239 let first_box = boxes
240 .clone()
241 .slice([0..boxes.dims()[0], 0..1, 0..4])
242 .reshape([boxes.dims()[0], 4]);
243
244 let mask = batch.box_mask.clone();
246 let has_box = mask.clone().sum_dim(1).reshape([mask.dims()[0], 1]);
247
248 let preds = model.forward(first_box);
249 let mse = MseLoss::new();
250 let loss = mse.forward(preds, has_box, Reduction::Mean);
251 let loss_detached = loss.clone().detach();
252 let grads = GradientsParams::from_grads(loss.backward(), &model);
253 model = optim.step(args.lr as f64, model, grads);
254
255 let loss_val: f32 = loss_detached
256 .into_data()
257 .to_vec::<f32>()
258 .unwrap_or_default()
259 .into_iter()
260 .next()
261 .unwrap_or(0.0);
262 losses.push(loss_val);
263 }
264 let avg_loss: f32 = if losses.is_empty() {
265 0.0
266 } else {
267 losses.iter().sum::<f32>() / losses.len() as f32
268 };
269 println!("epoch {epoch}: avg loss {avg_loss:.4}");
270 }
271
272 let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
273 model
274 .clone()
275 .save_file(Path::new(ckpt_path), &recorder)
276 .map_err(|e| anyhow::anyhow!("failed to save checkpoint: {e}"))?;
277
278 Ok(())
279}
280
281fn train_bigdet(
282 args: &TrainArgs,
283 samples: &[crate::RunSample],
284 ckpt_path: &str,
285) -> anyhow::Result<()> {
286 let device = <ADBackend as burn::tensor::backend::Backend>::Device::default();
287 let mut model = BigDet::<ADBackend>::new(
288 BigDetConfig {
289 input_dim: Some(4 + 8), max_boxes: args.max_boxes,
291 ..Default::default()
292 },
293 &device,
294 );
295 let mut optim = AdamConfig::new().init();
296
297 let batch_size = args.batch_size.max(1);
298 let data = samples.to_vec();
299 for epoch in 0..args.epochs {
300 let mut losses = Vec::new();
301 for batch in data.chunks(batch_size) {
302 let batch = crate::collate::<ADBackend>(batch, args.max_boxes)?;
303 let boxes = batch.boxes.clone();
305 let first_box = boxes
306 .clone()
307 .slice([0..boxes.dims()[0], 0..1, 0..4])
308 .reshape([boxes.dims()[0], 4]);
309 let features = batch.features.clone();
310 let input = burn::tensor::Tensor::cat(vec![first_box, features], 1);
311
312 let (pred_boxes, pred_scores) = model.forward_multibox(input);
313
314 let gt_boxes = batch.boxes.clone();
316 let gt_mask = batch.box_mask.clone();
317
318 let (obj_targets, box_targets, box_weights) =
320 build_greedy_targets(pred_boxes.clone(), gt_boxes.clone(), gt_mask.clone());
321 let eps = 1e-6;
325 let pred_scores_clamped = pred_scores.clamp(eps, 1.0 - eps);
326 let obj_targets_inv =
327 Tensor::<ADBackend, 2>::ones(obj_targets.dims(), &obj_targets.device())
328 - obj_targets.clone();
329 let obj_loss = -((obj_targets.clone() * pred_scores_clamped.clone().log())
330 + (obj_targets_inv
331 * (Tensor::<ADBackend, 2>::ones(
332 pred_scores_clamped.dims(),
333 &pred_scores_clamped.device(),
334 ) - pred_scores_clamped)
335 .log()))
336 .sum()
337 .div_scalar((obj_targets.dims()[0] * obj_targets.dims()[1]) as f32);
338
339 let box_err = (pred_boxes - box_targets.clone()).abs() * box_weights.clone();
341 let matched = box_weights.clone().sum().div_scalar(4.0);
342 let matched_scalar = matched
343 .into_data()
344 .to_vec::<f32>()
345 .unwrap_or_default()
346 .first()
347 .copied()
348 .unwrap_or(0.0);
349 let box_loss = if matched_scalar > 0.0 {
350 box_err.sum().div_scalar(matched_scalar)
351 } else {
352 let zeros = vec![0.0f32; 1];
354 Tensor::<ADBackend, 1>::from_data(
355 TensorData::new(zeros, [1]),
356 &box_weights.device(),
357 )
358 };
359
360 let loss = box_loss * args.lambda_box + obj_loss * args.lambda_obj;
361 let loss_detached = loss.clone().detach();
362 let grads = GradientsParams::from_grads(loss.backward(), &model);
363 model = optim.step(args.lr as f64, model, grads);
364
365 let loss_val: f32 = loss_detached
366 .into_data()
367 .to_vec::<f32>()
368 .unwrap_or_default()
369 .into_iter()
370 .next()
371 .unwrap_or(0.0);
372 losses.push(loss_val);
373 }
374 let avg_loss: f32 = if losses.is_empty() {
375 0.0
376 } else {
377 losses.iter().sum::<f32>() / losses.len() as f32
378 };
379 println!("epoch {epoch}: avg loss {avg_loss:.4}");
380 }
381
382 let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
383 model
384 .clone()
385 .save_file(Path::new(ckpt_path), &recorder)
386 .map_err(|e| anyhow::anyhow!("failed to save checkpoint: {e}"))?;
387
388 Ok(())
389}
390
391fn train_bigdet_warehouse(
392 args: &TrainArgs,
393 loaders: &WarehouseLoaders,
394 ckpt_path: &str,
395) -> anyhow::Result<()> {
396 let device = <ADBackend as burn::tensor::backend::Backend>::Device::default();
397 let mut model = BigDet::<ADBackend>::new(
398 BigDetConfig {
399 input_dim: Some(4 + 8), max_boxes: args.max_boxes,
401 ..Default::default()
402 },
403 &device,
404 );
405 let mut optim = AdamConfig::new().init();
406
407 let batch_size = args.batch_size.max(1);
408 for epoch in 0..args.epochs {
409 let mut losses = Vec::new();
410 let mut iter = loaders.train_iter();
411 loop {
412 let batch = match iter.next_batch::<ADBackend>(batch_size, &device)? {
413 Some(batch) => batch,
414 None => break,
415 };
416 let batch = crate::collate_from_burn_batch::<ADBackend>(batch, args.max_boxes)?;
417
418 let boxes = batch.boxes.clone();
419 let first_box = boxes
420 .clone()
421 .slice([0..boxes.dims()[0], 0..1, 0..4])
422 .reshape([boxes.dims()[0], 4]);
423 let features = batch.features.clone();
424 let input = burn::tensor::Tensor::cat(vec![first_box, features], 1);
425
426 let (pred_boxes, pred_scores) = model.forward_multibox(input);
427
428 let gt_boxes = batch.boxes.clone();
429 let gt_mask = batch.box_mask.clone();
430
431 let (obj_targets, box_targets, box_weights) =
432 build_greedy_targets(pred_boxes.clone(), gt_boxes.clone(), gt_mask.clone());
433
434 let eps = 1e-6;
435 let pred_scores_clamped = pred_scores.clamp(eps, 1.0 - eps);
436 let obj_targets_inv =
437 Tensor::<ADBackend, 2>::ones(obj_targets.dims(), &obj_targets.device())
438 - obj_targets.clone();
439 let obj_loss = -((obj_targets.clone() * pred_scores_clamped.clone().log())
440 + (obj_targets_inv
441 * (Tensor::<ADBackend, 2>::ones(
442 pred_scores_clamped.dims(),
443 &pred_scores_clamped.device(),
444 ) - pred_scores_clamped)
445 .log()))
446 .sum()
447 .div_scalar((obj_targets.dims()[0] * obj_targets.dims()[1]) as f32);
448
449 let box_err = (pred_boxes - box_targets.clone()).abs() * box_weights.clone();
450 let matched = box_weights.clone().sum().div_scalar(4.0);
451 let matched_scalar = matched
452 .into_data()
453 .to_vec::<f32>()
454 .unwrap_or_default()
455 .first()
456 .copied()
457 .unwrap_or(0.0);
458 let box_loss = if matched_scalar > 0.0 {
459 box_err.sum().div_scalar(matched_scalar)
460 } else {
461 let zeros = vec![0.0f32; 1];
462 Tensor::<ADBackend, 1>::from_data(
463 TensorData::new(zeros, [1]),
464 &box_weights.device(),
465 )
466 };
467
468 let loss = box_loss * args.lambda_box + obj_loss * args.lambda_obj;
469 let loss_detached = loss.clone().detach();
470 let grads = GradientsParams::from_grads(loss.backward(), &model);
471 model = optim.step(args.lr as f64, model, grads);
472
473 let loss_val: f32 = loss_detached
474 .into_data()
475 .to_vec::<f32>()
476 .unwrap_or_default()
477 .into_iter()
478 .next()
479 .unwrap_or(0.0);
480 losses.push(loss_val);
481 }
482 let avg_loss: f32 = if losses.is_empty() {
483 0.0
484 } else {
485 losses.iter().sum::<f32>() / losses.len() as f32
486 };
487 println!("epoch {epoch}: avg loss {avg_loss:.4}");
488 }
489
490 let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
491 model
492 .clone()
493 .save_file(Path::new(ckpt_path), &recorder)
494 .map_err(|e| anyhow::anyhow!("failed to save checkpoint: {e}"))?;
495
496 Ok(())
497}
498
499pub fn validate_backend_choice(kind: BackendKind) -> anyhow::Result<()> {
500 let built_wgpu = cfg!(feature = "backend-wgpu");
501 match (kind, built_wgpu) {
502 (BackendKind::Wgpu, false) => {
503 anyhow::bail!("backend-wgpu feature not enabled; rebuild with --features backend-wgpu or choose ndarray backend")
504 }
505 (BackendKind::NdArray, true) => {
506 println!("note: built with backend-wgpu; training will still use the WGPU backend despite --backend ndarray");
507 }
508 _ => {}
509 }
510 Ok(())
511}
512
513fn iou_xyxy(a: [f32; 4], b: [f32; 4]) -> f32 {
514 let ax0 = a[0].min(a[2]);
515 let ay0 = a[1].min(a[3]);
516 let ax1 = a[0].max(a[2]);
517 let ay1 = a[1].max(a[3]);
518 let bx0 = b[0].min(b[2]);
519 let by0 = b[1].min(b[3]);
520 let bx1 = b[0].max(b[2]);
521 let by1 = b[1].max(b[3]);
522
523 let inter_x0 = ax0.max(bx0);
524 let inter_y0 = ay0.max(by0);
525 let inter_x1 = ax1.min(bx1);
526 let inter_y1 = ay1.min(by1);
527
528 let inter_w = (inter_x1 - inter_x0).max(0.0);
529 let inter_h = (inter_y1 - inter_y0).max(0.0);
530 let inter_area = inter_w * inter_h;
531
532 let area_a = (ax1 - ax0).max(0.0) * (ay1 - ay0).max(0.0);
533 let area_b = (bx1 - bx0).max(0.0) * (by1 - by0).max(0.0);
534 let denom = area_a + area_b - inter_area;
535 if denom <= 0.0 {
536 0.0
537 } else {
538 inter_area / denom
539 }
540}
541pub fn load_bigdet_from_checkpoint<P: AsRef<Path>>(
542 path: P,
543 device: &<TrainBackend as burn::tensor::backend::Backend>::Device,
544 max_boxes: usize,
545) -> Result<BigDet<TrainBackend>, RecorderError> {
546 let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
547 BigDet::<TrainBackend>::new(
548 BigDetConfig {
549 max_boxes,
550 input_dim: Some(4 + 8),
551 ..Default::default()
552 },
553 device,
554 )
555 .load_file(path.as_ref(), &recorder, device)
556}
557
558pub fn build_greedy_targets<B: burn::tensor::backend::Backend>(
559 pred_boxes: Tensor<B, 3>,
560 gt_boxes: Tensor<B, 3>,
561 gt_mask: Tensor<B, 2>,
562) -> (Tensor<B, 2>, Tensor<B, 3>, Tensor<B, 3>) {
563 let batch = pred_boxes.dims()[0];
564 let max_pred = pred_boxes.dims()[1];
565 let max_gt = gt_boxes.dims()[1];
566
567 let gt_mask_vec = gt_mask
568 .clone()
569 .into_data()
570 .to_vec::<f32>()
571 .unwrap_or_default();
572 let gt_boxes_vec = gt_boxes
573 .clone()
574 .into_data()
575 .to_vec::<f32>()
576 .unwrap_or_default();
577 let pred_boxes_vec = pred_boxes
578 .clone()
579 .into_data()
580 .to_vec::<f32>()
581 .unwrap_or_default();
582
583 let mut obj_targets = vec![0.0f32; batch * max_pred];
584 let mut box_targets = vec![0.0f32; batch * max_pred * 4];
585 let mut box_weights = vec![0.0f32; batch * max_pred * 4];
586
587 for b in 0..batch {
588 for g in 0..max_gt {
589 let mask_idx = b * max_gt + g;
590 if gt_mask_vec.get(mask_idx).copied().unwrap_or(0.0) < 0.5 {
591 continue;
592 }
593 let gb = [
594 gt_boxes_vec[(b * max_gt + g) * 4],
595 gt_boxes_vec[(b * max_gt + g) * 4 + 1],
596 gt_boxes_vec[(b * max_gt + g) * 4 + 2],
597 gt_boxes_vec[(b * max_gt + g) * 4 + 3],
598 ];
599
600 let mut best_iou = -1.0f32;
601 let mut best_p = 0usize;
602 for p in 0..max_pred {
603 let pb = [
604 pred_boxes_vec[(b * max_pred + p) * 4],
605 pred_boxes_vec[(b * max_pred + p) * 4 + 1],
606 pred_boxes_vec[(b * max_pred + p) * 4 + 2],
607 pred_boxes_vec[(b * max_pred + p) * 4 + 3],
608 ];
609 let iou = iou_xyxy(pb, gb);
610 if iou > best_iou {
611 best_iou = iou;
612 best_p = p;
613 }
614 }
615
616 let obj_idx = b * max_pred + best_p;
617 obj_targets[obj_idx] = 1.0;
618 let bt_base = (b * max_pred + best_p) * 4;
619 box_targets[bt_base..bt_base + 4].copy_from_slice(&gb);
620 box_weights[bt_base..bt_base + 4].copy_from_slice(&[1.0, 1.0, 1.0, 1.0]);
621 }
622 }
623
624 let device = &B::Device::default();
625 let obj_targets =
626 Tensor::<B, 2>::from_data(TensorData::new(obj_targets, [batch, max_pred]), device);
627 let box_targets =
628 Tensor::<B, 3>::from_data(TensorData::new(box_targets, [batch, max_pred, 4]), device);
629 let box_weights =
630 Tensor::<B, 3>::from_data(TensorData::new(box_weights, [batch, max_pred, 4]), device);
631
632 (obj_targets, box_targets, box_weights)
633}