use scirs2_core::ndarray::Array2;
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum PromptType {
Point {
x: usize,
y: usize,
is_foreground: bool,
},
BoundingBox {
x1: usize,
y1: usize,
x2: usize,
y2: usize,
},
MaskPrompt {
mask: Array2<f64>,
},
MultiPoint {
points: Vec<(usize, usize, bool)>,
},
}
#[derive(Debug, Clone)]
pub struct SAMConfig {
pub image_size: usize,
pub embed_dim: usize,
pub num_mask_outputs: usize,
pub iou_head_hidden: usize,
pub encoder_stages: usize,
}
impl Default for SAMConfig {
fn default() -> Self {
Self {
image_size: 1024,
embed_dim: 256,
num_mask_outputs: 3,
iou_head_hidden: 256,
encoder_stages: 3,
}
}
}
#[derive(Debug, Clone)]
pub struct SegmentationResult {
pub masks: Vec<Array2<f64>>,
pub iou_predictions: Vec<f64>,
pub stability_scores: Vec<f64>,
}
#[derive(Debug, Clone)]
pub struct SegmentationPrompt {
pub prompt_type: PromptType,
pub label: Option<String>,
}
impl SegmentationPrompt {
pub fn new(prompt_type: PromptType) -> Self {
Self {
prompt_type,
label: None,
}
}
pub fn with_label(prompt_type: PromptType, label: impl Into<String>) -> Self {
Self {
prompt_type,
label: Some(label.into()),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
#[test]
fn test_sam_config_default() {
let cfg = SAMConfig::default();
assert_eq!(cfg.image_size, 1024);
assert_eq!(cfg.embed_dim, 256);
assert_eq!(cfg.num_mask_outputs, 3);
assert_eq!(cfg.iou_head_hidden, 256);
assert_eq!(cfg.encoder_stages, 3);
}
#[test]
fn test_prompt_type_point() {
let p = PromptType::Point {
x: 10,
y: 20,
is_foreground: true,
};
if let PromptType::Point {
x,
y,
is_foreground,
} = &p
{
assert_eq!(*x, 10);
assert_eq!(*y, 20);
assert!(*is_foreground);
} else {
panic!("expected Point variant");
}
}
#[test]
fn test_prompt_type_bounding_box() {
let p = PromptType::BoundingBox {
x1: 5,
y1: 10,
x2: 50,
y2: 60,
};
if let PromptType::BoundingBox { x1, y1, x2, y2 } = &p {
assert_eq!(*x1, 5);
assert_eq!(*y1, 10);
assert_eq!(*x2, 50);
assert_eq!(*y2, 60);
} else {
panic!("expected BoundingBox variant");
}
}
#[test]
fn test_prompt_type_mask() {
let mask = Array2::<f64>::zeros((64, 64));
let p = PromptType::MaskPrompt { mask: mask.clone() };
if let PromptType::MaskPrompt { mask: m } = &p {
assert_eq!(m.dim(), (64, 64));
} else {
panic!("expected MaskPrompt variant");
}
}
#[test]
fn test_prompt_type_multipoint() {
let pts = vec![(1, 2, true), (3, 4, false)];
let p = PromptType::MultiPoint {
points: pts.clone(),
};
if let PromptType::MultiPoint { points } = &p {
assert_eq!(points.len(), 2);
} else {
panic!("expected MultiPoint variant");
}
}
#[test]
fn test_segmentation_prompt_new() {
let sp = SegmentationPrompt::new(PromptType::Point {
x: 0,
y: 0,
is_foreground: true,
});
assert!(sp.label.is_none());
}
#[test]
fn test_segmentation_prompt_with_label() {
let sp = SegmentationPrompt::with_label(
PromptType::BoundingBox {
x1: 0,
y1: 0,
x2: 10,
y2: 10,
},
"cat",
);
assert_eq!(sp.label.as_deref(), Some("cat"));
}
}