burn-vision 0.20.1

Vision processing operations for burn tensors
Documentation
use burn_vision::{Nms, NmsOptions};

mod common;
use common::*;

#[test]
fn should_suppress_non_maximum() {
    let boxes = TestTensor::<2>::from([
        [0, 0, 100, 100],
        [0, 1, 100, 100],
        [0, 101, 200, 200],
        [0, 100, 200, 200],
        [0, 170, 300, 300],
    ]);
    let scores = TestTensor::<1>::from([0.1, 0.2, 0.4, 0.3, 0.5]);
    let options = NmsOptions {
        iou_threshold: 0.5,
        score_threshold: 0.0,
        max_output_boxes: 0,
    };

    let output = boxes.nms(scores, options);

    let expected = TestTensorInt::<1>::from([4, 2, 1]);
    output.into_data().assert_eq(&expected.into_data(), true);
}

#[test]
fn should_apply_score_threshold() {
    let boxes = TestTensor::<2>::from([
        [0, 0, 100, 100],
        [0, 1, 100, 100],
        [0, 101, 200, 200],
        [0, 100, 200, 200],
        [0, 170, 300, 300],
    ]);
    let scores = TestTensor::<1>::from([0.1, 0.2, 0.4, 0.3, 0.5]);
    let options = NmsOptions {
        iou_threshold: 0.5,
        score_threshold: 0.3,
        max_output_boxes: 0,
    };

    let output = boxes.nms(scores, options);

    let expected = TestTensorInt::<1>::from([4, 2]);
    output.into_data().assert_eq(&expected.into_data(), true);
}

#[test]
fn should_apply_iou_threshold() {
    let boxes = TestTensor::<2>::from([
        [0, 0, 100, 100],
        [0, 1, 100, 100],
        [0, 101, 200, 200],
        [0, 100, 200, 200],
        [0, 170, 300, 300],
    ]);
    let scores = TestTensor::<1>::from([0.1, 0.2, 0.4, 0.3, 0.5]);
    let options = NmsOptions {
        iou_threshold: 0.1,
        score_threshold: 0.0,
        max_output_boxes: 0,
    };

    let output = boxes.nms(scores, options);

    let expected = TestTensorInt::<1>::from([4, 1]);
    output.into_data().assert_eq(&expected.into_data(), true);
}

#[test]
fn should_apply_max_output_boxes() {
    let boxes = TestTensor::<2>::from([
        [0, 0, 100, 100],
        [0, 1, 100, 100],
        [0, 101, 200, 200],
        [0, 100, 200, 200],
        [0, 170, 300, 300],
    ]);
    let scores = TestTensor::<1>::from([0.1, 0.2, 0.4, 0.3, 0.5]);
    let options = NmsOptions {
        iou_threshold: 0.5,
        score_threshold: 0.0,
        max_output_boxes: 1,
    };

    let output = boxes.nms(scores, options);

    let expected = TestTensorInt::<1>::from([4]);
    output.into_data().assert_eq(&expected.into_data(), true);
}