use crate::{
Color, ColorSource, Hbb, HbbStyle, HbbStyleMode, Keypoint, KeypointStyle, KeypointStyleMode,
TextStyle, TextStyleMode,
};
#[derive(Debug, Clone, Default)]
pub struct Sam3Prompt {
pub text: String,
pub boxes: Vec<Hbb>,
pub points: Vec<Keypoint>,
}
impl Sam3Prompt {
pub const VISUAL: &'static str = "visual";
pub const POSITIVE: &'static str = "positive";
pub const NEGATIVE: &'static str = "negative";
pub fn new(text: &str) -> Self {
Self {
text: text.to_string(),
boxes: Vec::new(),
points: Vec::new(),
}
}
pub fn visual() -> Self {
Self::new(Self::VISUAL)
}
pub fn with_text(mut self, text: &str) -> Self {
self.text = text.to_string();
self
}
pub fn with_box(mut self, x: f32, y: f32, w: f32, h: f32, positive: bool) -> Self {
let (name, color) = if positive {
(Self::POSITIVE, Color::cyan())
} else {
(Self::NEGATIVE, Color::red())
};
self.boxes.push(
Hbb::from_xywh(x, y, w, h)
.with_name(name)
.with_confidence(1.0) .with_style(
HbbStyle::default()
.with_mode(HbbStyleMode::dashed())
.with_thickness(6)
.with_draw_fill(true)
.with_draw_outline(true)
.with_outline_color(ColorSource::Custom(color))
.with_text_visible(true)
.with_text_style(
TextStyle::default()
.with_mode(TextStyleMode::rect(5.))
.with_draw_fill(true)
.with_bg_fill_color(ColorSource::Custom(color)),
)
.show_id(false)
.show_confidence(false),
),
);
self
}
pub fn with_positive_box(self, x: f32, y: f32, w: f32, h: f32) -> Self {
self.with_box(x, y, w, h, true)
}
pub fn with_negative_box(self, x: f32, y: f32, w: f32, h: f32) -> Self {
self.with_box(x, y, w, h, false)
}
pub fn with_point(mut self, x: f32, y: f32, positive: bool) -> Self {
let (name, color) = if positive {
(Self::POSITIVE, Color::green())
} else {
(Self::NEGATIVE, Color::red())
};
self.points.push(
Keypoint::default()
.with_xy(x, y)
.with_name(name)
.with_confidence(1.0) .with_style(
KeypointStyle::default()
.with_mode(KeypointStyleMode::star())
.with_radius(15)
.with_draw_fill(true)
.with_draw_outline(true)
.with_fill_color(ColorSource::Custom(color))
.with_text_visible(false),
),
);
self
}
pub fn with_positive_point(self, x: f32, y: f32) -> Self {
self.with_point(x, y, true)
}
pub fn with_negative_point(self, x: f32, y: f32) -> Self {
self.with_point(x, y, false)
}
pub fn has_boxes(&self) -> bool {
!self.boxes.is_empty()
}
pub fn has_points(&self) -> bool {
!self.points.is_empty()
}
pub fn has_positive_box(&self) -> bool {
self.boxes.iter().any(|b| b.name() == Some(Self::POSITIVE))
}
pub fn is_visual(&self) -> bool {
self.text == Self::VISUAL
}
pub fn should_use_geometry(&self) -> bool {
if self.is_visual() {
self.has_positive_box()
} else {
self.has_boxes()
}
}
pub fn box_labels(&self) -> Vec<i64> {
self.boxes
.iter()
.map(|b| {
if b.name() == Some(Self::POSITIVE) {
1
} else {
0
}
})
.collect()
}
pub fn point_labels(&self) -> Vec<i64> {
self.points
.iter()
.map(|p| {
if p.name() == Some(Self::POSITIVE) {
1
} else {
0
}
})
.collect()
}
pub fn normalized_boxes(&self, image_width: f32, image_height: f32) -> Vec<[f32; 4]> {
self.boxes
.iter()
.map(|hbb| {
let (x, y, w, h) = hbb.xywh();
let cx = (x + w / 2.0) / image_width;
let cy = (y + h / 2.0) / image_height;
let nw = w / image_width;
let nh = h / image_height;
[cx, cy, nw, nh]
})
.collect()
}
pub fn scaled_points(&self, scale_x: f32, scale_y: f32) -> Vec<[f32; 2]> {
self.points
.iter()
.map(|kpt| [kpt.x() * scale_x, kpt.y() * scale_y])
.collect()
}
pub fn scaled_boxes_xyxy(&self, scale_x: f32, scale_y: f32) -> Vec<[f32; 4]> {
self.boxes
.iter()
.map(|hbb| {
let (x1, y1, x2, y2) = hbb.xyxy();
[x1 * scale_x, y1 * scale_y, x2 * scale_x, y2 * scale_y]
})
.collect()
}
pub fn class_name(&self) -> &str {
&self.text
}
fn parse_coords(s: &str) -> std::result::Result<Vec<f32>, String> {
s.split(',')
.map(|x| {
x.trim()
.parse::<f32>()
.map_err(|e| format!("Invalid coordinate '{}': {}", x.trim(), e))
})
.collect()
}
}
impl std::str::FromStr for Sam3Prompt {
type Err = String;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
let parts: Vec<&str> = s.split(';').collect();
if parts.is_empty() {
return Err("Empty prompt string".to_string());
}
let first = parts[0].trim();
let (text, geo_parts) = if first.starts_with("pos:") || first.starts_with("neg:") {
(Self::VISUAL, parts.as_slice())
} else {
(first, &parts[1..])
};
let mut prompt = Self::new(text);
for part in geo_parts {
let part = part.trim();
if part.is_empty() {
continue;
}
if let Some(coords_str) = part.strip_prefix("pos:") {
let coords = Self::parse_coords(coords_str)?;
match coords.len() {
2 => prompt = prompt.with_positive_point(coords[0], coords[1]),
4 => {
prompt =
prompt.with_positive_box(coords[0], coords[1], coords[2], coords[3])
}
n => {
return Err(format!(
"pos: expects 2 (point) or 4 (box) coords, got {}",
n
))
}
}
} else if let Some(coords_str) = part.strip_prefix("neg:") {
let coords = Self::parse_coords(coords_str)?;
match coords.len() {
2 => prompt = prompt.with_negative_point(coords[0], coords[1]),
4 => {
prompt =
prompt.with_negative_box(coords[0], coords[1], coords[2], coords[3])
}
n => {
return Err(format!(
"neg: expects 2 (point) or 4 (box) coords, got {}",
n
))
}
}
} else {
return Err(format!(
"Invalid format: '{}'. Use 'pos:x,y' (point) or 'pos:x,y,w,h' (box)",
part
));
}
}
Ok(prompt)
}
}