ultralytics-inference 0.0.13

Ultralytics YOLO inference library and CLI for Rust
Documentation
// Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license

#![allow(clippy::float_cmp)]

//! Integration tests for the inference library

use ndarray::Array3;
use std::collections::HashMap;
use tempfile::tempdir;
use ultralytics_inference::cli::args::PredictArgs;
use ultralytics_inference::cli::predict::run_prediction;
use ultralytics_inference::{Boxes, InferenceConfig, Results, Speed};

#[test]
#[ignore = "downloads a YOLO model and sample image"]
fn test_run_prediction_e2e() {
    let temp_dir = tempdir().expect("temp dir should be created");
    let model_path = temp_dir.path().join("yolo26n.onnx");

    let args = PredictArgs {
        model: Some(model_path.to_string_lossy().into_owned()),
        task: None,
        source: Some("https://ultralytics.com/images/bus.jpg".to_string()),
        conf: 0.25,
        iou: 0.45,
        max_det: 300,
        imgsz: Some(640),
        rect: false,
        batch: 1,
        half: false,
        save: false,
        save_frames: false,
        show: false,
        device: None,
        verbose: false,
        classes: None,
    };

    run_prediction(&args);
}

#[test]
fn test_inference_config_creation() {
    let config = InferenceConfig::default();
    assert_eq!(config.confidence_threshold, 0.25);
    assert_eq!(config.iou_threshold, 0.7);
    assert_eq!(config.max_det, 300);
}

#[test]
fn test_inference_config_builder() {
    let config = InferenceConfig::new()
        .with_confidence(0.5)
        .with_iou(0.7)
        .with_max_det(300);

    assert_eq!(config.confidence_threshold, 0.5);
    assert_eq!(config.iou_threshold, 0.7);
    assert_eq!(config.max_det, 300);
}

#[test]
fn test_inference_config_batch() {
    let config = InferenceConfig::new().with_batch(32);
    assert_eq!(config.batch, Some(32));
}

#[test]
fn test_boxes_creation() {
    // Create boxes data: [x1, y1, x2, y2, conf, cls]
    let data = ndarray::array![
        [10.0f32, 20.0, 30.0, 40.0, 0.95, 0.0],
        [50.0, 60.0, 70.0, 80.0, 0.85, 1.0],
    ];

    let boxes = Boxes::new(data, (480, 640));

    assert_eq!(boxes.len(), 2);
    assert!(!boxes.is_empty());
}

#[test]
fn test_boxes_xyxy() {
    let data = ndarray::array![[10.0f32, 20.0, 30.0, 40.0, 0.95, 0.0],];

    let boxes = Boxes::new(data, (480, 640));
    let xyxy = boxes.xyxy();

    assert_eq!(xyxy[[0, 0]], 10.0);
    assert_eq!(xyxy[[0, 1]], 20.0);
    assert_eq!(xyxy[[0, 2]], 30.0);
    assert_eq!(xyxy[[0, 3]], 40.0);
}

#[test]
fn test_boxes_conf_and_cls() {
    let data = ndarray::array![[10.0f32, 20.0, 30.0, 40.0, 0.95, 2.0],];

    let boxes = Boxes::new(data, (480, 640));

    assert_eq!(boxes.conf()[[0]], 0.95);
    assert_eq!(boxes.cls()[[0]], 2.0);
}

#[test]
fn test_results_creation() {
    let orig_img = Array3::zeros((480, 640, 3));
    let names = HashMap::new();
    let speed = Speed::default();

    let results = Results::new(orig_img, "test.jpg".to_string(), names, speed, (640, 640));

    assert!(results.boxes.is_none());
    assert!(results.masks.is_none());
    assert!(results.keypoints.is_none());
    assert!(results.probs.is_none());
    assert!(results.obb.is_none());
}

#[test]
fn test_results_with_boxes() {
    let orig_img = Array3::zeros((480, 640, 3));
    let names = HashMap::new();
    let speed = Speed::default();

    let mut results = Results::new(orig_img, "test.jpg".to_string(), names, speed, (640, 640));

    let boxes_data = ndarray::array![[10.0f32, 20.0, 30.0, 40.0, 0.95, 0.0],];
    results.boxes = Some(Boxes::new(boxes_data, (480, 640)));

    assert!(results.boxes.is_some());
    assert_eq!(results.boxes.as_ref().unwrap().len(), 1);
    assert!(!results.is_empty());
}

#[test]
fn test_results_is_empty() {
    let orig_img = Array3::zeros((480, 640, 3));
    let names = HashMap::new();
    let speed = Speed::default();

    let results = Results::new(orig_img, "test.jpg".to_string(), names, speed, (640, 640));

    assert!(results.is_empty());
}

#[test]
fn test_speed_timing() {
    let speed = Speed::new(10.0, 20.0, 5.0);

    assert_eq!(speed.preprocess, Some(10.0));
    assert_eq!(speed.inference, Some(20.0));
    assert_eq!(speed.postprocess, Some(5.0));
}