ultralytics-inference 0.0.13

Ultralytics YOLO inference library and CLI for Rust
Documentation
// Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license

use crate::InferenceConfig;
use crate::task::Task;
use clap::{Args, Parser, Subcommand};

/// CLI arguments parser.
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
#[command(propagate_version = true)]
#[command(after_help = r#"Predict Options:
    --model, -m <MODEL>    Path to ONNX model file [default: yolo26n.onnx]
    --task <TASK>          Task type: detect, segment, pose, obb, classify [default: detect]
                           Selects the matching nano model when --model is omitted
    --source, -s <SOURCE>  Input source (image, directory, glob, video, webcam, or URL)
    --conf <CONF>          Confidence threshold [default: 0.25]
    --iou <IOU>            IoU threshold for NMS [default: 0.7]
    --max-det <MAX_DET>    Maximum number of detections [default: 300]
    --imgsz <IMGSZ>        Inference image size [default: model metadata]
    --rect                 Enable rectangular inference (minimal padding) [default: true]
    --batch <BATCH>        Batch size for inference [default: 1]
    --half                 Use FP16 half-precision inference [default: false]
    --save                 Save annotated images to runs/<task>/predict [default: true]
    --save-frames          Save individual frames for video input (instead of video file)
    --show                 Display results in a window [default: false]
    --device <DEVICE>      Device (cpu, cuda:0, mps, coreml, directml:0, openvino, tensorrt:0, xnnpack)
    --verbose              Show verbose output [default: true]
    --classes <CLASSES>    Filter by class IDs (e.g., "0", "0,1,2", "[0, 1]")

Examples:
    ultralytics-inference predict
    ultralytics-inference predict --task segment
    ultralytics-inference predict --task pose
    ultralytics-inference predict --task obb --source aerial.jpg
    ultralytics-inference predict --task classify --source image.jpg
    ultralytics-inference predict --model yolo26n.onnx --source image.jpg
    ultralytics-inference predict --source video.mp4 --rect
    ultralytics-inference predict --source video.mp4 --save-frames
    ultralytics-inference predict --source 0 --conf 0.5 --show
    ultralytics-inference predict --source assets/ --save --half
    ultralytics-inference predict --source image.jpg --device cuda:0
    ultralytics-inference predict --source image.jpg --classes 0"#)]
pub struct Cli {
    #[command(subcommand)]
    /// Subcommand to execute.
    pub command: Commands,
}

/// Commands for the CLI.
#[derive(Subcommand, Debug)]
pub enum Commands {
    /// Run inference on an image, video, or stream
    Predict(PredictArgs),
    /// Print version information
    Version,
}

/// Arguments for the predict command.
#[derive(Args, Debug)]
#[allow(clippy::struct_excessive_bools)]
pub struct PredictArgs {
    /// Path to ONNX model file
    #[arg(short, long)]
    pub model: Option<String>,

    /// Task type; selects nano model for auto-download when --model is omitted
    /// (detect, segment, pose, obb, classify)
    #[arg(long)]
    pub task: Option<Task>,

    /// Input source (image, directory, glob, video, webcam, or URL)
    #[arg(short, long)]
    pub source: Option<String>,

    /// Confidence threshold
    #[arg(long, default_value_t = InferenceConfig::DEFAULT_CONF)]
    pub conf: f32,

    /// `IoU` threshold for NMS
    #[arg(long, default_value_t = InferenceConfig::DEFAULT_IOU)]
    pub iou: f32,

    /// Maximum number of detections
    #[arg(long, default_value_t = InferenceConfig::DEFAULT_MAX_DET)]
    pub max_det: usize,

    /// Inference image size
    #[arg(long)]
    pub imgsz: Option<usize>,

    /// Enable minimal padding (rectangular inference)
    #[arg(long, default_value_t = InferenceConfig::DEFAULT_RECT, num_args = 0..=1, default_missing_value = "true", action = clap::ArgAction::Set)]
    pub rect: bool,

    /// Batch size for inference
    #[arg(long, default_value_t = 1, value_parser = clap::value_parser!(u32).range(1..))]
    pub batch: u32,

