1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
// Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
//! Task definitions for YOLO models.
//!
//! This module defines the different tasks that YOLO models can perform,
//! along with their associated capabilities and string representations.
use std::fmt;
use std::str::FromStr;
/// YOLO model task types.
///
/// Each task type corresponds to a different computer vision problem
/// that YOLO models can solve. The task type determines the expected
/// model outputs and post-processing steps.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum Task {
/// Object detection.
/// Predicts bounding boxes and class labels for objects in an image.
#[default]
Detect,
/// Instance segmentation.
/// Predicts bounding boxes, class labels, and pixel-level masks for objects.
Segment,
/// Pose estimation.
/// Predicts bounding boxes and skeletal keypoints for objects (e.g., humans).
Pose,
/// Image classification.
/// Predicts class probabilities for the entire image (no localization).
Classify,
/// Oriented bounding box detection (OBB).
/// Predicts rotated bounding boxes for objects, useful for aerial imagery etc.
Obb,
}
impl Task {
/// Get the string representation used in ONNX model metadata
/// (e.g. `"detect"`, `"segment"`).
#[must_use]
pub const fn as_str(&self) -> &'static str {
match self {
Self::Detect => "detect",
Self::Segment => "segment",
Self::Pose => "pose",
Self::Classify => "classify",
Self::Obb => "obb",
}
}
/// ONNX filename suffix for this task, used to construct `yolo26n{suffix}.onnx`.
///
/// ```
/// use ultralytics_inference::Task;
/// assert_eq!(Task::Detect.model_suffix(), "");
/// assert_eq!(Task::Segment.model_suffix(), "-seg");
/// ```
#[must_use]
pub const fn model_suffix(&self) -> &'static str {
match self {
Self::Detect => "",
Self::Segment => "-seg",
Self::Pose => "-pose",
Self::Classify => "-cls",
Self::Obb => "-obb",
}
}
/// Default nano YOLO26 model filename for this task.
///
/// Used by the CLI to auto-pick a model when `--model` is omitted but `--task` is set.
///
/// ```
/// use ultralytics_inference::Task;
/// assert_eq!(Task::Detect.default_model(), "yolo26n.onnx");
/// assert_eq!(Task::Segment.default_model(), "yolo26n-seg.onnx");
/// ```
#[must_use]
pub fn default_model(&self) -> String {
format!("yolo26n{}.onnx", self.model_suffix())
}
/// Returns `true` when the task outputs bounding boxes — namely Detect, Segment, Pose, and Obb.
#[must_use]
pub const fn has_boxes(&self) -> bool {
matches!(self, Self::Detect | Self::Segment | Self::Pose | Self::Obb)
}
/// Returns `true` only for the Segment task, which outputs per-instance segmentation masks.
#[must_use]
pub const fn has_masks(&self) -> bool {
matches!(self, Self::Segment)
}
/// Returns `true` only for the Pose task, which outputs skeletal keypoints.
#[must_use]
pub const fn has_keypoints(&self) -> bool {
matches!(self, Self::Pose)
}
/// Returns `true` only for the Classify task, which outputs global class probabilities.
#[must_use]
pub const fn has_probs(&self) -> bool {
matches!(self, Self::Classify)
}
/// Returns `true` only for the Obb task, which outputs oriented (rotated) bounding boxes.
#[must_use]
pub const fn has_obb(&self) -> bool {
matches!(self, Self::Obb)
}
}
impl fmt::Display for Task {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
impl FromStr for Task {
type Err = TaskParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"detect" | "detection" => Ok(Self::Detect),
"segment" | "segmentation" => Ok(Self::Segment),
"pose" | "keypoint" | "keypoints" => Ok(Self::Pose),
"classify" | "classification" | "cls" => Ok(Self::Classify),
"obb" | "oriented" => Ok(Self::Obb),
_ => Err(TaskParseError(s.to_string())),
}
}
}
/// Error returned when parsing an invalid task string.
#[derive(Debug, Clone)]
pub struct TaskParseError(String);
impl fmt::Display for TaskParseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"invalid task '{}', expected one of: detect, segment, pose, classify, obb",
self.0
)
}
}
impl std::error::Error for TaskParseError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_task_from_str() {
assert_eq!("detect".parse::<Task>().unwrap(), Task::Detect);
assert_eq!("segment".parse::<Task>().unwrap(), Task::Segment);
assert_eq!("pose".parse::<Task>().unwrap(), Task::Pose);
assert_eq!("classify".parse::<Task>().unwrap(), Task::Classify);
assert_eq!("obb".parse::<Task>().unwrap(), Task::Obb);
// Alternative names
assert_eq!("detection".parse::<Task>().unwrap(), Task::Detect);
assert_eq!("segmentation".parse::<Task>().unwrap(), Task::Segment);
assert_eq!("keypoints".parse::<Task>().unwrap(), Task::Pose);
assert_eq!("cls".parse::<Task>().unwrap(), Task::Classify);
}
#[test]
fn test_task_display() {
assert_eq!(Task::Detect.to_string(), "detect");
assert_eq!(Task::Segment.to_string(), "segment");
}
#[test]
fn test_task_capabilities() {
assert!(Task::Detect.has_boxes());
assert!(!Task::Detect.has_masks());
assert!(Task::Segment.has_masks());
assert!(Task::Pose.has_keypoints());
assert!(Task::Classify.has_probs());
assert!(Task::Obb.has_obb());
}
}