training/
util.rs

1use burn::backend::Autodiff;
2use burn::module::Module;
3use burn::nn::loss::{MseLoss, Reduction};
4use burn::optim::{AdamConfig, GradientsParams, Optimizer};
5use burn::record::{BinFileRecorder, FullPrecisionSettings, RecorderError};
6use burn::tensor::{Tensor, TensorData};
7use burn_dataset::WarehouseLoaders;
8use std::path::Path;
9
10use crate::{
11    DatasetPathConfig, LinearClassifier, LinearClassifierConfig, MultiboxModel,
12    MultiboxModelConfig, TrainBackend,
13};
14use clap::{Parser, ValueEnum};
15use std::fs;
16
17pub fn load_linear_classifier_from_checkpoint<P: AsRef<Path>>(
18    path: P,
19    device: &<TrainBackend as burn::tensor::backend::Backend>::Device,
20) -> Result<LinearClassifier<TrainBackend>, RecorderError> {
21    let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
22    LinearClassifier::<TrainBackend>::new(LinearClassifierConfig::default(), device).load_file(
23        path.as_ref(),
24        &recorder,
25        device,
26    )
27}
28
29#[derive(ValueEnum, Debug, Clone, Copy)]
30pub enum ModelKind {
31    Tiny,
32    Big,
33}
34
35#[derive(ValueEnum, Debug, Clone, Copy)]
36pub enum BackendKind {
37    NdArray,
38    Wgpu,
39}
40
41#[derive(ValueEnum, Debug, Clone, Copy)]
42pub enum TrainingInputSource {
43    Warehouse,
44    CaptureLogs,
45}
46
47#[derive(Parser, Debug)]
48#[command(
49    name = "train",
50    about = "Train LinearClassifier/MultiboxModel (warehouse-first)"
51)]
52pub struct TrainArgs {
53    /// Model to train.
54    #[arg(long, value_enum, default_value_t = ModelKind::Tiny)]
55    pub model: ModelKind,
56    /// Backend to use (ndarray or wgpu if enabled).
57    #[arg(long, value_enum, default_value_t = BackendKind::NdArray)]
58    pub backend: BackendKind,
59    /// Maximum boxes per image (pads/truncates to this for training).
60    #[arg(long, default_value_t = 64)]
61    pub max_boxes: usize,
62    /// Loss weight for box regression.
63    #[arg(long, default_value_t = 1.0)]
64    pub lambda_box: f32,
65    /// Loss weight for objectness.
66    #[arg(long, default_value_t = 1.0)]
67    pub lambda_obj: f32,
68    /// Training input source (warehouse by default).
69    #[arg(long, value_enum, default_value_t = TrainingInputSource::Warehouse)]
70    pub input_source: TrainingInputSource,
71    /// Warehouse manifest path (used with --input-source warehouse).
72    #[arg(long, default_value = "assets/warehouse/manifest.json")]
73    pub warehouse_manifest: String,
74    /// Capture-log dataset root containing labels/ and images/.
75    #[arg(long, default_value = "assets/datasets/captures_filtered")]
76    pub dataset_root: String,
77    /// Labels subdirectory relative to dataset root (capture-logs only).
78    #[arg(long, default_value = "labels")]
79    pub labels_subdir: String,
80    /// Images subdirectory relative to dataset root (capture-logs only).
81    #[arg(long, default_value = ".")]
82    pub images_subdir: String,
83    /// Number of epochs.
84    #[arg(long, default_value_t = 1)]
85    pub epochs: usize,
86    /// Batch size.
87    #[arg(long, default_value_t = 1)]
88    pub batch_size: usize,
89    /// Learning rate.
90    #[arg(long, default_value_t = 1e-3)]
91    pub lr: f32,
92    /// Objectness threshold (for future eval).
93    #[arg(long, default_value_t = 0.3)]
94    pub infer_obj_thresh: f32,
95    /// IoU threshold (for future eval).
96    #[arg(long, default_value_t = 0.5)]
97    pub infer_iou_thresh: f32,
98    /// Checkpoint output path (defaults by model if not provided).
99    #[arg(long)]
100    pub checkpoint_out: Option<String>,
101}
102
103pub fn run_train(args: TrainArgs) -> anyhow::Result<()> {
104    validate_backend_choice(args.backend)?;
105
106    let ckpt_path = args
107        .checkpoint_out
108        .clone()
109        .unwrap_or_else(|| match args.model {
110            ModelKind::Tiny => "checkpoints/linear_detector.bin".to_string(),
111            ModelKind::Big => "checkpoints/convolutional_detector.bin".to_string(),
112        });
113
114    if let Some(parent) = Path::new(&ckpt_path).parent() {
115        fs::create_dir_all(parent)?;
116    }
117
118    match args.input_source {
119        TrainingInputSource::Warehouse => {
120            let manifest_path = Path::new(&args.warehouse_manifest);
121            let loaders = WarehouseLoaders::from_manifest_path(manifest_path, 0.0, None, false)
122                .map_err(|e| {
123                    anyhow::anyhow!(
124                        "failed to load warehouse manifest at {}: {e}",
125                        manifest_path.display()
126                    )
127                })?;
128            if loaders.train_len() == 0 {
129                anyhow::bail!(
130                    "warehouse manifest {} contains no training shards",
131                    manifest_path.display()
132                );
133            }
134            match args.model {
135                ModelKind::Tiny => train_linear_detector_warehouse(&args, &loaders, &ckpt_path)?,
136                ModelKind::Big => {
137                    train_convolutional_detector_warehouse(&args, &loaders, &ckpt_path)?
138                }
139            }
140        }
141        TrainingInputSource::CaptureLogs => {
142            println!("training from capture logs (legacy path); prefer warehouse manifests");
143            let cfg = DatasetPathConfig {
144                root: args.dataset_root.clone().into(),
145                labels_subdir: args.labels_subdir.clone(),
146                images_subdir: args.images_subdir.clone(),
147            };
148            let samples = cfg.load()?;
149            if samples.is_empty() {
150                println!("No samples found under {}", cfg.root.display());
151                return Ok(());
152            }
153            match args.model {
154                ModelKind::Tiny => train_linear_detector(&args, &samples, &ckpt_path)?,
155                ModelKind::Big => train_convolutional_detector(&args, &samples, &ckpt_path)?,
156            }
157        }
158    }
159
160    println!("Saved checkpoint to {}", ckpt_path);
161    Ok(())
162}
163
164type ADBackend = Autodiff<TrainBackend>;
165
166fn train_linear_detector(
167    args: &TrainArgs,
168    samples: &[crate::RunSample],
169    ckpt_path: &str,
170) -> anyhow::Result<()> {
171    let device = <ADBackend as burn::tensor::backend::Backend>::Device::default();
172    let mut model = LinearClassifier::<ADBackend>::new(LinearClassifierConfig::default(), &device);
173    let mut optim = AdamConfig::new().init();
174
175    let batch_size = args.batch_size.max(1);
176    let data = samples.to_vec();
177    for epoch in 0..args.epochs {
178        let mut losses = Vec::new();
179        for batch in data.chunks(batch_size) {
180            let batch = crate::collate::<ADBackend>(batch, args.max_boxes)?;
181            // Feature: take the first box (or zeros) as the input vector.
182            let boxes = batch.boxes.clone();
183            let first_box = boxes
184                .clone()
185                .slice([0..boxes.dims()[0], 0..1, 0..4])
186                .reshape([boxes.dims()[0], 4]);
187
188            // Target: 1.0 if any box present, else 0.0.
189            let mask = batch.box_mask.clone();
190            let has_box = mask.clone().sum_dim(1).reshape([mask.dims()[0], 1]);
191
192            let preds = model.forward(first_box);
193            let mse = MseLoss::new();
194            let loss = mse.forward(preds, has_box, Reduction::Mean);
195            let loss_detached = loss.clone().detach();
196            let grads = GradientsParams::from_grads(loss.backward(), &model);
197            model = optim.step(args.lr as f64, model, grads);
198
199            let loss_val: f32 = loss_detached
200                .into_data()
201                .to_vec::<f32>()
202                .unwrap_or_default()
203                .into_iter()
204                .next()
205                .unwrap_or(0.0);
206            losses.push(loss_val);
207        }
208        let avg_loss: f32 = if losses.is_empty() {
209            0.0
210        } else {
211            losses.iter().sum::<f32>() / losses.len() as f32
212        };
213        println!("epoch {epoch}: avg loss {avg_loss:.4}");
214    }
215
216    let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
217    model
218        .clone()
219        .save_file(Path::new(ckpt_path), &recorder)
220        .map_err(|e| anyhow::anyhow!("failed to save checkpoint: {e}"))?;
221
222    Ok(())
223}
224
225fn train_linear_detector_warehouse(
226    args: &TrainArgs,
227    loaders: &WarehouseLoaders,
228    ckpt_path: &str,
229) -> anyhow::Result<()> {
230    let device = <ADBackend as burn::tensor::backend::Backend>::Device::default();
231    let mut model = LinearClassifier::<ADBackend>::new(LinearClassifierConfig::default(), &device);
232    let mut optim = AdamConfig::new().init();
233
234    let batch_size = args.batch_size.max(1);
235    for epoch in 0..args.epochs {
236        let mut losses = Vec::new();
237        let mut iter = loaders.train_iter();
238        loop {
239            let batch = match iter.next_batch::<ADBackend>(batch_size, &device)? {
240                Some(batch) => batch,
241                None => break,
242            };
243            let batch = crate::collate_from_burn_batch::<ADBackend>(batch, args.max_boxes)?;
244
245            // Feature: take the first box (or zeros) as the input vector.
246            let boxes = batch.boxes.clone();
247            let first_box = boxes
248                .clone()
249                .slice([0..boxes.dims()[0], 0..1, 0..4])
250                .reshape([boxes.dims()[0], 4]);
251
252            // Target: 1.0 if any box present, else 0.0.
253            let mask = batch.box_mask.clone();
254            let has_box = mask.clone().sum_dim(1).reshape([mask.dims()[0], 1]);
255
256            let preds = model.forward(first_box);
257            let mse = MseLoss::new();
258            let loss = mse.forward(preds, has_box, Reduction::Mean);
259            let loss_detached = loss.clone().detach();
260            let grads = GradientsParams::from_grads(loss.backward(), &model);
261            model = optim.step(args.lr as f64, model, grads);
262
263            let loss_val: f32 = loss_detached
264                .into_data()
265                .to_vec::<f32>()
266                .unwrap_or_default()
267                .into_iter()
268                .next()
269                .unwrap_or(0.0);
270            losses.push(loss_val);
271        }
272        let avg_loss: f32 = if losses.is_empty() {
273            0.0
274        } else {
275            losses.iter().sum::<f32>() / losses.len() as f32
276        };
277        println!("epoch {epoch}: avg loss {avg_loss:.4}");
278    }
279
280    let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
281    model
282        .clone()
283        .save_file(Path::new(ckpt_path), &recorder)
284        .map_err(|e| anyhow::anyhow!("failed to save checkpoint: {e}"))?;
285
286    Ok(())
287}
288
289fn train_convolutional_detector(
290    args: &TrainArgs,
291    samples: &[crate::RunSample],
292    ckpt_path: &str,
293) -> anyhow::Result<()> {
294    let device = <ADBackend as burn::tensor::backend::Backend>::Device::default();
295    let mut model = MultiboxModel::<ADBackend>::new(
296        MultiboxModelConfig {
297            input_dim: Some(4 + 8), // first box (4) + features (8)
298            max_boxes: args.max_boxes,
299            ..Default::default()
300        },
301        &device,
302    );
303    let mut optim = AdamConfig::new().init();
304
305    let batch_size = args.batch_size.max(1);
306    let data = samples.to_vec();
307    for epoch in 0..args.epochs {
308        let mut losses = Vec::new();
309        for batch in data.chunks(batch_size) {
310            let batch = crate::collate::<ADBackend>(batch, args.max_boxes)?;
311            // Features: first box (or zeros) + pooled image features.
312            let boxes = batch.boxes.clone();
313            let first_box = boxes
314                .clone()
315                .slice([0..boxes.dims()[0], 0..1, 0..4])
316                .reshape([boxes.dims()[0], 4]);
317            let features = batch.features.clone();
318            let input = burn::tensor::Tensor::cat(vec![first_box, features], 1);
319
320            let (pred_boxes, pred_scores) = model.forward_multibox(input);
321
322            // Targets
323            let gt_boxes = batch.boxes.clone();
324            let gt_mask = batch.box_mask.clone();
325
326            // Greedy matching per GT: for each GT box, pick best pred by IoU.
327            let (obj_targets, box_targets, box_weights) =
328                build_greedy_targets(pred_boxes.clone(), gt_boxes.clone(), gt_mask.clone());
329            // Greedy IoU matcher is deterministic/cheap; swap to Hungarian if finer matching is needed later.
330
331            // Objectness loss (BCE) with targets; unassigned preds stay at 0.0.
332            let eps = 1e-6;
333            let pred_scores_clamped = pred_scores.clamp(eps, 1.0 - eps);
334            let obj_targets_inv =
335                Tensor::<ADBackend, 2>::ones(obj_targets.dims(), &obj_targets.device())
336                    - obj_targets.clone();
337            let obj_loss = -((obj_targets.clone() * pred_scores_clamped.clone().log())
338                + (obj_targets_inv
339                    * (Tensor::<ADBackend, 2>::ones(
340                        pred_scores_clamped.dims(),
341                        &pred_scores_clamped.device(),
342                    ) - pred_scores_clamped)
343                        .log()))
344            .sum()
345            .div_scalar((obj_targets.dims()[0] * obj_targets.dims()[1]) as f32);
346
347            // Box regression loss on matched preds only.
348            let box_err = (pred_boxes - box_targets.clone()).abs() * box_weights.clone();
349            let matched = box_weights.clone().sum().div_scalar(4.0);
350            let matched_scalar = matched
351                .into_data()
352                .to_vec::<f32>()
353                .unwrap_or_default()
354                .first()
355                .copied()
356                .unwrap_or(0.0);
357            let box_loss = if matched_scalar > 0.0 {
358                box_err.sum().div_scalar(matched_scalar)
359            } else {
360                // Return a zero scalar in the same tensor rank as div output (rank 1).
361                let zeros = vec![0.0f32; 1];
362                Tensor::<ADBackend, 1>::from_data(
363                    TensorData::new(zeros, [1]),
364                    &box_weights.device(),
365                )
366            };
367
368            let loss = box_loss * args.lambda_box + obj_loss * args.lambda_obj;
369            let loss_detached = loss.clone().detach();
370            let grads = GradientsParams::from_grads(loss.backward(), &model);
371            model = optim.step(args.lr as f64, model, grads);
372
373            let loss_val: f32 = loss_detached
374                .into_data()
375                .to_vec::<f32>()
376                .unwrap_or_default()
377                .into_iter()
378                .next()
379                .unwrap_or(0.0);
380            losses.push(loss_val);
381        }
382        let avg_loss: f32 = if losses.is_empty() {
383            0.0
384        } else {
385            losses.iter().sum::<f32>() / losses.len() as f32
386        };
387        println!("epoch {epoch}: avg loss {avg_loss:.4}");
388    }
389
390    let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
391    model
392        .clone()
393        .save_file(Path::new(ckpt_path), &recorder)
394        .map_err(|e| anyhow::anyhow!("failed to save checkpoint: {e}"))?;
395
396    Ok(())
397}
398
399fn train_convolutional_detector_warehouse(
400    args: &TrainArgs,
401    loaders: &WarehouseLoaders,
402    ckpt_path: &str,
403) -> anyhow::Result<()> {
404    let device = <ADBackend as burn::tensor::backend::Backend>::Device::default();
405    let mut model = MultiboxModel::<ADBackend>::new(
406        MultiboxModelConfig {
407            input_dim: Some(4 + 8), // first box (4) + features (8)
408            max_boxes: args.max_boxes,
409            ..Default::default()
410        },
411        &device,
412    );
413    let mut optim = AdamConfig::new().init();
414
415    let batch_size = args.batch_size.max(1);
416    for epoch in 0..args.epochs {
417        let mut losses = Vec::new();
418        let mut iter = loaders.train_iter();
419        loop {
420            let batch = match iter.next_batch::<ADBackend>(batch_size, &device)? {
421                Some(batch) => batch,
422                None => break,
423            };
424            let batch = crate::collate_from_burn_batch::<ADBackend>(batch, args.max_boxes)?;
425
426            let boxes = batch.boxes.clone();
427            let first_box = boxes
428                .clone()
429                .slice([0..boxes.dims()[0], 0..1, 0..4])
430                .reshape([boxes.dims()[0], 4]);
431            let features = batch.features.clone();
432            let input = burn::tensor::Tensor::cat(vec![first_box, features], 1);
433
434            let (pred_boxes, pred_scores) = model.forward_multibox(input);
435
436            let gt_boxes = batch.boxes.clone();
437            let gt_mask = batch.box_mask.clone();
438
439            let (obj_targets, box_targets, box_weights) =
440                build_greedy_targets(pred_boxes.clone(), gt_boxes.clone(), gt_mask.clone());
441
442            let eps = 1e-6;
443            let pred_scores_clamped = pred_scores.clamp(eps, 1.0 - eps);
444            let obj_targets_inv =
445                Tensor::<ADBackend, 2>::ones(obj_targets.dims(), &obj_targets.device())
446                    - obj_targets.clone();
447            let obj_loss = -((obj_targets.clone() * pred_scores_clamped.clone().log())
448                + (obj_targets_inv
449                    * (Tensor::<ADBackend, 2>::ones(
450                        pred_scores_clamped.dims(),
451                        &pred_scores_clamped.device(),
452                    ) - pred_scores_clamped)
453                        .log()))
454            .sum()
455            .div_scalar((obj_targets.dims()[0] * obj_targets.dims()[1]) as f32);
456
457            let box_err = (pred_boxes - box_targets.clone()).abs() * box_weights.clone();
458            let matched = box_weights.clone().sum().div_scalar(4.0);
459            let matched_scalar = matched
460                .into_data()
461                .to_vec::<f32>()
462                .unwrap_or_default()
463                .first()
464                .copied()
465                .unwrap_or(0.0);
466            let box_loss = if matched_scalar > 0.0 {
467                box_err.sum().div_scalar(matched_scalar)
468            } else {
469                let zeros = vec![0.0f32; 1];
470                Tensor::<ADBackend, 1>::from_data(
471                    TensorData::new(zeros, [1]),
472                    &box_weights.device(),
473                )
474            };
475
476            let loss = box_loss * args.lambda_box + obj_loss * args.lambda_obj;
477            let loss_detached = loss.clone().detach();
478            let grads = GradientsParams::from_grads(loss.backward(), &model);
479            model = optim.step(args.lr as f64, model, grads);
480
481            let loss_val: f32 = loss_detached
482                .into_data()
483                .to_vec::<f32>()
484                .unwrap_or_default()
485                .into_iter()
486                .next()
487                .unwrap_or(0.0);
488            losses.push(loss_val);
489        }
490        let avg_loss: f32 = if losses.is_empty() {
491            0.0
492        } else {
493            losses.iter().sum::<f32>() / losses.len() as f32
494        };
495        println!("epoch {epoch}: avg loss {avg_loss:.4}");
496    }
497
498    let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
499    model
500        .clone()
501        .save_file(Path::new(ckpt_path), &recorder)
502        .map_err(|e| anyhow::anyhow!("failed to save checkpoint: {e}"))?;
503
504    Ok(())
505}
506
507pub fn validate_backend_choice(kind: BackendKind) -> anyhow::Result<()> {
508    let built_wgpu = cfg!(feature = "backend-wgpu");
509    match (kind, built_wgpu) {
510        (BackendKind::Wgpu, false) => {
511            anyhow::bail!("backend-wgpu feature not enabled; rebuild with --features backend-wgpu or choose ndarray backend")
512        }
513        (BackendKind::NdArray, true) => {
514            println!("note: built with backend-wgpu; training will still use the WGPU backend despite --backend ndarray");
515        }
516        _ => {}
517    }
518    Ok(())
519}
520
521fn iou_xyxy(a: [f32; 4], b: [f32; 4]) -> f32 {
522    let ax0 = a[0].min(a[2]);
523    let ay0 = a[1].min(a[3]);
524    let ax1 = a[0].max(a[2]);
525    let ay1 = a[1].max(a[3]);
526    let bx0 = b[0].min(b[2]);
527    let by0 = b[1].min(b[3]);
528    let bx1 = b[0].max(b[2]);
529    let by1 = b[1].max(b[3]);
530
531    let inter_x0 = ax0.max(bx0);
532    let inter_y0 = ay0.max(by0);
533    let inter_x1 = ax1.min(bx1);
534    let inter_y1 = ay1.min(by1);
535
536    let inter_w = (inter_x1 - inter_x0).max(0.0);
537    let inter_h = (inter_y1 - inter_y0).max(0.0);
538    let inter_area = inter_w * inter_h;
539
540    let area_a = (ax1 - ax0).max(0.0) * (ay1 - ay0).max(0.0);
541    let area_b = (bx1 - bx0).max(0.0) * (by1 - by0).max(0.0);
542    let denom = area_a + area_b - inter_area;
543    if denom <= 0.0 {
544        0.0
545    } else {
546        inter_area / denom
547    }
548}
549pub fn load_multibox_model_from_checkpoint<P: AsRef<Path>>(
550    path: P,
551    device: &<TrainBackend as burn::tensor::backend::Backend>::Device,
552    max_boxes: usize,
553) -> Result<MultiboxModel<TrainBackend>, RecorderError> {
554    let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
555    MultiboxModel::<TrainBackend>::new(
556        MultiboxModelConfig {
557            max_boxes,
558            input_dim: Some(4 + 8),
559            ..Default::default()
560        },
561        device,
562    )
563    .load_file(path.as_ref(), &recorder, device)
564}
565
566pub fn build_greedy_targets<B: burn::tensor::backend::Backend>(
567    pred_boxes: Tensor<B, 3>,
568    gt_boxes: Tensor<B, 3>,
569    gt_mask: Tensor<B, 2>,
570) -> (Tensor<B, 2>, Tensor<B, 3>, Tensor<B, 3>) {
571    let batch = pred_boxes.dims()[0];
572    let max_pred = pred_boxes.dims()[1];
573    let max_gt = gt_boxes.dims()[1];
574
575    let gt_mask_vec = gt_mask
576        .clone()
577        .into_data()
578        .to_vec::<f32>()
579        .unwrap_or_default();
580    let gt_boxes_vec = gt_boxes
581        .clone()
582        .into_data()
583        .to_vec::<f32>()
584        .unwrap_or_default();
585    let pred_boxes_vec = pred_boxes
586        .clone()
587        .into_data()
588        .to_vec::<f32>()
589        .unwrap_or_default();
590
591    let mut obj_targets = vec![0.0f32; batch * max_pred];
592    let mut box_targets = vec![0.0f32; batch * max_pred * 4];
593    let mut box_weights = vec![0.0f32; batch * max_pred * 4];
594
595    for b in 0..batch {
596        for g in 0..max_gt {
597            let mask_idx = b * max_gt + g;
598            if gt_mask_vec.get(mask_idx).copied().unwrap_or(0.0) < 0.5 {
599                continue;
600            }
601            let gb = [
602                gt_boxes_vec[(b * max_gt + g) * 4],
603                gt_boxes_vec[(b * max_gt + g) * 4 + 1],
604                gt_boxes_vec[(b * max_gt + g) * 4 + 2],
605                gt_boxes_vec[(b * max_gt + g) * 4 + 3],
606            ];
607
608            let mut best_iou = -1.0f32;
609            let mut best_p = 0usize;
610            for p in 0..max_pred {
611                let pb = [
612                    pred_boxes_vec[(b * max_pred + p) * 4],
613                    pred_boxes_vec[(b * max_pred + p) * 4 + 1],
614                    pred_boxes_vec[(b * max_pred + p) * 4 + 2],
615                    pred_boxes_vec[(b * max_pred + p) * 4 + 3],
616                ];
617                let iou = iou_xyxy(pb, gb);
618                if iou > best_iou {
619                    best_iou = iou;
620                    best_p = p;
621                }
622            }
623
624            let obj_idx = b * max_pred + best_p;
625            obj_targets[obj_idx] = 1.0;
626            let bt_base = (b * max_pred + best_p) * 4;
627            box_targets[bt_base..bt_base + 4].copy_from_slice(&gb);
628            box_weights[bt_base..bt_base + 4].copy_from_slice(&[1.0, 1.0, 1.0, 1.0]);
629        }
630    }
631
632    let device = &B::Device::default();
633    let obj_targets =
634        Tensor::<B, 2>::from_data(TensorData::new(obj_targets, [batch, max_pred]), device);
635    let box_targets =
636        Tensor::<B, 3>::from_data(TensorData::new(box_targets, [batch, max_pred, 4]), device);
637    let box_weights =
638        Tensor::<B, 3>::from_data(TensorData::new(box_weights, [batch, max_pred, 4]), device);
639
640    (obj_targets, box_targets, box_weights)
641}