#![allow(clippy::float_cmp)]
use ndarray::Array3;
use std::collections::HashMap;
use tempfile::tempdir;
use ultralytics_inference::cli::args::PredictArgs;
use ultralytics_inference::cli::predict::run_prediction;
use ultralytics_inference::{Boxes, InferenceConfig, Results, Speed};
#[test]
#[ignore = "downloads a YOLO model and sample image"]
fn test_run_prediction_e2e() {
let temp_dir = tempdir().expect("temp dir should be created");
let model_path = temp_dir.path().join("yolo26n.onnx");
let args = PredictArgs {
model: Some(model_path.to_string_lossy().into_owned()),
task: None,
source: Some("https://ultralytics.com/images/bus.jpg".to_string()),
conf: 0.25,
iou: 0.45,
max_det: 300,
imgsz: Some(640),
rect: false,
batch: 1,
half: false,
save: false,
save_frames: false,
show: false,
device: None,
verbose: false,
classes: None,
};
run_prediction(&args);
}
#[test]
fn test_inference_config_creation() {
let config = InferenceConfig::default();
assert_eq!(config.confidence_threshold, 0.25);
assert_eq!(config.iou_threshold, 0.7);
assert_eq!(config.max_det, 300);
}
#[test]
fn test_inference_config_builder() {
let config = InferenceConfig::new()
.with_confidence(0.5)
.with_iou(0.7)
.with_max_det(300);
assert_eq!(config.confidence_threshold, 0.5);
assert_eq!(config.iou_threshold, 0.7);
assert_eq!(config.max_det, 300);
}
#[test]
fn test_inference_config_batch() {
let config = InferenceConfig::new().with_batch(32);
assert_eq!(config.batch, Some(32));
}
#[test]
fn test_boxes_creation() {
let data = ndarray::array![
[10.0f32, 20.0, 30.0, 40.0, 0.95, 0.0],
[50.0, 60.0, 70.0, 80.0, 0.85, 1.0],
];
let boxes = Boxes::new(data, (480, 640));
assert_eq!(boxes.len(), 2);
assert!(!boxes.is_empty());
}
#[test]
fn test_boxes_xyxy() {
let data = ndarray::array![[10.0f32, 20.0, 30.0, 40.0, 0.95, 0.0],];
let boxes = Boxes::new(data, (480, 640));
let xyxy = boxes.xyxy();
assert_eq!(xyxy[[0, 0]], 10.0);
assert_eq!(xyxy[[0, 1]], 20.0);
assert_eq!(xyxy[[0, 2]], 30.0);
assert_eq!(xyxy[[0, 3]], 40.0);
}
#[test]
fn test_boxes_conf_and_cls() {
let data = ndarray::array![[10.0f32, 20.0, 30.0, 40.0, 0.95, 2.0],];
let boxes = Boxes::new(data, (480, 640));
assert_eq!(boxes.conf()[[0]], 0.95);
assert_eq!(boxes.cls()[[0]], 2.0);
}
#[test]
fn test_results_creation() {
let orig_img = Array3::zeros((480, 640, 3));
let names = HashMap::new();
let speed = Speed::default();
let results = Results::new(orig_img, "test.jpg".to_string(), names, speed, (640, 640));
assert!(results.boxes.is_none());
assert!(results.masks.is_none());
assert!(results.keypoints.is_none());
assert!(results.probs.is_none());
assert!(results.obb.is_none());
}
#[test]
fn test_results_with_boxes() {
let orig_img = Array3::zeros((480, 640, 3));
let names = HashMap::new();
let speed = Speed::default();
let mut results = Results::new(orig_img, "test.jpg".to_string(), names, speed, (640, 640));
let boxes_data = ndarray::array![[10.0f32, 20.0, 30.0, 40.0, 0.95, 0.0],];
results.boxes = Some(Boxes::new(boxes_data, (480, 640)));
assert!(results.boxes.is_some());
assert_eq!(results.boxes.as_ref().unwrap().len(), 1);
assert!(!results.is_empty());
}
#[test]
fn test_results_is_empty() {
let orig_img = Array3::zeros((480, 640, 3));
let names = HashMap::new();
let speed = Speed::default();
let results = Results::new(orig_img, "test.jpg".to_string(), names, speed, (640, 640));
assert!(results.is_empty());
}
#[test]
fn test_speed_timing() {
let speed = Speed::new(10.0, 20.0, 5.0);
assert_eq!(speed.preprocess, Some(10.0));
assert_eq!(speed.inference, Some(20.0));
assert_eq!(speed.postprocess, Some(5.0));
}