use crate::{Result, VisionError};
use image::{DynamicImage, GenericImageView};
use torsh_tensor::{creation, creation::zeros_mut, Tensor};
use super::text_rendering::draw_simple_text;
pub fn make_grid(tensors: &[Tensor<f32>], nrow: usize, padding: usize) -> Result<Tensor<f32>> {
if tensors.is_empty() {
return Err(VisionError::TransformError(
"No tensors provided".to_string(),
));
}
let first_shape = tensors[0].shape();
for tensor in tensors.iter().skip(1) {
if tensor.shape() != first_shape {
return Err(VisionError::InvalidArgument(
"All tensors must have the same shape".to_string(),
));
}
}
if first_shape.dims().len() != 3 {
return Err(VisionError::InvalidShape(format!(
"Expected 3D tensors (C, H, W), got {}D",
first_shape.dims().len()
)));
}
let (channels, height, width) = (
first_shape.dims()[0],
first_shape.dims()[1],
first_shape.dims()[2],
);
let num_images = tensors.len();
let ncol = (num_images + nrow - 1) / nrow; let actual_nrow = (num_images + ncol - 1) / ncol;
let grid_height = actual_nrow * height + (actual_nrow - 1) * padding;
let grid_width = ncol * width + (ncol - 1) * padding;
let grid = zeros_mut(&[channels, grid_height, grid_width]);
for (idx, tensor) in tensors.iter().enumerate() {
let row = idx / ncol;
let col = idx % ncol;
let start_y = row * (height + padding);
let start_x = col * (width + padding);
for c in 0..channels {
for y in 0..height {
for x in 0..width {
let pixel_val = tensor.get(&[c, y, x])?;
grid.set(&[c, start_y + y, start_x + x], pixel_val)?;
}
}
}
}
Ok(grid)
}
pub fn draw_bounding_boxes(
image: &mut DynamicImage,
boxes: &Tensor<f32>,
labels: Option<&[String]>,
scores: Option<&Tensor<f32>>,
colors: Option<&[(u8, u8, u8)]>,
) -> Result<()> {
let box_shape = boxes.shape();
if box_shape.dims().len() != 2 || box_shape.dims()[1] != 4 {
return Err(VisionError::InvalidShape(format!(
"Expected boxes tensor of shape [N, 4], got [{}, {}]",
box_shape.dims()[0],
box_shape.dims()[1]
)));
}
let num_boxes = box_shape.dims()[0];
let default_colors = vec![
(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255), (0, 255, 255), ];
let box_colors = colors.unwrap_or(&default_colors);
let mut rgb_image = image.to_rgb8();
for i in 0..num_boxes {
let x1 = boxes.get(&[i, 0])? as u32;
let y1 = boxes.get(&[i, 1])? as u32;
let x2 = boxes.get(&[i, 2])? as u32;
let y2 = boxes.get(&[i, 3])? as u32;
let color = box_colors[i % box_colors.len()];
let rgb_color = image::Rgb([color.0, color.1, color.2]);
for x in x1..=x2 {
if y1 < rgb_image.height() && x < rgb_image.width() {
rgb_image.put_pixel(x, y1, rgb_color);
}
if y2 < rgb_image.height() && x < rgb_image.width() {
rgb_image.put_pixel(x, y2, rgb_color);
}
}
for y in y1..=y2 {
if x1 < rgb_image.width() && y < rgb_image.height() {
rgb_image.put_pixel(x1, y, rgb_color);
}
if x2 < rgb_image.width() && y < rgb_image.height() {
rgb_image.put_pixel(x2, y, rgb_color);
}
}
if let Some(label_texts) = labels {
if i < label_texts.len() {
let label_text = &label_texts[i];
let score_text = if let Some(score_tensor) = scores {
if i < score_tensor.shape().dims()[0] {
let score = score_tensor.get(&[i])?;
format!("{}: {:.2}", label_text, score)
} else {
label_text.clone()
}
} else {
label_text.clone()
};
draw_simple_text(
&mut rgb_image,
&score_text,
x1 + 2,
y1.saturating_sub(15),
rgb_color,
);
}
}
}
*image = DynamicImage::ImageRgb8(rgb_image);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use image::{DynamicImage, RgbImage};
use torsh_tensor::creation;
#[test]
fn test_make_grid_single_tensor() {
let tensor = creation::ones(&[3, 32, 32]).expect("creation should succeed");
let tensors = vec![tensor];
let grid = make_grid(&tensors, 1, 0).expect("make grid should succeed");
assert_eq!(grid.shape().dims(), &[3, 32, 32]);
}
#[test]
fn test_make_grid_multiple_tensors() {
let tensors: Vec<_> = (0..4)
.map(|_| creation::ones(&[3, 16, 16]).expect("map operation should succeed"))
.collect();
let grid = make_grid(&tensors, 2, 2).expect("make grid should succeed");
assert_eq!(grid.shape().dims(), &[3, 34, 34]);
}
#[test]
fn test_make_grid_empty_tensors() {
let tensors: Vec<Tensor<f32>> = vec![];
let result = make_grid(&tensors, 1, 0);
assert!(result.is_err());
}
#[test]
fn test_make_grid_mismatched_shapes() {
let tensor1 = creation::ones(&[3, 32, 32]).expect("creation should succeed");
let tensor2 = creation::ones(&[3, 16, 16]).expect("creation should succeed");
let tensors = vec![tensor1, tensor2];
let result = make_grid(&tensors, 1, 0);
assert!(result.is_err());
}
#[test]
fn test_make_grid_invalid_dimensions() {
let tensor = creation::ones(&[32, 32]).expect("creation should succeed"); let tensors = vec![tensor];
let result = make_grid(&tensors, 1, 0);
assert!(result.is_err());
}
#[test]
fn test_draw_bounding_boxes_valid_input() -> Result<()> {
let mut image = DynamicImage::ImageRgb8(RgbImage::new(100, 100));
let boxes_data = vec![10.0, 10.0, 50.0, 50.0, 60.0, 60.0, 90.0, 90.0];
let boxes = creation::tensor_1d(&boxes_data)?.reshape(&[2, 4])?;
let labels = vec!["box1".to_string(), "box2".to_string()];
let result = draw_bounding_boxes(&mut image, &boxes, Some(&labels), None, None);
assert!(result.is_ok());
Ok(())
}
#[test]
fn test_draw_bounding_boxes_invalid_shape() {
let mut image = DynamicImage::ImageRgb8(RgbImage::new(100, 100));
let boxes = creation::ones(&[2, 3]).expect("creation should succeed");
let result = draw_bounding_boxes(&mut image, &boxes, None, None, None);
assert!(result.is_err());
}
#[test]
fn test_draw_bounding_boxes_with_scores() -> Result<()> {
let mut image = DynamicImage::ImageRgb8(RgbImage::new(100, 100));
let boxes = creation::tensor_1d(&[10.0, 10.0, 50.0, 50.0])?.reshape(&[1, 4])?;
let scores = creation::tensor_1d(&[0.95])?;
let labels = vec!["object".to_string()];
let result = draw_bounding_boxes(&mut image, &boxes, Some(&labels), Some(&scores), None);
assert!(result.is_ok());
Ok(())
}
#[test]
fn test_draw_bounding_boxes_custom_colors() -> Result<()> {
let mut image = DynamicImage::ImageRgb8(RgbImage::new(100, 100));
let boxes = creation::tensor_1d(&[10.0, 10.0, 50.0, 50.0])?.reshape(&[1, 4])?;
let colors = vec![(128, 64, 192)];
let result = draw_bounding_boxes(&mut image, &boxes, None, None, Some(&colors));
assert!(result.is_ok());
Ok(())
}
}