use crate::{Result, VisionError};
use std::collections::HashMap;
use std::path::Path;
use std::sync::{Arc, Mutex};
use torsh_core::dtype::DType;
use torsh_core::{Device, DeviceType};
use torsh_tensor::Tensor;
#[derive(Clone)]
pub struct InteractiveViewer {
current_image: Option<Tensor>,
annotations: Vec<Annotation>,
config: ViewerConfig,
event_handlers: Arc<Mutex<HashMap<String, Box<dyn Fn(&ViewerEvent) + Send + Sync>>>>,
}
impl std::fmt::Debug for InteractiveViewer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("InteractiveViewer")
.field("current_image", &self.current_image)
.field("annotations", &self.annotations)
.field("config", &self.config)
.field("event_handlers", &"<event_handlers>")
.finish()
}
}
#[derive(Debug, Clone)]
pub struct ViewerConfig {
pub width: u32,
pub height: u32,
pub show_zoom_controls: bool,
pub show_annotation_tools: bool,
pub default_annotation_color: [u8; 3],
pub background_color: [u8; 3],
}
impl Default for ViewerConfig {
fn default() -> Self {
Self {
width: 800,
height: 600,
show_zoom_controls: true,
show_annotation_tools: true,
default_annotation_color: [255, 0, 0], background_color: [240, 240, 240], }
}
}
#[derive(Debug, Clone)]
pub enum Annotation {
BoundingBox {
x: f32,
y: f32,
width: f32,
height: f32,
label: String,
color: [u8; 3],
confidence: Option<f32>,
},
Point {
x: f32,
y: f32,
label: String,
color: [u8; 3],
radius: f32,
},
Polygon {
points: Vec<(f32, f32)>,
label: String,
color: [u8; 3],
filled: bool,
},
Text {
x: f32,
y: f32,
text: String,
color: [u8; 3],
font_size: f32,
},
Mask {
mask: Tensor,
color: [u8; 3],
alpha: f32,
label: String,
},
}
#[derive(Debug, Clone)]
pub enum ViewerEvent {
MouseClick { x: f32, y: f32, button: MouseButton },
MouseMove { x: f32, y: f32 },
KeyPress { key: String },
AnnotationCreated { annotation: Annotation },
AnnotationSelected { index: usize },
AnnotationModified {
index: usize,
annotation: Annotation,
},
AnnotationDeleted { index: usize },
ZoomChanged { zoom_level: f32 },
ImageChanged { image: Tensor },
}
#[derive(Debug, Clone, PartialEq)]
pub enum MouseButton {
Left,
Right,
Middle,
}
impl InteractiveViewer {
pub fn new() -> Self {
Self {
current_image: None,
annotations: Vec::new(),
config: ViewerConfig::default(),
event_handlers: Arc::new(Mutex::new(HashMap::new())),
}
}
pub fn with_config(config: ViewerConfig) -> Self {
Self {
current_image: None,
annotations: Vec::new(),
config,
event_handlers: Arc::new(Mutex::new(HashMap::new())),
}
}
pub fn load_image(&mut self, image: Tensor) -> Result<()> {
if image.ndim() < 2 || image.ndim() > 3 {
return Err(VisionError::InvalidInput(
"Image must be 2D (grayscale) or 3D (color)".to_string(),
));
}
self.current_image = Some(image.clone());
let event = ViewerEvent::ImageChanged { image };
self.emit_event(event);
Ok(())
}
pub fn load_image_from_path<P: AsRef<Path>>(&mut self, path: P) -> Result<()> {
use crate::io::{global_io, VisionIO};
use crate::utils::image_to_tensor;
let dynamic_image = global_io().load_image(path)?;
let tensor = image_to_tensor(&dynamic_image)?;
self.load_image(tensor)
}
pub fn add_annotation(&mut self, annotation: Annotation) {
self.annotations.push(annotation.clone());
let event = ViewerEvent::AnnotationCreated { annotation };
self.emit_event(event);
}
pub fn remove_annotation(&mut self, index: usize) -> Result<()> {
if index >= self.annotations.len() {
return Err(VisionError::InvalidArgument(format!(
"Annotation index {} out of bounds",
index
)));
}
self.annotations.remove(index);
let event = ViewerEvent::AnnotationDeleted { index };
self.emit_event(event);
Ok(())
}
pub fn update_annotation(&mut self, index: usize, annotation: Annotation) -> Result<()> {
if index >= self.annotations.len() {
return Err(VisionError::InvalidArgument(format!(
"Annotation index {} out of bounds",
index
)));
}
self.annotations[index] = annotation.clone();
let event = ViewerEvent::AnnotationModified { index, annotation };
self.emit_event(event);
Ok(())
}
pub fn annotations(&self) -> &[Annotation] {
&self.annotations
}
pub fn clear_annotations(&mut self) {
self.annotations.clear();
}
pub fn current_image(&self) -> Option<&Tensor> {
self.current_image.as_ref()
}
pub fn on_event<F>(&mut self, event_name: String, handler: F)
where
F: Fn(&ViewerEvent) + Send + Sync + 'static,
{
let mut handlers = self
.event_handlers
.lock()
.expect("lock should not be poisoned");
handlers.insert(event_name, Box::new(handler));
}
fn emit_event(&self, event: ViewerEvent) {
let handlers = self
.event_handlers
.lock()
.expect("lock should not be poisoned");
for handler in handlers.values() {
handler(&event);
}
}
pub fn handle_mouse_click(&mut self, x: f32, y: f32, button: MouseButton) {
let event = ViewerEvent::MouseClick { x, y, button };
self.emit_event(event);
}
pub fn handle_mouse_move(&mut self, x: f32, y: f32) {
let event = ViewerEvent::MouseMove { x, y };
self.emit_event(event);
}
pub fn handle_key_press(&mut self, key: String) {
let event = ViewerEvent::KeyPress { key };
self.emit_event(event);
}
pub fn export_annotations(&self) -> Result<String> {
use serde_json::json;
let annotations_json: Vec<serde_json::Value> = self
.annotations
.iter()
.map(|ann| match ann {
Annotation::BoundingBox {
x,
y,
width,
height,
label,
color,
confidence,
} => {
json!({
"type": "bounding_box",
"x": x,
"y": y,
"width": width,
"height": height,
"label": label,
"color": color,
"confidence": confidence
})
}
Annotation::Point {
x,
y,
label,
color,
radius,
} => {
json!({
"type": "point",
"x": x,
"y": y,
"label": label,
"color": color,
"radius": radius
})
}
Annotation::Polygon {
points,
label,
color,
filled,
} => {
json!({
"type": "polygon",
"points": points,
"label": label,
"color": color,
"filled": filled
})
}
Annotation::Text {
x,
y,
text,
color,
font_size,
} => {
json!({
"type": "text",
"x": x,
"y": y,
"text": text,
"color": color,
"font_size": font_size
})
}
Annotation::Mask {
label,
color,
alpha,
..
} => {
json!({
"type": "mask",
"label": label,
"color": color,
"alpha": alpha
})
}
})
.collect();
let export = json!({
"annotations": annotations_json,
"config": {
"width": self.config.width,
"height": self.config.height
}
});
Ok(serde_json::to_string_pretty(&export).map_err(|e| {
VisionError::InvalidArgument(format!("JSON serialization error: {}", e))
})?)
}
pub fn import_annotations(&mut self, json_str: &str) -> Result<()> {
let data: serde_json::Value = serde_json::from_str(json_str)
.map_err(|e| VisionError::InvalidArgument(format!("JSON parsing error: {}", e)))?;
if let Some(annotations) = data["annotations"].as_array() {
self.annotations.clear();
for ann_json in annotations {
let annotation = self.parse_annotation_from_json(ann_json)?;
self.annotations.push(annotation);
}
}
Ok(())
}
fn parse_annotation_from_json(&self, json: &serde_json::Value) -> Result<Annotation> {
let ann_type = json["type"]
.as_str()
.ok_or_else(|| VisionError::InvalidInput("Missing annotation type".to_string()))?;
match ann_type {
"bounding_box" => Ok(Annotation::BoundingBox {
x: json["x"].as_f64().unwrap_or(0.0) as f32,
y: json["y"].as_f64().unwrap_or(0.0) as f32,
width: json["width"].as_f64().unwrap_or(0.0) as f32,
height: json["height"].as_f64().unwrap_or(0.0) as f32,
label: json["label"].as_str().unwrap_or("").to_string(),
color: [
json["color"][0].as_u64().unwrap_or(255) as u8,
json["color"][1].as_u64().unwrap_or(0) as u8,
json["color"][2].as_u64().unwrap_or(0) as u8,
],
confidence: json["confidence"].as_f64().map(|v| v as f32),
}),
"point" => Ok(Annotation::Point {
x: json["x"].as_f64().unwrap_or(0.0) as f32,
y: json["y"].as_f64().unwrap_or(0.0) as f32,
label: json["label"].as_str().unwrap_or("").to_string(),
color: [
json["color"][0].as_u64().unwrap_or(255) as u8,
json["color"][1].as_u64().unwrap_or(0) as u8,
json["color"][2].as_u64().unwrap_or(0) as u8,
],
radius: json["radius"].as_f64().unwrap_or(3.0) as f32,
}),
"polygon" => {
let points = json["points"]
.as_array()
.ok_or_else(|| VisionError::InvalidInput("Missing polygon points".to_string()))?
.iter()
.map(|p| {
(
p[0].as_f64().unwrap_or(0.0) as f32,
p[1].as_f64().unwrap_or(0.0) as f32,
)
})
.collect();
Ok(Annotation::Polygon {
points,
label: json["label"].as_str().unwrap_or("").to_string(),
color: [
json["color"][0].as_u64().unwrap_or(255) as u8,
json["color"][1].as_u64().unwrap_or(0) as u8,
json["color"][2].as_u64().unwrap_or(0) as u8,
],
filled: json["filled"].as_bool().unwrap_or(false),
})
}
"text" => Ok(Annotation::Text {
x: json["x"].as_f64().unwrap_or(0.0) as f32,
y: json["y"].as_f64().unwrap_or(0.0) as f32,
text: json["text"].as_str().unwrap_or("").to_string(),
color: [
json["color"][0].as_u64().unwrap_or(255) as u8,
json["color"][1].as_u64().unwrap_or(0) as u8,
json["color"][2].as_u64().unwrap_or(0) as u8,
],
font_size: json["font_size"].as_f64().unwrap_or(12.0) as f32,
}),
_ => Err(VisionError::InvalidInput(format!(
"Unknown annotation type: {}",
ann_type
))),
}
}
}
#[derive(Debug)]
pub struct InteractiveGallery {
images: Vec<(String, Tensor)>, current_index: usize,
config: GalleryConfig,
annotations: HashMap<String, Vec<Annotation>>,
}
#[derive(Debug, Clone)]
pub struct GalleryConfig {
pub thumbnail_size: (u32, u32),
pub images_per_row: usize,
pub show_names: bool,
pub show_navigation: bool,
}
impl Default for GalleryConfig {
fn default() -> Self {
Self {
thumbnail_size: (150, 150),
images_per_row: 4,
show_names: true,
show_navigation: true,
}
}
}
impl InteractiveGallery {
pub fn new() -> Self {
Self {
images: Vec::new(),
current_index: 0,
config: GalleryConfig::default(),
annotations: HashMap::new(),
}
}
pub fn with_config(config: GalleryConfig) -> Self {
Self {
images: Vec::new(),
current_index: 0,
config,
annotations: HashMap::new(),
}
}
pub fn add_image(&mut self, name: String, image: Tensor) -> Result<()> {
if image.ndim() < 2 || image.ndim() > 3 {
return Err(VisionError::InvalidInput(
"Image must be 2D (grayscale) or 3D (color)".to_string(),
));
}
self.images.push((name.clone(), image));
self.annotations.insert(name, Vec::new());
Ok(())
}
pub fn load_from_directory<P: AsRef<Path>>(&mut self, dir_path: P) -> Result<()> {
use crate::io::{global_io, VisionIO};
use crate::utils::image_to_tensor;
use std::fs;
let dir = fs::read_dir(dir_path)?;
for entry in dir {
let entry = entry?;
let path = entry.path();
if let Some(extension) = path.extension() {
let ext_str = extension.to_str().unwrap_or("").to_lowercase();
if matches!(
ext_str.as_str(),
"jpg" | "jpeg" | "png" | "bmp" | "tiff" | "webp"
) {
let name = path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("unknown")
.to_string();
let dynamic_image = global_io().load_image(&path)?;
let tensor = image_to_tensor(&dynamic_image)?;
self.add_image(name, tensor)?;
}
}
}
Ok(())
}
pub fn current_image(&self) -> Option<&(String, Tensor)> {
self.images.get(self.current_index)
}
pub fn next_image(&mut self) -> Result<()> {
if self.images.is_empty() {
return Err(VisionError::InvalidInput(
"No images in gallery".to_string(),
));
}
self.current_index = (self.current_index + 1) % self.images.len();
Ok(())
}
pub fn previous_image(&mut self) -> Result<()> {
if self.images.is_empty() {
return Err(VisionError::InvalidInput(
"No images in gallery".to_string(),
));
}
if self.current_index == 0 {
self.current_index = self.images.len() - 1;
} else {
self.current_index -= 1;
}
Ok(())
}
pub fn goto_image(&mut self, index: usize) -> Result<()> {
if index >= self.images.len() {
return Err(VisionError::InvalidArgument(format!(
"Image index {} out of bounds",
index
)));
}
self.current_index = index;
Ok(())
}
pub fn image_names(&self) -> Vec<&String> {
self.images.iter().map(|(name, _)| name).collect()
}
pub fn len(&self) -> usize {
self.images.len()
}
pub fn is_empty(&self) -> bool {
self.images.is_empty()
}
pub fn add_annotation_to_current(&mut self, annotation: Annotation) -> Result<()> {
let current_name = self
.current_image()
.ok_or_else(|| VisionError::InvalidInput("No current image".to_string()))?
.0
.clone();
self.annotations
.entry(current_name)
.or_default()
.push(annotation);
Ok(())
}
pub fn current_annotations(&self) -> Option<&Vec<Annotation>> {
let current_name = &self.current_image()?.0;
self.annotations.get(current_name)
}
pub fn clear_current_annotations(&mut self) -> Result<()> {
let current_name = self
.current_image()
.ok_or_else(|| VisionError::InvalidInput("No current image".to_string()))?
.0
.clone();
self.annotations.insert(current_name, Vec::new());
Ok(())
}
}
#[derive(Debug)]
pub struct LiveVisualization {
current_frame: Option<Tensor>,
frame_buffer: std::collections::VecDeque<Tensor>,
buffer_size: usize,
fps_counter: FpsCounter,
config: LiveConfig,
}
#[derive(Debug, Clone)]
pub struct LiveConfig {
pub target_fps: f32,
pub buffer_size: usize,
pub show_fps: bool,
pub show_metrics: bool,
}
impl Default for LiveConfig {
fn default() -> Self {
Self {
target_fps: 30.0,
buffer_size: 10,
show_fps: true,
show_metrics: false,
}
}
}
#[derive(Debug)]
pub struct FpsCounter {
frame_times: std::collections::VecDeque<std::time::Instant>,
window_size: usize,
}
impl FpsCounter {
fn new(window_size: usize) -> Self {
Self {
frame_times: std::collections::VecDeque::new(),
window_size,
}
}
fn update(&mut self) {
let now = std::time::Instant::now();
self.frame_times.push_back(now);
if self.frame_times.len() > self.window_size {
self.frame_times.pop_front();
}
}
fn current_fps(&self) -> f32 {
if self.frame_times.len() < 2 {
return 0.0;
}
let elapsed = self
.frame_times
.back()
.expect("frame_times should have back element")
.duration_since(
*self
.frame_times
.front()
.expect("frame_times should have front element"),
);
let num_frames = self.frame_times.len() - 1;
num_frames as f32 / elapsed.as_secs_f32()
}
}
impl LiveVisualization {
pub fn new() -> Self {
Self::with_config(LiveConfig::default())
}
pub fn with_config(config: LiveConfig) -> Self {
Self {
current_frame: None,
frame_buffer: std::collections::VecDeque::with_capacity(config.buffer_size),
buffer_size: config.buffer_size,
fps_counter: FpsCounter::new(30),
config,
}
}
pub fn add_frame(&mut self, frame: Tensor) -> Result<()> {
if frame.ndim() < 2 || frame.ndim() > 3 {
return Err(VisionError::InvalidInput(
"Frame must be 2D (grayscale) or 3D (color)".to_string(),
));
}
if self.frame_buffer.len() >= self.buffer_size {
self.frame_buffer.pop_front();
}
self.frame_buffer.push_back(frame.clone());
self.current_frame = Some(frame);
self.fps_counter.update();
Ok(())
}
pub fn current_frame(&self) -> Option<&Tensor> {
self.current_frame.as_ref()
}
pub fn current_fps(&self) -> f32 {
self.fps_counter.current_fps()
}
pub fn buffer_len(&self) -> usize {
self.frame_buffer.len()
}
pub fn clear_buffer(&mut self) {
self.frame_buffer.clear();
self.current_frame = None;
}
}
impl Default for InteractiveViewer {
fn default() -> Self {
Self::new()
}
}
impl Default for InteractiveGallery {
fn default() -> Self {
Self::new()
}
}
impl Default for LiveVisualization {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_core::{DType, Device};
#[test]
fn test_interactive_viewer_creation() {
let viewer = InteractiveViewer::new();
assert!(viewer.current_image().is_none());
assert_eq!(viewer.annotations().len(), 0);
}
#[test]
fn test_interactive_viewer_load_image() {
let mut viewer = InteractiveViewer::new();
let image = Tensor::zeros(&[3, 224, 224], DeviceType::Cpu).unwrap();
viewer.load_image(image).unwrap();
assert!(viewer.current_image().is_some());
}
#[test]
fn test_annotation_management() {
let mut viewer = InteractiveViewer::new();
let annotation = Annotation::BoundingBox {
x: 10.0,
y: 20.0,
width: 50.0,
height: 30.0,
label: "test".to_string(),
color: [255, 0, 0],
confidence: Some(0.95),
};
viewer.add_annotation(annotation);
assert_eq!(viewer.annotations().len(), 1);
viewer.remove_annotation(0).unwrap();
assert_eq!(viewer.annotations().len(), 0);
}
#[test]
fn test_interactive_gallery() {
let mut gallery = InteractiveGallery::new();
let image1 = Tensor::zeros(&[3, 224, 224], DeviceType::Cpu).unwrap();
let image2 = Tensor::ones(&[3, 224, 224], DeviceType::Cpu).unwrap();
gallery.add_image("image1".to_string(), image1).unwrap();
gallery.add_image("image2".to_string(), image2).unwrap();
assert_eq!(gallery.len(), 2);
assert!(!gallery.is_empty());
gallery.next_image().unwrap();
let (current_name, _) = gallery.current_image().unwrap();
assert_eq!(current_name, "image2");
}
#[test]
fn test_live_visualization() {
let mut live_viz = LiveVisualization::new();
let frame = Tensor::zeros(&[3, 480, 640], DeviceType::Cpu).unwrap();
live_viz.add_frame(frame).unwrap();
assert!(live_viz.current_frame().is_some());
assert_eq!(live_viz.buffer_len(), 1);
}
#[test]
fn test_annotation_export_import() {
let mut viewer = InteractiveViewer::new();
let annotation = Annotation::Point {
x: 100.0,
y: 200.0,
label: "landmark".to_string(),
color: [0, 255, 0],
radius: 5.0,
};
viewer.add_annotation(annotation);
let exported = viewer.export_annotations().unwrap();
assert!(exported.contains("landmark"));
let mut new_viewer = InteractiveViewer::new();
new_viewer.import_annotations(&exported).unwrap();
assert_eq!(new_viewer.annotations().len(), 1);
}
}