use crate::InferenceConfig;
use crate::task::Task;
use clap::{Args, Parser, Subcommand};
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
#[command(propagate_version = true)]
#[command(after_help = r#"Predict Options:
--model, -m <MODEL> Path to ONNX model file [default: yolo26n.onnx]
--task <TASK> Task type: detect, segment, pose, obb, classify [default: detect]
Selects the matching nano model when --model is omitted
--source, -s <SOURCE> Input source (image, directory, glob, video, webcam, or URL)
--conf <CONF> Confidence threshold [default: 0.25]
--iou <IOU> IoU threshold for NMS [default: 0.7]
--max-det <MAX_DET> Maximum number of detections [default: 300]
--imgsz <IMGSZ> Inference image size [default: model metadata]
--rect Enable rectangular inference (minimal padding) [default: true]
--batch <BATCH> Batch size for inference [default: 1]
--half Use FP16 half-precision inference [default: false]
--save Save annotated images to runs/<task>/predict [default: true]
--save-frames Save individual frames for video input (instead of video file)
--show Display results in a window [default: false]
--device <DEVICE> Device (cpu, cuda:0, mps, coreml, directml:0, openvino, tensorrt:0, xnnpack)
--verbose Show verbose output [default: true]
--classes <CLASSES> Filter by class IDs (e.g., "0", "0,1,2", "[0, 1]")
Examples:
ultralytics-inference predict
ultralytics-inference predict --task segment
ultralytics-inference predict --task pose
ultralytics-inference predict --task obb --source aerial.jpg
ultralytics-inference predict --task classify --source image.jpg
ultralytics-inference predict --model yolo26n.onnx --source image.jpg
ultralytics-inference predict --source video.mp4 --rect
ultralytics-inference predict --source video.mp4 --save-frames
ultralytics-inference predict --source 0 --conf 0.5 --show
ultralytics-inference predict --source assets/ --save --half
ultralytics-inference predict --source image.jpg --device cuda:0
ultralytics-inference predict --source image.jpg --classes 0"#)]
pub struct Cli {
#[command(subcommand)]
pub command: Commands,
}
#[derive(Subcommand, Debug)]
pub enum Commands {
Predict(PredictArgs),
}
#[derive(Args, Debug)]
#[allow(clippy::struct_excessive_bools)]
pub struct PredictArgs {
#[arg(short, long)]
pub model: Option<String>,
#[arg(long)]
pub task: Option<Task>,
#[arg(short, long)]
pub source: Option<String>,
#[arg(long, default_value_t = InferenceConfig::DEFAULT_CONF)]
pub conf: f32,
#[arg(long, default_value_t = InferenceConfig::DEFAULT_IOU)]
pub iou: f32,
#[arg(long, default_value_t = InferenceConfig::DEFAULT_MAX_DET)]
pub max_det: usize,
#[arg(long)]
pub imgsz: Option<usize>,
#[arg(long, default_value_t = InferenceConfig::DEFAULT_RECT, num_args = 0..=1, default_missing_value = "true", action = clap::ArgAction::Set)]
pub rect: bool,
#[arg(long, default_value_t = 1, value_parser = clap::value_parser!(u32).range(1..))]
pub batch: u32,
#[arg(long, default_value_t = InferenceConfig::DEFAULT_HALF)]
pub half: bool,
#[arg(long, default_value_t = InferenceConfig::DEFAULT_SAVE, num_args = 0..=1, default_missing_value = "true", action = clap::ArgAction::Set)]
pub save: bool,
#[arg(long, default_value_t = InferenceConfig::DEFAULT_SAVE_FRAMES)]
pub save_frames: bool,
#[arg(long, default_value_t = false)]
pub show: bool,
#[arg(long)]
pub device: Option<String>,
#[arg(long, default_value_t = true, action = clap::ArgAction::Set)]
pub verbose: bool,
#[arg(long, allow_hyphen_values = true)]
pub classes: Option<String>,
}
pub fn parse_classes(s: &str) -> Result<Vec<usize>, String> {
let cleaned = s
.trim()
.trim_matches(|c| c == '[' || c == ']' || c == '(' || c == ')' || c == '"' || c == '\'');
if cleaned.is_empty() {
return Ok(Vec::new());
}
cleaned
.split(',')
.map(|part| {
part.trim()
.parse::<usize>()
.map_err(|e| format!("Invalid class ID '{}': {}", part.trim(), e))
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn verify_cli() {
use clap::CommandFactory;
Cli::command().debug_assert();
}
#[test]
fn test_predict_args_defaults() {
let args = Cli::parse_from(["app", "predict", "--model", "yolo26n.onnx"]);
match args.command {
Commands::Predict(predict_args) => {
assert_eq!(predict_args.model, Some("yolo26n.onnx".to_string()));
assert!((predict_args.conf - InferenceConfig::DEFAULT_CONF).abs() < f32::EPSILON);
assert!((predict_args.iou - InferenceConfig::DEFAULT_IOU).abs() < f32::EPSILON);
assert!(predict_args.rect);
assert_eq!(predict_args.max_det, 300);
assert!(!predict_args.half);
assert!(predict_args.verbose);
assert!(predict_args.source.is_none());
}
}
}
#[test]
fn test_predict_args_custom() {
let args = Cli::parse_from([
"app",
"predict",
"--model",
"custom.onnx",
"--source",
"test.jpg",
"--conf",
"0.8",
"--verbose",
"false",
]);
match args.command {
Commands::Predict(predict_args) => {
assert_eq!(predict_args.model, Some("custom.onnx".to_string()));
assert_eq!(predict_args.source, Some("test.jpg".to_string()));
assert!((predict_args.conf - 0.8).abs() < f32::EPSILON);
assert!(!predict_args.verbose);
}
}
}
}