pub mod analysis;
pub mod color;
pub mod common;
pub mod detection;
pub mod filtering;
pub mod geometric;
pub use common::{
constants, to_tensor, utils, EdgeDetectionAlgorithm, InterpolationMode, MorphologicalOperation,
PaddingMode, VisionOpConfig,
};
pub use geometric::{
center_crop,
horizontal_flip,
pad,
random_crop,
resize,
resize_with_mode,
rotate,
vertical_flip,
};
pub use filtering::{
edge_detection,
gaussian_blur,
median_filter,
morphological_operation,
sobel_edge_detection,
};
pub use color::{
adjust_brightness,
adjust_contrast,
adjust_hue,
adjust_saturation,
histogram_equalization,
hsv_to_rgb,
normalize,
rgb_to_grayscale,
rgb_to_hsv,
};
pub use detection::{
calculate_iou,
generate_anchors,
nms,
BoundingBox,
Detection,
};
pub use analysis::{
compute_classification_metrics,
cross_entropy_loss,
ClassificationMetrics,
DetectionMetrics,
SegmentationMetrics,
};
use crate::Result;
use torsh_tensor::Tensor;
pub fn normalize_with_mean_std(
image: &Tensor<f32>,
mean: &[f32],
std: &[f32],
) -> Result<Tensor<f32>> {
let config = color::NormalizationConfig::custom(mean.to_vec(), std.to_vec());
color::normalize(image, config)
}
pub fn resize_simple(image: &Tensor<f32>, size: (usize, usize)) -> Result<Tensor<f32>> {
geometric::resize(image, size)
}
pub fn normalize_simple(image: &Tensor<f32>, mean: &[f32], std: &[f32]) -> Result<Tensor<f32>> {
normalize_with_mean_std(image, mean, std)
}
pub use geometric::{center_crop as image_center_crop, resize as image_resize};
pub use filtering::{gaussian_blur as image_blur, sobel_edge_detection as edge_detection_sobel};
pub use color::{
adjust_brightness as brightness, adjust_contrast as contrast,
normalize_imagenet as imagenet_normalize, rgb_to_grayscale as to_grayscale,
};
pub use detection::{calculate_iou as box_iou, nms as non_max_suppression};
pub use analysis::{
compute_classification_metrics as classification_eval, cross_entropy_loss as ce_loss,
};
pub fn standard_preprocessing_config() -> color::NormalizationConfig {
color::NormalizationConfig::imagenet()
}
pub fn standard_nms_config() -> detection::NMSConfig {
detection::NMSConfig::new(0.5, 0.5).with_per_class(true)
}
pub fn strict_nms_config() -> detection::NMSConfig {
detection::NMSConfig::new(0.3, 0.7).with_per_class(true)
}
pub fn standard_edge_detection_config() -> filtering::EdgeDetectionConfig {
filtering::EdgeDetectionConfig::sobel()
}
pub fn canny_edge_detection_config() -> filtering::EdgeDetectionConfig {
filtering::EdgeDetectionConfig::canny(50.0, 150.0)
}
pub fn high_quality_resize(image: &Tensor<f32>, size: (usize, usize)) -> Result<Tensor<f32>> {
geometric::resize_with_mode(image, size, common::InterpolationMode::Bicubic)
}
pub fn fast_resize(image: &Tensor<f32>, size: (usize, usize)) -> Result<Tensor<f32>> {
geometric::resize_with_mode(image, size, common::InterpolationMode::Nearest)
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_tensor::creation::zeros;
#[test]
fn test_backward_compatibility() -> Result<()> {
let image = zeros(&[3, 32, 32])?;
let resized = resize(&image, (16, 16))?;
assert_eq!(resized.shape().dims(), &[3, 16, 16]);
let cropped = center_crop(&image, (16, 16))?;
assert_eq!(cropped.shape().dims(), &[3, 16, 16]);
let gray = rgb_to_grayscale(&image)?;
assert_eq!(gray.shape().dims(), &[1, 32, 32]);
let blurred = gaussian_blur(&image, 1.0)?;
assert_eq!(blurred.shape().dims(), &[3, 32, 32]);
Ok(())
}
#[test]
fn test_enhanced_functionality() -> Result<()> {
let image = zeros(&[3, 32, 32])?;
let resized = high_quality_resize(&image, (16, 16))?;
assert_eq!(resized.shape().dims(), &[3, 16, 16]);
let norm_config = standard_preprocessing_config();
let normalized = color::normalize(&image, norm_config)?;
assert_eq!(normalized.shape().dims(), &[3, 32, 32]);
Ok(())
}
#[test]
fn test_convenience_aliases() -> Result<()> {
let image = zeros(&[3, 32, 32])?;
let resized = image_resize(&image, (16, 16))?;
assert_eq!(resized.shape().dims(), &[3, 16, 16]);
let gray = to_grayscale(&image)?;
assert_eq!(gray.shape().dims(), &[1, 32, 32]);
let bright = brightness(&image, 0.1)?;
assert_eq!(bright.shape().dims(), &[3, 32, 32]);
Ok(())
}
#[test]
fn test_factory_functions() -> Result<()> {
let image = zeros(&[3, 32, 32])?;
let _resized_high = high_quality_resize(&image, (224, 224))?;
let _resized_fast = fast_resize(&image, (224, 224))?;
let _nms_config = standard_nms_config();
let _edge_config = standard_edge_detection_config();
let _preprocessing_config = standard_preprocessing_config();
Ok(())
}
#[test]
fn test_detection_operations() -> Result<()> {
let detections = vec![
detection::Detection::new([0.0, 0.0, 10.0, 10.0], 0.9, 0),
detection::Detection::new([5.0, 5.0, 15.0, 15.0], 0.8, 0),
];
let config = standard_nms_config();
let filtered = nms(detections, config)?;
assert!(filtered.len() <= 2);
Ok(())
}
#[test]
fn test_filtering_operations() -> Result<()> {
let image = zeros(&[1, 32, 32])?;
let edges = sobel_edge_detection(&image)?;
assert_eq!(edges.shape().dims(), &[1, 32, 32]);
let blurred = gaussian_blur(&image, 2.0)?;
assert_eq!(blurred.shape().dims(), &[1, 32, 32]);
let median_filtered = median_filter(&image, 3)?;
assert_eq!(median_filtered.shape().dims(), &[1, 32, 32]);
Ok(())
}
#[test]
fn test_color_operations() -> Result<()> {
let rgb_image = zeros(&[3, 16, 16])?;
let hsv = rgb_to_hsv(&rgb_image)?;
assert_eq!(hsv.shape().dims(), &[3, 16, 16]);
let back_to_rgb = hsv_to_rgb(&hsv)?;
assert_eq!(back_to_rgb.shape().dims(), &[3, 16, 16]);
let bright = adjust_brightness(&rgb_image, 0.2)?;
assert_eq!(bright.shape().dims(), &[3, 16, 16]);
let contrasted = adjust_contrast(&rgb_image, 1.5)?;
assert_eq!(contrasted.shape().dims(), &[3, 16, 16]);
Ok(())
}
#[test]
fn test_analysis_operations() -> Result<()> {
let predictions = zeros(&[5, 3])?; let targets = zeros(&[5])?;
let metrics = compute_classification_metrics(&predictions, &targets, 3)?;
assert_eq!(metrics.precision.len(), 3);
assert_eq!(metrics.recall.len(), 3);
assert_eq!(metrics.f1_score.len(), 3);
let config = analysis::LossConfig::default();
let loss = cross_entropy_loss(&predictions, &targets, config)?;
assert_eq!(loss.shape().dims(), &[1]);
Ok(())
}
}