    /// Use FP16 half-precision inference
    #[arg(long, default_value_t = InferenceConfig::DEFAULT_HALF)]
    pub half: bool,

    /// Save annotated images to runs/\<task\>/predict
    #[arg(long, default_value_t = InferenceConfig::DEFAULT_SAVE, num_args = 0..=1, default_missing_value = "true", action = clap::ArgAction::Set)]
    pub save: bool,

    /// Save individual frames for video input (instead of video file)
    #[arg(long, default_value_t = InferenceConfig::DEFAULT_SAVE_FRAMES)]
    pub save_frames: bool,

    /// Display results in a window
    #[arg(long, default_value_t = false)]
    pub show: bool,

    /// Device to use (cpu, cuda:0, mps, coreml, directml:0, openvino, tensorrt:0, etc.)
    #[arg(long)]
    pub device: Option<String>,

    /// Show verbose output
    #[arg(long, default_value_t = true, action = clap::ArgAction::Set)]
    pub verbose: bool,

    /// Filter by class IDs (e.g. 0 or "0,1,2" or "[0, 1, 2]")
    ///
    /// Supported formats:
    /// - Single integer: --classes 0
    /// - Comma-separated list: --classes "0,1,2"
    /// - List syntax: --classes "[0, 1, 2]"
    ///
    /// Note: When passing a list directly without quotes, avoid spaces to prevent
    /// shell argument parsing issues (e.g. use --classes 0,1 not --classes 0, 1).
    #[arg(long, allow_hyphen_values = true)]
    pub classes: Option<String>,
}

/// Parse class IDs from various formats: `"1,2,3"`, `"[1,2,3]"`, `"(1,2,3)"`
///
/// # Errors
///
/// Returns a `String` error message if any segment of the string cannot be parsed as a `usize`.
pub fn parse_classes(s: &str) -> Result<Vec<usize>, String> {
    // Remove brackets, parentheses, and quotes
    let cleaned = s
        .trim()
        .trim_matches(|c| c == '[' || c == ']' || c == '(' || c == ')' || c == '"' || c == '\'');

    if cleaned.is_empty() {
        return Ok(Vec::new());
    }

    cleaned
        .split(',')
        .map(|part| {
            part.trim()
                .parse::<usize>()
                .map_err(|e| format!("Invalid class ID '{}': {}", part.trim(), e))
        })
        .collect()
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn verify_cli() {
        use clap::CommandFactory;
        Cli::command().debug_assert();
    }

    #[test]
    fn test_predict_args_defaults() {
        let args = Cli::parse_from(["app", "predict", "--model", "yolo26n.onnx"]);
        let Commands::Predict(predict_args) = args.command else {
            panic!("expected predict command");
        };
        assert_eq!(predict_args.model, Some("yolo26n.onnx".to_string()));
        assert!((predict_args.conf - InferenceConfig::DEFAULT_CONF).abs() < f32::EPSILON);
        assert!((predict_args.iou - InferenceConfig::DEFAULT_IOU).abs() < f32::EPSILON);
        assert!(predict_args.rect);
        assert_eq!(predict_args.max_det, 300);
        assert!(!predict_args.half);
        assert!(predict_args.verbose);
        assert!(predict_args.source.is_none());
    }

    #[test]
    fn test_predict_args_custom() {
        let args = Cli::parse_from([
            "app",
            "predict",
            "--model",
            "custom.onnx",
            "--source",
            "test.jpg",
            "--conf",
            "0.8",
            "--verbose",
            "false",
        ]);
        let Commands::Predict(predict_args) = args.command else {
            panic!("expected predict command");
        };
        assert_eq!(predict_args.model, Some("custom.onnx".to_string()));
        assert_eq!(predict_args.source, Some("test.jpg".to_string()));
        assert!((predict_args.conf - 0.8).abs() < f32::EPSILON);
        assert!(!predict_args.verbose);
    }

    #[test]
    fn test_version_command() {
        let args = Cli::parse_from(["app", "version"]);
        assert!(matches!(args.command, Commands::Version));
    }
}