1use burn::backend::Autodiff;
2use burn::module::Module;
3use burn::nn::loss::{MseLoss, Reduction};
4use burn::optim::{AdamConfig, GradientsParams, Optimizer};
5use burn::record::{BinFileRecorder, FullPrecisionSettings, RecorderError};
6use burn::tensor::{Tensor, TensorData};
7use burn_dataset::WarehouseLoaders;
8use std::path::Path;
9
10use crate::{
11 DatasetPathConfig, LinearClassifier, LinearClassifierConfig, MultiboxModel,
12 MultiboxModelConfig, TrainBackend,
13};
14use clap::{Parser, ValueEnum};
15use std::fs;
16
17pub fn load_linear_classifier_from_checkpoint<P: AsRef<Path>>(
18 path: P,
19 device: &<TrainBackend as burn::tensor::backend::Backend>::Device,
20) -> Result<LinearClassifier<TrainBackend>, RecorderError> {
21 let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
22 LinearClassifier::<TrainBackend>::new(LinearClassifierConfig::default(), device).load_file(
23 path.as_ref(),
24 &recorder,
25 device,
26 )
27}
28
29#[derive(ValueEnum, Debug, Clone, Copy)]
30pub enum ModelKind {
31 Tiny,
32 Big,
33}
34
35#[derive(ValueEnum, Debug, Clone, Copy)]
36pub enum BackendKind {
37 NdArray,
38 Wgpu,
39}
40
41#[derive(ValueEnum, Debug, Clone, Copy)]
42pub enum TrainingInputSource {
43 Warehouse,
44 CaptureLogs,
45}
46
47#[derive(Parser, Debug)]
48#[command(
49 name = "train",
50 about = "Train LinearClassifier/MultiboxModel (warehouse-first)"
51)]
52pub struct TrainArgs {
53 #[arg(long, value_enum, default_value_t = ModelKind::Tiny)]
55 pub model: ModelKind,
56 #[arg(long, value_enum, default_value_t = BackendKind::NdArray)]
58 pub backend: BackendKind,
59 #[arg(long, default_value_t = 64)]
61 pub max_boxes: usize,
62 #[arg(long, default_value_t = 1.0)]
64 pub lambda_box: f32,
65 #[arg(long, default_value_t = 1.0)]
67 pub lambda_obj: f32,
68 #[arg(long, value_enum, default_value_t = TrainingInputSource::Warehouse)]
70 pub input_source: TrainingInputSource,
71 #[arg(long, default_value = "assets/warehouse/manifest.json")]
73 pub warehouse_manifest: String,
74 #[arg(long, default_value = "assets/datasets/captures_filtered")]
76 pub dataset_root: String,
77 #[arg(long, default_value = "labels")]
79 pub labels_subdir: String,
80 #[arg(long, default_value = ".")]
82 pub images_subdir: String,
83 #[arg(long, default_value_t = 1)]
85 pub epochs: usize,
86 #[arg(long, default_value_t = 1)]
88 pub batch_size: usize,
89 #[arg(long, default_value_t = 1e-3)]
91 pub lr: f32,
92 #[arg(long, default_value_t = 0.3)]
94 pub infer_obj_thresh: f32,
95 #[arg(long, default_value_t = 0.5)]
97 pub infer_iou_thresh: f32,
98 #[arg(long)]
100 pub checkpoint_out: Option<String>,
101}
102
103pub fn run_train(args: TrainArgs) -> anyhow::Result<()> {
104 validate_backend_choice(args.backend)?;
105
106 let ckpt_path = args
107 .checkpoint_out
108 .clone()
109 .unwrap_or_else(|| match args.model {
110 ModelKind::Tiny => "checkpoints/linear_detector.bin".to_string(),
111 ModelKind::Big => "checkpoints/convolutional_detector.bin".to_string(),
112 });
113
114 if let Some(parent) = Path::new(&ckpt_path).parent() {
115 fs::create_dir_all(parent)?;
116 }
117
118 match args.input_source {
119 TrainingInputSource::Warehouse => {
120 let manifest_path = Path::new(&args.warehouse_manifest);
121 let loaders = WarehouseLoaders::from_manifest_path(manifest_path, 0.0, None, false)
122 .map_err(|e| {
123 anyhow::anyhow!(
124 "failed to load warehouse manifest at {}: {e}",
125 manifest_path.display()
126 )
127 })?;
128 if loaders.train_len() == 0 {
129 anyhow::bail!(
130 "warehouse manifest {} contains no training shards",
131 manifest_path.display()
132 );
133 }
134 match args.model {
135 ModelKind::Tiny => train_linear_detector_warehouse(&args, &loaders, &ckpt_path)?,
136 ModelKind::Big => {
137 train_convolutional_detector_warehouse(&args, &loaders, &ckpt_path)?
138 }
139 }
140 }
141 TrainingInputSource::CaptureLogs => {
142 println!("training from capture logs (legacy path); prefer warehouse manifests");
143 let cfg = DatasetPathConfig {
144 root: args.dataset_root.clone().into(),
145 labels_subdir: args.labels_subdir.clone(),
146 images_subdir: args.images_subdir.clone(),
147 };
148 let samples = cfg.load()?;
149 if samples.is_empty() {
150 println!("No samples found under {}", cfg.root.display());
151 return Ok(());
152 }
153 match args.model {
154 ModelKind::Tiny => train_linear_detector(&args, &samples, &ckpt_path)?,
155 ModelKind::Big => train_convolutional_detector(&args, &samples, &ckpt_path)?,
156 }
157 }
158 }
159
160 println!("Saved checkpoint to {}", ckpt_path);
161 Ok(())
162}
163
164type ADBackend = Autodiff<TrainBackend>;
165
166fn train_linear_detector(
167 args: &TrainArgs,
168 samples: &[crate::RunSample],
169 ckpt_path: &str,
170) -> anyhow::Result<()> {
171 let device = <ADBackend as burn::tensor::backend::Backend>::Device::default();
172 let mut model = LinearClassifier::<ADBackend>::new(LinearClassifierConfig::default(), &device);
173 let mut optim = AdamConfig::new().init();
174
175 let batch_size = args.batch_size.max(1);
176 let data = samples.to_vec();
177 for epoch in 0..args.epochs {
178 let mut losses = Vec::new();
179 for batch in data.chunks(batch_size) {
180 let batch = crate::collate::<ADBackend>(batch, args.max_boxes)?;
181 let boxes = batch.boxes.clone();
183 let first_box = boxes
184 .clone()
185 .slice([0..boxes.dims()[0], 0..1, 0..4])
186 .reshape([boxes.dims()[0], 4]);
187
188 let mask = batch.box_mask.clone();
190 let has_box = mask.clone().sum_dim(1).reshape([mask.dims()[0], 1]);
191
192 let preds = model.forward(first_box);
193 let mse = MseLoss::new();
194 let loss = mse.forward(preds, has_box, Reduction::Mean);
195 let loss_detached = loss.clone().detach();
196 let grads = GradientsParams::from_grads(loss.backward(), &model);
197 model = optim.step(args.lr as f64, model, grads);
198
199 let loss_val: f32 = loss_detached
200 .into_data()
201 .to_vec::<f32>()
202 .unwrap_or_default()
203 .into_iter()
204 .next()
205 .unwrap_or(0.0);
206 losses.push(loss_val);
207 }
208 let avg_loss: f32 = if losses.is_empty() {
209 0.0
210 } else {
211 losses.iter().sum::<f32>() / losses.len() as f32
212 };
213 println!("epoch {epoch}: avg loss {avg_loss:.4}");
214 }
215
216 let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
217 model
218 .clone()
219 .save_file(Path::new(ckpt_path), &recorder)
220 .map_err(|e| anyhow::anyhow!("failed to save checkpoint: {e}"))?;
221
222 Ok(())
223}
224
225fn train_linear_detector_warehouse(
226 args: &TrainArgs,
227 loaders: &WarehouseLoaders,
228 ckpt_path: &str,
229) -> anyhow::Result<()> {
230 let device = <ADBackend as burn::tensor::backend::Backend>::Device::default();
231 let mut model = LinearClassifier::<ADBackend>::new(LinearClassifierConfig::default(), &device);
232 let mut optim = AdamConfig::new().init();
233
234 let batch_size = args.batch_size.max(1);
235 for epoch in 0..args.epochs {
236 let mut losses = Vec::new();
237 let mut iter = loaders.train_iter();
238 loop {
239 let batch = match iter.next_batch::<ADBackend>(batch_size, &device)? {
240 Some(batch) => batch,
241 None => break,
242 };
243 let batch = crate::collate_from_burn_batch::<ADBackend>(batch, args.max_boxes)?;
244
245 let boxes = batch.boxes.clone();
247 let first_box = boxes
248 .clone()
249 .slice([0..boxes.dims()[0], 0..1, 0..4])
250 .reshape([boxes.dims()[0], 4]);
251
252 let mask = batch.box_mask.clone();
254 let has_box = mask.clone().sum_dim(1).reshape([mask.dims()[0], 1]);
255
256 let preds = model.forward(first_box);
257 let mse = MseLoss::new();
258 let loss = mse.forward(preds, has_box, Reduction::Mean);
259 let loss_detached = loss.clone().detach();
260 let grads = GradientsParams::from_grads(loss.backward(), &model);
261 model = optim.step(args.lr as f64, model, grads);
262
263 let loss_val: f32 = loss_detached
264 .into_data()
265 .to_vec::<f32>()
266 .unwrap_or_default()
267 .into_iter()
268 .next()
269 .unwrap_or(0.0);
270 losses.push(loss_val);
271 }
272 let avg_loss: f32 = if losses.is_empty() {
273 0.0
274 } else {
275 losses.iter().sum::<f32>() / losses.len() as f32
276 };
277 println!("epoch {epoch}: avg loss {avg_loss:.4}");
278 }
279
280 let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
281 model
282 .clone()
283 .save_file(Path::new(ckpt_path), &recorder)
284 .map_err(|e| anyhow::anyhow!("failed to save checkpoint: {e}"))?;
285
286 Ok(())
287}
288
289fn train_convolutional_detector(
290 args: &TrainArgs,
291 samples: &[crate::RunSample],
292 ckpt_path: &str,
293) -> anyhow::Result<()> {
294 let device = <ADBackend as burn::tensor::backend::Backend>::Device::default();
295 let mut model = MultiboxModel::<ADBackend>::new(
296 MultiboxModelConfig {
297 input_dim: Some(4 + 8), max_boxes: args.max_boxes,
299 ..Default::default()
300 },
301 &device,
302 );
303 let mut optim = AdamConfig::new().init();
304
305 let batch_size = args.batch_size.max(1);
306 let data = samples.to_vec();
307 for epoch in 0..args.epochs {
308 let mut losses = Vec::new();
309 for batch in data.chunks(batch_size) {
310 let batch = crate::collate::<ADBackend>(batch, args.max_boxes)?;
311 let boxes = batch.boxes.clone();
313 let first_box = boxes
314 .clone()
315 .slice([0..boxes.dims()[0], 0..1, 0..4])
316 .reshape([boxes.dims()[0], 4]);
317 let features = batch.features.clone();
318 let input = burn::tensor::Tensor::cat(vec![first_box, features], 1);
319
320 let (pred_boxes, pred_scores) = model.forward_multibox(input);
321
322 let gt_boxes = batch.boxes.clone();
324 let gt_mask = batch.box_mask.clone();
325
326 let (obj_targets, box_targets, box_weights) =
328 build_greedy_targets(pred_boxes.clone(), gt_boxes.clone(), gt_mask.clone());
329 let eps = 1e-6;
333 let pred_scores_clamped = pred_scores.clamp(eps, 1.0 - eps);
334 let obj_targets_inv =
335 Tensor::<ADBackend, 2>::ones(obj_targets.dims(), &obj_targets.device())
336 - obj_targets.clone();
337 let obj_loss = -((obj_targets.clone() * pred_scores_clamped.clone().log())
338 + (obj_targets_inv
339 * (Tensor::<ADBackend, 2>::ones(
340 pred_scores_clamped.dims(),
341 &pred_scores_clamped.device(),
342 ) - pred_scores_clamped)
343 .log()))
344 .sum()
345 .div_scalar((obj_targets.dims()[0] * obj_targets.dims()[1]) as f32);
346
347 let box_err = (pred_boxes - box_targets.clone()).abs() * box_weights.clone();
349 let matched = box_weights.clone().sum().div_scalar(4.0);
350 let matched_scalar = matched
351 .into_data()
352 .to_vec::<f32>()
353 .unwrap_or_default()
354 .first()
355 .copied()
356 .unwrap_or(0.0);
357 let box_loss = if matched_scalar > 0.0 {
358 box_err.sum().div_scalar(matched_scalar)
359 } else {
360 let zeros = vec![0.0f32; 1];
362 Tensor::<ADBackend, 1>::from_data(
363 TensorData::new(zeros, [1]),
364 &box_weights.device(),
365 )
366 };
367
368 let loss = box_loss * args.lambda_box + obj_loss * args.lambda_obj;
369 let loss_detached = loss.clone().detach();
370 let grads = GradientsParams::from_grads(loss.backward(), &model);
371 model = optim.step(args.lr as f64, model, grads);
372
373 let loss_val: f32 = loss_detached
374 .into_data()
375 .to_vec::<f32>()
376 .unwrap_or_default()
377 .into_iter()
378 .next()
379 .unwrap_or(0.0);
380 losses.push(loss_val);
381 }
382 let avg_loss: f32 = if losses.is_empty() {
383 0.0
384 } else {
385 losses.iter().sum::<f32>() / losses.len() as f32
386 };
387 println!("epoch {epoch}: avg loss {avg_loss:.4}");
388 }
389
390 let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
391 model
392 .clone()
393 .save_file(Path::new(ckpt_path), &recorder)
394 .map_err(|e| anyhow::anyhow!("failed to save checkpoint: {e}"))?;
395
396 Ok(())
397}
398
399fn train_convolutional_detector_warehouse(
400 args: &TrainArgs,
401 loaders: &WarehouseLoaders,
402 ckpt_path: &str,
403) -> anyhow::Result<()> {
404 let device = <ADBackend as burn::tensor::backend::Backend>::Device::default();
405 let mut model = MultiboxModel::<ADBackend>::new(
406 MultiboxModelConfig {
407 input_dim: Some(4 + 8), max_boxes: args.max_boxes,
409 ..Default::default()
410 },
411 &device,
412 );
413 let mut optim = AdamConfig::new().init();
414
415 let batch_size = args.batch_size.max(1);
416 for epoch in 0..args.epochs {
417 let mut losses = Vec::new();
418 let mut iter = loaders.train_iter();
419 loop {
420 let batch = match iter.next_batch::<ADBackend>(batch_size, &device)? {
421 Some(batch) => batch,
422 None => break,
423 };
424 let batch = crate::collate_from_burn_batch::<ADBackend>(batch, args.max_boxes)?;
425
426 let boxes = batch.boxes.clone();
427 let first_box = boxes
428 .clone()
429 .slice([0..boxes.dims()[0], 0..1, 0..4])
430 .reshape([boxes.dims()[0], 4]);
431 let features = batch.features.clone();
432 let input = burn::tensor::Tensor::cat(vec![first_box, features], 1);
433
434 let (pred_boxes, pred_scores) = model.forward_multibox(input);
435
436 let gt_boxes = batch.boxes.clone();
437 let gt_mask = batch.box_mask.clone();
438
439 let (obj_targets, box_targets, box_weights) =
440 build_greedy_targets(pred_boxes.clone(), gt_boxes.clone(), gt_mask.clone());
441
442 let eps = 1e-6;
443 let pred_scores_clamped = pred_scores.clamp(eps, 1.0 - eps);
444 let obj_targets_inv =
445 Tensor::<ADBackend, 2>::ones(obj_targets.dims(), &obj_targets.device())
446 - obj_targets.clone();
447 let obj_loss = -((obj_targets.clone() * pred_scores_clamped.clone().log())
448 + (obj_targets_inv
449 * (Tensor::<ADBackend, 2>::ones(
450 pred_scores_clamped.dims(),
451 &pred_scores_clamped.device(),
452 ) - pred_scores_clamped)
453 .log()))
454 .sum()
455 .div_scalar((obj_targets.dims()[0] * obj_targets.dims()[1]) as f32);
456
457 let box_err = (pred_boxes - box_targets.clone()).abs() * box_weights.clone();
458 let matched = box_weights.clone().sum().div_scalar(4.0);
459 let matched_scalar = matched
460 .into_data()
461 .to_vec::<f32>()
462 .unwrap_or_default()
463 .first()
464 .copied()
465 .unwrap_or(0.0);
466 let box_loss = if matched_scalar > 0.0 {
467 box_err.sum().div_scalar(matched_scalar)
468 } else {
469 let zeros = vec![0.0f32; 1];
470 Tensor::<ADBackend, 1>::from_data(
471 TensorData::new(zeros, [1]),
472 &box_weights.device(),
473 )
474 };
475
476 let loss = box_loss * args.lambda_box + obj_loss * args.lambda_obj;
477 let loss_detached = loss.clone().detach();
478 let grads = GradientsParams::from_grads(loss.backward(), &model);
479 model = optim.step(args.lr as f64, model, grads);
480
481 let loss_val: f32 = loss_detached
482 .into_data()
483 .to_vec::<f32>()
484 .unwrap_or_default()
485 .into_iter()
486 .next()
487 .unwrap_or(0.0);
488 losses.push(loss_val);
489 }
490 let avg_loss: f32 = if losses.is_empty() {
491 0.0
492 } else {
493 losses.iter().sum::<f32>() / losses.len() as f32
494 };
495 println!("epoch {epoch}: avg loss {avg_loss:.4}");
496 }
497
498 let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
499 model
500 .clone()
501 .save_file(Path::new(ckpt_path), &recorder)
502 .map_err(|e| anyhow::anyhow!("failed to save checkpoint: {e}"))?;
503
504 Ok(())
505}
506
507pub fn validate_backend_choice(kind: BackendKind) -> anyhow::Result<()> {
508 let built_wgpu = cfg!(feature = "backend-wgpu");
509 match (kind, built_wgpu) {
510 (BackendKind::Wgpu, false) => {
511 anyhow::bail!("backend-wgpu feature not enabled; rebuild with --features backend-wgpu or choose ndarray backend")
512 }
513 (BackendKind::NdArray, true) => {
514 println!("note: built with backend-wgpu; training will still use the WGPU backend despite --backend ndarray");
515 }
516 _ => {}
517 }
518 Ok(())
519}
520
521fn iou_xyxy(a: [f32; 4], b: [f32; 4]) -> f32 {
522 let ax0 = a[0].min(a[2]);
523 let ay0 = a[1].min(a[3]);
524 let ax1 = a[0].max(a[2]);
525 let ay1 = a[1].max(a[3]);
526 let bx0 = b[0].min(b[2]);
527 let by0 = b[1].min(b[3]);
528 let bx1 = b[0].max(b[2]);
529 let by1 = b[1].max(b[3]);
530
531 let inter_x0 = ax0.max(bx0);
532 let inter_y0 = ay0.max(by0);
533 let inter_x1 = ax1.min(bx1);
534 let inter_y1 = ay1.min(by1);
535
536 let inter_w = (inter_x1 - inter_x0).max(0.0);
537 let inter_h = (inter_y1 - inter_y0).max(0.0);
538 let inter_area = inter_w * inter_h;
539
540 let area_a = (ax1 - ax0).max(0.0) * (ay1 - ay0).max(0.0);
541 let area_b = (bx1 - bx0).max(0.0) * (by1 - by0).max(0.0);
542 let denom = area_a + area_b - inter_area;
543 if denom <= 0.0 {
544 0.0
545 } else {
546 inter_area / denom
547 }
548}
549pub fn load_multibox_model_from_checkpoint<P: AsRef<Path>>(
550 path: P,
551 device: &<TrainBackend as burn::tensor::backend::Backend>::Device,
552 max_boxes: usize,
553) -> Result<MultiboxModel<TrainBackend>, RecorderError> {
554 let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
555 MultiboxModel::<TrainBackend>::new(
556 MultiboxModelConfig {
557 max_boxes,
558 input_dim: Some(4 + 8),
559 ..Default::default()
560 },
561 device,
562 )
563 .load_file(path.as_ref(), &recorder, device)
564}
565
566pub fn build_greedy_targets<B: burn::tensor::backend::Backend>(
567 pred_boxes: Tensor<B, 3>,
568 gt_boxes: Tensor<B, 3>,
569 gt_mask: Tensor<B, 2>,
570) -> (Tensor<B, 2>, Tensor<B, 3>, Tensor<B, 3>) {
571 let batch = pred_boxes.dims()[0];
572 let max_pred = pred_boxes.dims()[1];
573 let max_gt = gt_boxes.dims()[1];
574
575 let gt_mask_vec = gt_mask
576 .clone()
577 .into_data()
578 .to_vec::<f32>()
579 .unwrap_or_default();
580 let gt_boxes_vec = gt_boxes
581 .clone()
582 .into_data()
583 .to_vec::<f32>()
584 .unwrap_or_default();
585 let pred_boxes_vec = pred_boxes
586 .clone()
587 .into_data()
588 .to_vec::<f32>()
589 .unwrap_or_default();
590
591 let mut obj_targets = vec![0.0f32; batch * max_pred];
592 let mut box_targets = vec![0.0f32; batch * max_pred * 4];
593 let mut box_weights = vec![0.0f32; batch * max_pred * 4];
594
595 for b in 0..batch {
596 for g in 0..max_gt {
597 let mask_idx = b * max_gt + g;
598 if gt_mask_vec.get(mask_idx).copied().unwrap_or(0.0) < 0.5 {
599 continue;
600 }
601 let gb = [
602 gt_boxes_vec[(b * max_gt + g) * 4],
603 gt_boxes_vec[(b * max_gt + g) * 4 + 1],
604 gt_boxes_vec[(b * max_gt + g) * 4 + 2],
605 gt_boxes_vec[(b * max_gt + g) * 4 + 3],
606 ];
607
608 let mut best_iou = -1.0f32;
609 let mut best_p = 0usize;
610 for p in 0..max_pred {
611 let pb = [
612 pred_boxes_vec[(b * max_pred + p) * 4],
613 pred_boxes_vec[(b * max_pred + p) * 4 + 1],
614 pred_boxes_vec[(b * max_pred + p) * 4 + 2],
615 pred_boxes_vec[(b * max_pred + p) * 4 + 3],
616 ];
617 let iou = iou_xyxy(pb, gb);
618 if iou > best_iou {
619 best_iou = iou;
620 best_p = p;
621 }
622 }
623
624 let obj_idx = b * max_pred + best_p;
625 obj_targets[obj_idx] = 1.0;
626 let bt_base = (b * max_pred + best_p) * 4;
627 box_targets[bt_base..bt_base + 4].copy_from_slice(&gb);
628 box_weights[bt_base..bt_base + 4].copy_from_slice(&[1.0, 1.0, 1.0, 1.0]);
629 }
630 }
631
632 let device = &B::Device::default();
633 let obj_targets =
634 Tensor::<B, 2>::from_data(TensorData::new(obj_targets, [batch, max_pred]), device);
635 let box_targets =
636 Tensor::<B, 3>::from_data(TensorData::new(box_targets, [batch, max_pred, 4]), device);
637 let box_weights =
638 Tensor::<B, 3>::from_data(TensorData::new(box_weights, [batch, max_pred, 4]), device);
639
640 (obj_targets, box_targets, box_weights)
641}