training/
util.rs

1use burn::backend::Autodiff;
2use burn::module::Module;
3use burn::nn::loss::{MseLoss, Reduction};
4use burn::optim::{AdamConfig, GradientsParams, Optimizer};
5use burn::record::{BinFileRecorder, FullPrecisionSettings, RecorderError};
6use burn::tensor::{Tensor, TensorData};
7use burn_dataset::WarehouseLoaders;
8use std::path::Path;
9
10use crate::{BigDet, BigDetConfig, DatasetConfig, TinyDet, TinyDetConfig, TrainBackend};
11use clap::{Parser, ValueEnum};
12use std::fs;
13
14pub fn load_tinydet_from_checkpoint<P: AsRef<Path>>(
15    path: P,
16    device: &<TrainBackend as burn::tensor::backend::Backend>::Device,
17) -> Result<TinyDet<TrainBackend>, RecorderError> {
18    let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
19    TinyDet::<TrainBackend>::new(TinyDetConfig::default(), device).load_file(
20        path.as_ref(),
21        &recorder,
22        device,
23    )
24}
25
26#[derive(ValueEnum, Debug, Clone, Copy)]
27pub enum ModelKind {
28    Tiny,
29    Big,
30}
31
32#[derive(ValueEnum, Debug, Clone, Copy)]
33pub enum BackendKind {
34    NdArray,
35    Wgpu,
36}
37
38#[derive(ValueEnum, Debug, Clone, Copy)]
39pub enum TrainingInputSource {
40    Warehouse,
41    CaptureLogs,
42}
43
44#[derive(Parser, Debug)]
45#[command(name = "train", about = "Train TinyDet/BigDet (warehouse-first)")]
46pub struct TrainArgs {
47    /// Model to train.
48    #[arg(long, value_enum, default_value_t = ModelKind::Tiny)]
49    pub model: ModelKind,
50    /// Backend to use (ndarray or wgpu if enabled).
51    #[arg(long, value_enum, default_value_t = BackendKind::NdArray)]
52    pub backend: BackendKind,
53    /// Maximum boxes per image (pads/truncates to this for training).
54    #[arg(long, default_value_t = 64)]
55    pub max_boxes: usize,
56    /// Loss weight for box regression.
57    #[arg(long, default_value_t = 1.0)]
58    pub lambda_box: f32,
59    /// Loss weight for objectness.
60    #[arg(long, default_value_t = 1.0)]
61    pub lambda_obj: f32,
62    /// Training input source (warehouse by default).
63    #[arg(long, value_enum, default_value_t = TrainingInputSource::Warehouse)]
64    pub input_source: TrainingInputSource,
65    /// Warehouse manifest path (used with --input-source warehouse).
66    #[arg(long, default_value = "assets/warehouse/manifest.json")]
67    pub warehouse_manifest: String,
68    /// Capture-log dataset root containing labels/ and images/.
69    #[arg(long, default_value = "assets/datasets/captures_filtered")]
70    pub dataset_root: String,
71    /// Labels subdirectory relative to dataset root (capture-logs only).
72    #[arg(long, default_value = "labels")]
73    pub labels_subdir: String,
74    /// Images subdirectory relative to dataset root (capture-logs only).
75    #[arg(long, default_value = ".")]
76    pub images_subdir: String,
77    /// Number of epochs.
78    #[arg(long, default_value_t = 1)]
79    pub epochs: usize,
80    /// Batch size.
81    #[arg(long, default_value_t = 1)]
82    pub batch_size: usize,
83    /// Learning rate.
84    #[arg(long, default_value_t = 1e-3)]
85    pub lr: f32,
86    /// Objectness threshold (for future eval).
87    #[arg(long, default_value_t = 0.3)]
88    pub infer_obj_thresh: f32,
89    /// IoU threshold (for future eval).
90    #[arg(long, default_value_t = 0.5)]
91    pub infer_iou_thresh: f32,
92    /// Checkpoint output path (defaults by model if not provided).
93    #[arg(long)]
94    pub checkpoint_out: Option<String>,
95}
96
97pub fn run_train(args: TrainArgs) -> anyhow::Result<()> {
98    validate_backend_choice(args.backend)?;
99
100    let ckpt_path = args
101        .checkpoint_out
102        .clone()
103        .unwrap_or_else(|| match args.model {
104            ModelKind::Tiny => "checkpoints/tinydet.bin".to_string(),
105            ModelKind::Big => "checkpoints/bigdet.bin".to_string(),
106        });
107
108    if let Some(parent) = Path::new(&ckpt_path).parent() {
109        fs::create_dir_all(parent)?;
110    }
111
112    match args.input_source {
113        TrainingInputSource::Warehouse => {
114            let manifest_path = Path::new(&args.warehouse_manifest);
115            let loaders = WarehouseLoaders::from_manifest_path(manifest_path, 0.0, None, false)
116                .map_err(|e| {
117                    anyhow::anyhow!(
118                        "failed to load warehouse manifest at {}: {e}",
119                        manifest_path.display()
120                    )
121                })?;
122            if loaders.train_len() == 0 {
123                anyhow::bail!(
124                    "warehouse manifest {} contains no training shards",
125                    manifest_path.display()
126                );
127            }
128            match args.model {
129                ModelKind::Tiny => train_tinydet_warehouse(&args, &loaders, &ckpt_path)?,
130                ModelKind::Big => train_bigdet_warehouse(&args, &loaders, &ckpt_path)?,
131            }
132        }
133        TrainingInputSource::CaptureLogs => {
134            println!("training from capture logs (legacy path); prefer warehouse manifests");
135            let cfg = DatasetConfig {
136                root: args.dataset_root.clone().into(),
137                labels_subdir: args.labels_subdir.clone(),
138                images_subdir: args.images_subdir.clone(),
139            };
140            let samples = cfg.load()?;
141            if samples.is_empty() {
142                println!("No samples found under {}", cfg.root.display());
143                return Ok(());
144            }
145            match args.model {
146                ModelKind::Tiny => train_tinydet(&args, &samples, &ckpt_path)?,
147                ModelKind::Big => train_bigdet(&args, &samples, &ckpt_path)?,
148            }
149        }
150    }
151
152    println!("Saved checkpoint to {}", ckpt_path);
153    Ok(())
154}
155
156type ADBackend = Autodiff<TrainBackend>;
157
158fn train_tinydet(
159    args: &TrainArgs,
160    samples: &[crate::RunSample],
161    ckpt_path: &str,
162) -> anyhow::Result<()> {
163    let device = <ADBackend as burn::tensor::backend::Backend>::Device::default();
164    let mut model = TinyDet::<ADBackend>::new(TinyDetConfig::default(), &device);
165    let mut optim = AdamConfig::new().init();
166
167    let batch_size = args.batch_size.max(1);
168    let data = samples.to_vec();
169    for epoch in 0..args.epochs {
170        let mut losses = Vec::new();
171        for batch in data.chunks(batch_size) {
172            let batch = crate::collate::<ADBackend>(batch, args.max_boxes)?;
173            // Feature: take the first box (or zeros) as the input vector.
174            let boxes = batch.boxes.clone();
175            let first_box = boxes
176                .clone()
177                .slice([0..boxes.dims()[0], 0..1, 0..4])
178                .reshape([boxes.dims()[0], 4]);
179
180            // Target: 1.0 if any box present, else 0.0.
181            let mask = batch.box_mask.clone();
182            let has_box = mask.clone().sum_dim(1).reshape([mask.dims()[0], 1]);
183
184            let preds = model.forward(first_box);
185            let mse = MseLoss::new();
186            let loss = mse.forward(preds, has_box, Reduction::Mean);
187            let loss_detached = loss.clone().detach();
188            let grads = GradientsParams::from_grads(loss.backward(), &model);
189            model = optim.step(args.lr as f64, model, grads);
190
191            let loss_val: f32 = loss_detached
192                .into_data()
193                .to_vec::<f32>()
194                .unwrap_or_default()
195                .into_iter()
196                .next()
197                .unwrap_or(0.0);
198            losses.push(loss_val);
199        }
200        let avg_loss: f32 = if losses.is_empty() {
201            0.0
202        } else {
203            losses.iter().sum::<f32>() / losses.len() as f32
204        };
205        println!("epoch {epoch}: avg loss {avg_loss:.4}");
206    }
207
208    let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
209    model
210        .clone()
211        .save_file(Path::new(ckpt_path), &recorder)
212        .map_err(|e| anyhow::anyhow!("failed to save checkpoint: {e}"))?;
213
214    Ok(())
215}
216
217fn train_tinydet_warehouse(
218    args: &TrainArgs,
219    loaders: &WarehouseLoaders,
220    ckpt_path: &str,
221) -> anyhow::Result<()> {
222    let device = <ADBackend as burn::tensor::backend::Backend>::Device::default();
223    let mut model = TinyDet::<ADBackend>::new(TinyDetConfig::default(), &device);
224    let mut optim = AdamConfig::new().init();
225
226    let batch_size = args.batch_size.max(1);
227    for epoch in 0..args.epochs {
228        let mut losses = Vec::new();
229        let mut iter = loaders.train_iter();
230        loop {
231            let batch = match iter.next_batch::<ADBackend>(batch_size, &device)? {
232                Some(batch) => batch,
233                None => break,
234            };
235            let batch = crate::collate_from_burn_batch::<ADBackend>(batch, args.max_boxes)?;
236
237            // Feature: take the first box (or zeros) as the input vector.
238            let boxes = batch.boxes.clone();
239            let first_box = boxes
240                .clone()
241                .slice([0..boxes.dims()[0], 0..1, 0..4])
242                .reshape([boxes.dims()[0], 4]);
243
244            // Target: 1.0 if any box present, else 0.0.
245            let mask = batch.box_mask.clone();
246            let has_box = mask.clone().sum_dim(1).reshape([mask.dims()[0], 1]);
247
248            let preds = model.forward(first_box);
249            let mse = MseLoss::new();
250            let loss = mse.forward(preds, has_box, Reduction::Mean);
251            let loss_detached = loss.clone().detach();
252            let grads = GradientsParams::from_grads(loss.backward(), &model);
253            model = optim.step(args.lr as f64, model, grads);
254
255            let loss_val: f32 = loss_detached
256                .into_data()
257                .to_vec::<f32>()
258                .unwrap_or_default()
259                .into_iter()
260                .next()
261                .unwrap_or(0.0);
262            losses.push(loss_val);
263        }
264        let avg_loss: f32 = if losses.is_empty() {
265            0.0
266        } else {
267            losses.iter().sum::<f32>() / losses.len() as f32
268        };
269        println!("epoch {epoch}: avg loss {avg_loss:.4}");
270    }
271
272    let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
273    model
274        .clone()
275        .save_file(Path::new(ckpt_path), &recorder)
276        .map_err(|e| anyhow::anyhow!("failed to save checkpoint: {e}"))?;
277
278    Ok(())
279}
280
281fn train_bigdet(
282    args: &TrainArgs,
283    samples: &[crate::RunSample],
284    ckpt_path: &str,
285) -> anyhow::Result<()> {
286    let device = <ADBackend as burn::tensor::backend::Backend>::Device::default();
287    let mut model = BigDet::<ADBackend>::new(
288        BigDetConfig {
289            input_dim: Some(4 + 8), // first box (4) + features (8)
290            max_boxes: args.max_boxes,
291            ..Default::default()
292        },
293        &device,
294    );
295    let mut optim = AdamConfig::new().init();
296
297    let batch_size = args.batch_size.max(1);
298    let data = samples.to_vec();
299    for epoch in 0..args.epochs {
300        let mut losses = Vec::new();
301        for batch in data.chunks(batch_size) {
302            let batch = crate::collate::<ADBackend>(batch, args.max_boxes)?;
303            // Features: first box (or zeros) + pooled image features.
304            let boxes = batch.boxes.clone();
305            let first_box = boxes
306                .clone()
307                .slice([0..boxes.dims()[0], 0..1, 0..4])
308                .reshape([boxes.dims()[0], 4]);
309            let features = batch.features.clone();
310            let input = burn::tensor::Tensor::cat(vec![first_box, features], 1);
311
312            let (pred_boxes, pred_scores) = model.forward_multibox(input);
313
314            // Targets
315            let gt_boxes = batch.boxes.clone();
316            let gt_mask = batch.box_mask.clone();
317
318            // Greedy matching per GT: for each GT box, pick best pred by IoU.
319            let (obj_targets, box_targets, box_weights) =
320                build_greedy_targets(pred_boxes.clone(), gt_boxes.clone(), gt_mask.clone());
321            // Greedy IoU matcher is deterministic/cheap; swap to Hungarian if finer matching is needed later.
322
323            // Objectness loss (BCE) with targets; unassigned preds stay at 0.0.
324            let eps = 1e-6;
325            let pred_scores_clamped = pred_scores.clamp(eps, 1.0 - eps);
326            let obj_targets_inv =
327                Tensor::<ADBackend, 2>::ones(obj_targets.dims(), &obj_targets.device())
328                    - obj_targets.clone();
329            let obj_loss = -((obj_targets.clone() * pred_scores_clamped.clone().log())
330                + (obj_targets_inv
331                    * (Tensor::<ADBackend, 2>::ones(
332                        pred_scores_clamped.dims(),
333                        &pred_scores_clamped.device(),
334                    ) - pred_scores_clamped)
335                        .log()))
336            .sum()
337            .div_scalar((obj_targets.dims()[0] * obj_targets.dims()[1]) as f32);
338
339            // Box regression loss on matched preds only.
340            let box_err = (pred_boxes - box_targets.clone()).abs() * box_weights.clone();
341            let matched = box_weights.clone().sum().div_scalar(4.0);
342            let matched_scalar = matched
343                .into_data()
344                .to_vec::<f32>()
345                .unwrap_or_default()
346                .first()
347                .copied()
348                .unwrap_or(0.0);
349            let box_loss = if matched_scalar > 0.0 {
350                box_err.sum().div_scalar(matched_scalar)
351            } else {
352                // Return a zero scalar in the same tensor rank as div output (rank 1).
353                let zeros = vec![0.0f32; 1];
354                Tensor::<ADBackend, 1>::from_data(
355                    TensorData::new(zeros, [1]),
356                    &box_weights.device(),
357                )
358            };
359
360            let loss = box_loss * args.lambda_box + obj_loss * args.lambda_obj;
361            let loss_detached = loss.clone().detach();
362            let grads = GradientsParams::from_grads(loss.backward(), &model);
363            model = optim.step(args.lr as f64, model, grads);
364
365            let loss_val: f32 = loss_detached
366                .into_data()
367                .to_vec::<f32>()
368                .unwrap_or_default()
369                .into_iter()
370                .next()
371                .unwrap_or(0.0);
372            losses.push(loss_val);
373        }
374        let avg_loss: f32 = if losses.is_empty() {
375            0.0
376        } else {
377            losses.iter().sum::<f32>() / losses.len() as f32
378        };
379        println!("epoch {epoch}: avg loss {avg_loss:.4}");
380    }
381
382    let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
383    model
384        .clone()
385        .save_file(Path::new(ckpt_path), &recorder)
386        .map_err(|e| anyhow::anyhow!("failed to save checkpoint: {e}"))?;
387
388    Ok(())
389}
390
391fn train_bigdet_warehouse(
392    args: &TrainArgs,
393    loaders: &WarehouseLoaders,
394    ckpt_path: &str,
395) -> anyhow::Result<()> {
396    let device = <ADBackend as burn::tensor::backend::Backend>::Device::default();
397    let mut model = BigDet::<ADBackend>::new(
398        BigDetConfig {
399            input_dim: Some(4 + 8), // first box (4) + features (8)
400            max_boxes: args.max_boxes,
401            ..Default::default()
402        },
403        &device,
404    );
405    let mut optim = AdamConfig::new().init();
406
407    let batch_size = args.batch_size.max(1);
408    for epoch in 0..args.epochs {
409        let mut losses = Vec::new();
410        let mut iter = loaders.train_iter();
411        loop {
412            let batch = match iter.next_batch::<ADBackend>(batch_size, &device)? {
413                Some(batch) => batch,
414                None => break,
415            };
416            let batch = crate::collate_from_burn_batch::<ADBackend>(batch, args.max_boxes)?;
417
418            let boxes = batch.boxes.clone();
419            let first_box = boxes
420                .clone()
421                .slice([0..boxes.dims()[0], 0..1, 0..4])
422                .reshape([boxes.dims()[0], 4]);
423            let features = batch.features.clone();
424            let input = burn::tensor::Tensor::cat(vec![first_box, features], 1);
425
426            let (pred_boxes, pred_scores) = model.forward_multibox(input);
427
428            let gt_boxes = batch.boxes.clone();
429            let gt_mask = batch.box_mask.clone();
430
431            let (obj_targets, box_targets, box_weights) =
432                build_greedy_targets(pred_boxes.clone(), gt_boxes.clone(), gt_mask.clone());
433
434            let eps = 1e-6;
435            let pred_scores_clamped = pred_scores.clamp(eps, 1.0 - eps);
436            let obj_targets_inv =
437                Tensor::<ADBackend, 2>::ones(obj_targets.dims(), &obj_targets.device())
438                    - obj_targets.clone();
439            let obj_loss = -((obj_targets.clone() * pred_scores_clamped.clone().log())
440                + (obj_targets_inv
441                    * (Tensor::<ADBackend, 2>::ones(
442                        pred_scores_clamped.dims(),
443                        &pred_scores_clamped.device(),
444                    ) - pred_scores_clamped)
445                        .log()))
446            .sum()
447            .div_scalar((obj_targets.dims()[0] * obj_targets.dims()[1]) as f32);
448
449            let box_err = (pred_boxes - box_targets.clone()).abs() * box_weights.clone();
450            let matched = box_weights.clone().sum().div_scalar(4.0);
451            let matched_scalar = matched
452                .into_data()
453                .to_vec::<f32>()
454                .unwrap_or_default()
455                .first()
456                .copied()
457                .unwrap_or(0.0);
458            let box_loss = if matched_scalar > 0.0 {
459                box_err.sum().div_scalar(matched_scalar)
460            } else {
461                let zeros = vec![0.0f32; 1];
462                Tensor::<ADBackend, 1>::from_data(
463                    TensorData::new(zeros, [1]),
464                    &box_weights.device(),
465                )
466            };
467
468            let loss = box_loss * args.lambda_box + obj_loss * args.lambda_obj;
469            let loss_detached = loss.clone().detach();
470            let grads = GradientsParams::from_grads(loss.backward(), &model);
471            model = optim.step(args.lr as f64, model, grads);
472
473            let loss_val: f32 = loss_detached
474                .into_data()
475                .to_vec::<f32>()
476                .unwrap_or_default()
477                .into_iter()
478                .next()
479                .unwrap_or(0.0);
480            losses.push(loss_val);
481        }
482        let avg_loss: f32 = if losses.is_empty() {
483            0.0
484        } else {
485            losses.iter().sum::<f32>() / losses.len() as f32
486        };
487        println!("epoch {epoch}: avg loss {avg_loss:.4}");
488    }
489
490    let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
491    model
492        .clone()
493        .save_file(Path::new(ckpt_path), &recorder)
494        .map_err(|e| anyhow::anyhow!("failed to save checkpoint: {e}"))?;
495
496    Ok(())
497}
498
499pub fn validate_backend_choice(kind: BackendKind) -> anyhow::Result<()> {
500    let built_wgpu = cfg!(feature = "backend-wgpu");
501    match (kind, built_wgpu) {
502        (BackendKind::Wgpu, false) => {
503            anyhow::bail!("backend-wgpu feature not enabled; rebuild with --features backend-wgpu or choose ndarray backend")
504        }
505        (BackendKind::NdArray, true) => {
506            println!("note: built with backend-wgpu; training will still use the WGPU backend despite --backend ndarray");
507        }
508        _ => {}
509    }
510    Ok(())
511}
512
513fn iou_xyxy(a: [f32; 4], b: [f32; 4]) -> f32 {
514    let ax0 = a[0].min(a[2]);
515    let ay0 = a[1].min(a[3]);
516    let ax1 = a[0].max(a[2]);
517    let ay1 = a[1].max(a[3]);
518    let bx0 = b[0].min(b[2]);
519    let by0 = b[1].min(b[3]);
520    let bx1 = b[0].max(b[2]);
521    let by1 = b[1].max(b[3]);
522
523    let inter_x0 = ax0.max(bx0);
524    let inter_y0 = ay0.max(by0);
525    let inter_x1 = ax1.min(bx1);
526    let inter_y1 = ay1.min(by1);
527
528    let inter_w = (inter_x1 - inter_x0).max(0.0);
529    let inter_h = (inter_y1 - inter_y0).max(0.0);
530    let inter_area = inter_w * inter_h;
531
532    let area_a = (ax1 - ax0).max(0.0) * (ay1 - ay0).max(0.0);
533    let area_b = (bx1 - bx0).max(0.0) * (by1 - by0).max(0.0);
534    let denom = area_a + area_b - inter_area;
535    if denom <= 0.0 {
536        0.0
537    } else {
538        inter_area / denom
539    }
540}
541pub fn load_bigdet_from_checkpoint<P: AsRef<Path>>(
542    path: P,
543    device: &<TrainBackend as burn::tensor::backend::Backend>::Device,
544    max_boxes: usize,
545) -> Result<BigDet<TrainBackend>, RecorderError> {
546    let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
547    BigDet::<TrainBackend>::new(
548        BigDetConfig {
549            max_boxes,
550            input_dim: Some(4 + 8),
551            ..Default::default()
552        },
553        device,
554    )
555    .load_file(path.as_ref(), &recorder, device)
556}
557
558pub fn build_greedy_targets<B: burn::tensor::backend::Backend>(
559    pred_boxes: Tensor<B, 3>,
560    gt_boxes: Tensor<B, 3>,
561    gt_mask: Tensor<B, 2>,
562) -> (Tensor<B, 2>, Tensor<B, 3>, Tensor<B, 3>) {
563    let batch = pred_boxes.dims()[0];
564    let max_pred = pred_boxes.dims()[1];
565    let max_gt = gt_boxes.dims()[1];
566
567    let gt_mask_vec = gt_mask
568        .clone()
569        .into_data()
570        .to_vec::<f32>()
571        .unwrap_or_default();
572    let gt_boxes_vec = gt_boxes
573        .clone()
574        .into_data()
575        .to_vec::<f32>()
576        .unwrap_or_default();
577    let pred_boxes_vec = pred_boxes
578        .clone()
579        .into_data()
580        .to_vec::<f32>()
581        .unwrap_or_default();
582
583    let mut obj_targets = vec![0.0f32; batch * max_pred];
584    let mut box_targets = vec![0.0f32; batch * max_pred * 4];
585    let mut box_weights = vec![0.0f32; batch * max_pred * 4];
586
587    for b in 0..batch {
588        for g in 0..max_gt {
589            let mask_idx = b * max_gt + g;
590            if gt_mask_vec.get(mask_idx).copied().unwrap_or(0.0) < 0.5 {
591                continue;
592            }
593            let gb = [
594                gt_boxes_vec[(b * max_gt + g) * 4],
595                gt_boxes_vec[(b * max_gt + g) * 4 + 1],
596                gt_boxes_vec[(b * max_gt + g) * 4 + 2],
597                gt_boxes_vec[(b * max_gt + g) * 4 + 3],
598            ];
599
600            let mut best_iou = -1.0f32;
601            let mut best_p = 0usize;
602            for p in 0..max_pred {
603                let pb = [
604                    pred_boxes_vec[(b * max_pred + p) * 4],
605                    pred_boxes_vec[(b * max_pred + p) * 4 + 1],
606                    pred_boxes_vec[(b * max_pred + p) * 4 + 2],
607                    pred_boxes_vec[(b * max_pred + p) * 4 + 3],
608                ];
609                let iou = iou_xyxy(pb, gb);
610                if iou > best_iou {
611                    best_iou = iou;
612                    best_p = p;
613                }
614            }
615
616            let obj_idx = b * max_pred + best_p;
617            obj_targets[obj_idx] = 1.0;
618            let bt_base = (b * max_pred + best_p) * 4;
619            box_targets[bt_base..bt_base + 4].copy_from_slice(&gb);
620            box_weights[bt_base..bt_base + 4].copy_from_slice(&[1.0, 1.0, 1.0, 1.0]);
621        }
622    }
623
624    let device = &B::Device::default();
625    let obj_targets =
626        Tensor::<B, 2>::from_data(TensorData::new(obj_targets, [batch, max_pred]), device);
627    let box_targets =
628        Tensor::<B, 3>::from_data(TensorData::new(box_targets, [batch, max_pred, 4]), device);
629    let box_weights =
630        Tensor::<B, 3>::from_data(TensorData::new(box_weights, [batch, max_pred, 4]), device);
631
632    (obj_targets, box_targets, box_weights)
633}