training/
util.rs

1use burn::backend::Autodiff;
2use burn::module::Module;
3use burn::nn::loss::{MseLoss, Reduction};
4use burn::optim::{AdamConfig, GradientsParams, Optimizer};
5use burn::record::{BinFileRecorder, FullPrecisionSettings, RecorderError};
6use burn::tensor::{Tensor, TensorData};
7use std::path::Path;
8
9use crate::{BigDet, BigDetConfig, DatasetConfig, TinyDet, TinyDetConfig, TrainBackend};
10use clap::{Parser, ValueEnum};
11use std::fs;
12
13pub fn load_tinydet_from_checkpoint<P: AsRef<Path>>(
14    path: P,
15    device: &<TrainBackend as burn::tensor::backend::Backend>::Device,
16) -> Result<TinyDet<TrainBackend>, RecorderError> {
17    let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
18    TinyDet::<TrainBackend>::new(TinyDetConfig::default(), device).load_file(
19        path.as_ref(),
20        &recorder,
21        device,
22    )
23}
24
25#[derive(ValueEnum, Debug, Clone, Copy)]
26pub enum ModelKind {
27    Tiny,
28    Big,
29}
30
31#[derive(ValueEnum, Debug, Clone, Copy)]
32pub enum BackendKind {
33    NdArray,
34    Wgpu,
35}
36
37#[derive(Parser, Debug)]
38#[command(name = "train", about = "Train TinyDet/BigDet on capture metadata")]
39pub struct TrainArgs {
40    /// Model to train.
41    #[arg(long, value_enum, default_value_t = ModelKind::Tiny)]
42    pub model: ModelKind,
43    /// Backend to use (ndarray or wgpu if enabled).
44    #[arg(long, value_enum, default_value_t = BackendKind::NdArray)]
45    pub backend: BackendKind,
46    /// Maximum boxes per image (pads/truncates to this for training).
47    #[arg(long, default_value_t = 64)]
48    pub max_boxes: usize,
49    /// Loss weight for box regression.
50    #[arg(long, default_value_t = 1.0)]
51    pub lambda_box: f32,
52    /// Loss weight for objectness.
53    #[arg(long, default_value_t = 1.0)]
54    pub lambda_obj: f32,
55    /// Dataset root containing labels/ and images/ (uses data_contracts schemas).
56    #[arg(long, default_value = "assets/datasets/captures_filtered")]
57    pub dataset_root: String,
58    /// Labels subdirectory relative to dataset root.
59    #[arg(long, default_value = "labels")]
60    pub labels_subdir: String,
61    /// Images subdirectory relative to dataset root.
62    #[arg(long, default_value = ".")]
63    pub images_subdir: String,
64    /// Number of epochs.
65    #[arg(long, default_value_t = 1)]
66    pub epochs: usize,
67    /// Batch size.
68    #[arg(long, default_value_t = 1)]
69    pub batch_size: usize,
70    /// Learning rate.
71    #[arg(long, default_value_t = 1e-3)]
72    pub lr: f32,
73    /// Objectness threshold (for future eval).
74    #[arg(long, default_value_t = 0.3)]
75    pub infer_obj_thresh: f32,
76    /// IoU threshold (for future eval).
77    #[arg(long, default_value_t = 0.5)]
78    pub infer_iou_thresh: f32,
79    /// Checkpoint output path (defaults by model if not provided).
80    #[arg(long)]
81    pub checkpoint_out: Option<String>,
82}
83
84pub fn run_train(args: TrainArgs) -> anyhow::Result<()> {
85    validate_backend_choice(args.backend)?;
86
87    let cfg = DatasetConfig {
88        root: args.dataset_root.clone().into(),
89        labels_subdir: args.labels_subdir.clone(),
90        images_subdir: args.images_subdir.clone(),
91    };
92    let samples = cfg.load()?;
93    if samples.is_empty() {
94        println!("No samples found under {}", cfg.root.display());
95        return Ok(());
96    }
97
98    let ckpt_path = args
99        .checkpoint_out
100        .clone()
101        .unwrap_or_else(|| match args.model {
102            ModelKind::Tiny => "checkpoints/tinydet.bin".to_string(),
103            ModelKind::Big => "checkpoints/bigdet.bin".to_string(),
104        });
105
106    if let Some(parent) = Path::new(&ckpt_path).parent() {
107        fs::create_dir_all(parent)?;
108    }
109
110    match args.model {
111        ModelKind::Tiny => train_tinydet(&args, &samples, &ckpt_path)?,
112        ModelKind::Big => train_bigdet(&args, &samples, &ckpt_path)?,
113    }
114
115    println!("Saved checkpoint to {}", ckpt_path);
116    Ok(())
117}
118
119type ADBackend = Autodiff<TrainBackend>;
120
121fn train_tinydet(
122    args: &TrainArgs,
123    samples: &[crate::RunSample],
124    ckpt_path: &str,
125) -> anyhow::Result<()> {
126    let device = <ADBackend as burn::tensor::backend::Backend>::Device::default();
127    let mut model = TinyDet::<ADBackend>::new(TinyDetConfig::default(), &device);
128    let mut optim = AdamConfig::new().init();
129
130    let batch_size = args.batch_size.max(1);
131    let data = samples.to_vec();
132    for epoch in 0..args.epochs {
133        let mut losses = Vec::new();
134        for batch in data.chunks(batch_size) {
135            let batch = crate::collate::<ADBackend>(batch, args.max_boxes)?;
136            // Feature: take the first box (or zeros) as the input vector.
137            let boxes = batch.boxes.clone();
138            let first_box = boxes
139                .clone()
140                .slice([0..boxes.dims()[0], 0..1, 0..4])
141                .reshape([boxes.dims()[0], 4]);
142
143            // Target: 1.0 if any box present, else 0.0.
144            let mask = batch.box_mask.clone();
145            let has_box = mask.clone().sum_dim(1).reshape([mask.dims()[0], 1]);
146
147            let preds = model.forward(first_box);
148            let mse = MseLoss::new();
149            let loss = mse.forward(preds, has_box, Reduction::Mean);
150            let loss_detached = loss.clone().detach();
151            let grads = GradientsParams::from_grads(loss.backward(), &model);
152            model = optim.step(args.lr as f64, model, grads);
153
154            let loss_val: f32 = loss_detached
155                .into_data()
156                .to_vec::<f32>()
157                .unwrap_or_default()
158                .into_iter()
159                .next()
160                .unwrap_or(0.0);
161            losses.push(loss_val);
162        }
163        let avg_loss: f32 = if losses.is_empty() {
164            0.0
165        } else {
166            losses.iter().sum::<f32>() / losses.len() as f32
167        };
168        println!("epoch {epoch}: avg loss {avg_loss:.4}");
169    }
170
171    let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
172    model
173        .clone()
174        .save_file(Path::new(ckpt_path), &recorder)
175        .map_err(|e| anyhow::anyhow!("failed to save checkpoint: {e}"))?;
176
177    Ok(())
178}
179
180fn train_bigdet(
181    args: &TrainArgs,
182    samples: &[crate::RunSample],
183    ckpt_path: &str,
184) -> anyhow::Result<()> {
185    let device = <ADBackend as burn::tensor::backend::Backend>::Device::default();
186    let mut model = BigDet::<ADBackend>::new(
187        BigDetConfig {
188            input_dim: Some(4 + 8), // first box (4) + features (8)
189            max_boxes: args.max_boxes,
190            ..Default::default()
191        },
192        &device,
193    );
194    let mut optim = AdamConfig::new().init();
195
196    let batch_size = args.batch_size.max(1);
197    let data = samples.to_vec();
198    for epoch in 0..args.epochs {
199        let mut losses = Vec::new();
200        for batch in data.chunks(batch_size) {
201            let batch = crate::collate::<ADBackend>(batch, args.max_boxes)?;
202            // Features: first box (or zeros) + pooled image features.
203            let boxes = batch.boxes.clone();
204            let first_box = boxes
205                .clone()
206                .slice([0..boxes.dims()[0], 0..1, 0..4])
207                .reshape([boxes.dims()[0], 4]);
208            let features = batch.features.clone();
209            let input = burn::tensor::Tensor::cat(vec![first_box, features], 1);
210
211            let (pred_boxes, pred_scores) = model.forward_multibox(input);
212
213            // Targets
214            let gt_boxes = batch.boxes.clone();
215            let gt_mask = batch.box_mask.clone();
216
217            // Greedy matching per GT: for each GT box, pick best pred by IoU.
218            let (obj_targets, box_targets, box_weights) =
219                build_greedy_targets(pred_boxes.clone(), gt_boxes.clone(), gt_mask.clone());
220            // Greedy IoU matcher is deterministic/cheap; swap to Hungarian if finer matching is needed later.
221
222            // Objectness loss (BCE) with targets; unassigned preds stay at 0.0.
223            let eps = 1e-6;
224            let pred_scores_clamped = pred_scores.clamp(eps, 1.0 - eps);
225            let obj_targets_inv =
226                Tensor::<ADBackend, 2>::ones(obj_targets.dims(), &obj_targets.device())
227                    - obj_targets.clone();
228            let obj_loss = -((obj_targets.clone() * pred_scores_clamped.clone().log())
229                + (obj_targets_inv
230                    * (Tensor::<ADBackend, 2>::ones(
231                        pred_scores_clamped.dims(),
232                        &pred_scores_clamped.device(),
233                    ) - pred_scores_clamped)
234                        .log()))
235            .sum()
236            .div_scalar((obj_targets.dims()[0] * obj_targets.dims()[1]) as f32);
237
238            // Box regression loss on matched preds only.
239            let box_err = (pred_boxes - box_targets.clone()).abs() * box_weights.clone();
240            let matched = box_weights.clone().sum().div_scalar(4.0);
241            let matched_scalar = matched
242                .into_data()
243                .to_vec::<f32>()
244                .unwrap_or_default()
245                .first()
246                .copied()
247                .unwrap_or(0.0);
248            let box_loss = if matched_scalar > 0.0 {
249                box_err.sum().div_scalar(matched_scalar)
250            } else {
251                // Return a zero scalar in the same tensor rank as div output (rank 1).
252                let zeros = vec![0.0f32; 1];
253                Tensor::<ADBackend, 1>::from_data(
254                    TensorData::new(zeros, [1]),
255                    &box_weights.device(),
256                )
257            };
258
259            let loss = box_loss * args.lambda_box + obj_loss * args.lambda_obj;
260            let loss_detached = loss.clone().detach();
261            let grads = GradientsParams::from_grads(loss.backward(), &model);
262            model = optim.step(args.lr as f64, model, grads);
263
264            let loss_val: f32 = loss_detached
265                .into_data()
266                .to_vec::<f32>()
267                .unwrap_or_default()
268                .into_iter()
269                .next()
270                .unwrap_or(0.0);
271            losses.push(loss_val);
272        }
273        let avg_loss: f32 = if losses.is_empty() {
274            0.0
275        } else {
276            losses.iter().sum::<f32>() / losses.len() as f32
277        };
278        println!("epoch {epoch}: avg loss {avg_loss:.4}");
279    }
280
281    let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
282    model
283        .clone()
284        .save_file(Path::new(ckpt_path), &recorder)
285        .map_err(|e| anyhow::anyhow!("failed to save checkpoint: {e}"))?;
286
287    Ok(())
288}
289
290pub fn validate_backend_choice(kind: BackendKind) -> anyhow::Result<()> {
291    let built_wgpu = cfg!(feature = "backend-wgpu");
292    match (kind, built_wgpu) {
293        (BackendKind::Wgpu, false) => {
294            anyhow::bail!("backend-wgpu feature not enabled; rebuild with --features backend-wgpu or choose ndarray backend")
295        }
296        (BackendKind::NdArray, true) => {
297            println!("note: built with backend-wgpu; training will still use the WGPU backend despite --backend ndarray");
298        }
299        _ => {}
300    }
301    Ok(())
302}
303
304fn iou_xyxy(a: [f32; 4], b: [f32; 4]) -> f32 {
305    let ax0 = a[0].min(a[2]);
306    let ay0 = a[1].min(a[3]);
307    let ax1 = a[0].max(a[2]);
308    let ay1 = a[1].max(a[3]);
309    let bx0 = b[0].min(b[2]);
310    let by0 = b[1].min(b[3]);
311    let bx1 = b[0].max(b[2]);
312    let by1 = b[1].max(b[3]);
313
314    let inter_x0 = ax0.max(bx0);
315    let inter_y0 = ay0.max(by0);
316    let inter_x1 = ax1.min(bx1);
317    let inter_y1 = ay1.min(by1);
318
319    let inter_w = (inter_x1 - inter_x0).max(0.0);
320    let inter_h = (inter_y1 - inter_y0).max(0.0);
321    let inter_area = inter_w * inter_h;
322
323    let area_a = (ax1 - ax0).max(0.0) * (ay1 - ay0).max(0.0);
324    let area_b = (bx1 - bx0).max(0.0) * (by1 - by0).max(0.0);
325    let denom = area_a + area_b - inter_area;
326    if denom <= 0.0 {
327        0.0
328    } else {
329        inter_area / denom
330    }
331}
332pub fn load_bigdet_from_checkpoint<P: AsRef<Path>>(
333    path: P,
334    device: &<TrainBackend as burn::tensor::backend::Backend>::Device,
335    max_boxes: usize,
336) -> Result<BigDet<TrainBackend>, RecorderError> {
337    let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
338    BigDet::<TrainBackend>::new(
339        BigDetConfig {
340            max_boxes,
341            input_dim: Some(4 + 8),
342            ..Default::default()
343        },
344        device,
345    )
346    .load_file(path.as_ref(), &recorder, device)
347}
348
349pub fn build_greedy_targets<B: burn::tensor::backend::Backend>(
350    pred_boxes: Tensor<B, 3>,
351    gt_boxes: Tensor<B, 3>,
352    gt_mask: Tensor<B, 2>,
353) -> (Tensor<B, 2>, Tensor<B, 3>, Tensor<B, 3>) {
354    let batch = pred_boxes.dims()[0];
355    let max_pred = pred_boxes.dims()[1];
356    let max_gt = gt_boxes.dims()[1];
357
358    let gt_mask_vec = gt_mask
359        .clone()
360        .into_data()
361        .to_vec::<f32>()
362        .unwrap_or_default();
363    let gt_boxes_vec = gt_boxes
364        .clone()
365        .into_data()
366        .to_vec::<f32>()
367        .unwrap_or_default();
368    let pred_boxes_vec = pred_boxes
369        .clone()
370        .into_data()
371        .to_vec::<f32>()
372        .unwrap_or_default();
373
374    let mut obj_targets = vec![0.0f32; batch * max_pred];
375    let mut box_targets = vec![0.0f32; batch * max_pred * 4];
376    let mut box_weights = vec![0.0f32; batch * max_pred * 4];
377
378    for b in 0..batch {
379        for g in 0..max_gt {
380            let mask_idx = b * max_gt + g;
381            if gt_mask_vec.get(mask_idx).copied().unwrap_or(0.0) < 0.5 {
382                continue;
383            }
384            let gb = [
385                gt_boxes_vec[(b * max_gt + g) * 4],
386                gt_boxes_vec[(b * max_gt + g) * 4 + 1],
387                gt_boxes_vec[(b * max_gt + g) * 4 + 2],
388                gt_boxes_vec[(b * max_gt + g) * 4 + 3],
389            ];
390
391            let mut best_iou = -1.0f32;
392            let mut best_p = 0usize;
393            for p in 0..max_pred {
394                let pb = [
395                    pred_boxes_vec[(b * max_pred + p) * 4],
396                    pred_boxes_vec[(b * max_pred + p) * 4 + 1],
397                    pred_boxes_vec[(b * max_pred + p) * 4 + 2],
398                    pred_boxes_vec[(b * max_pred + p) * 4 + 3],
399                ];
400                let iou = iou_xyxy(pb, gb);
401                if iou > best_iou {
402                    best_iou = iou;
403                    best_p = p;
404                }
405            }
406
407            let obj_idx = b * max_pred + best_p;
408            obj_targets[obj_idx] = 1.0;
409            let bt_base = (b * max_pred + best_p) * 4;
410            box_targets[bt_base..bt_base + 4].copy_from_slice(&gb);
411            box_weights[bt_base..bt_base + 4].copy_from_slice(&[1.0, 1.0, 1.0, 1.0]);
412        }
413    }
414
415    let device = &B::Device::default();
416    let obj_targets =
417        Tensor::<B, 2>::from_data(TensorData::new(obj_targets, [batch, max_pred]), device);
418    let box_targets =
419        Tensor::<B, 3>::from_data(TensorData::new(box_targets, [batch, max_pred, 4]), device);
420    let box_weights =
421        Tensor::<B, 3>::from_data(TensorData::new(box_weights, [batch, max_pred, 4]), device);
422
423    (obj_targets, box_targets, box_weights)
424}