use crate::{
ColorJitter, Compose, RandomCrop, RandomHorizontalFlip, RandomResizedCrop, Result, Transform,
};
use scirs2_core::random::{thread_rng, Random}; use torsh_tensor::Tensor;
pub(crate) type TensorF32 = Tensor<f32>;
pub struct SimCLRAugmentation {
crop_size: usize,
color_strength: f32,
blur_probability: f32,
transform: Box<dyn Transform>,
}
impl SimCLRAugmentation {
pub fn new(crop_size: usize, color_strength: f32, blur_probability: f32) -> Self {
let transform = Box::new(Compose::new(vec![
Box::new(RandomResizedCrop::new((crop_size, crop_size))),
Box::new(RandomHorizontalFlip::new(0.5)),
Box::new(
ColorJitter::new()
.brightness(0.8 * color_strength)
.contrast(0.8 * color_strength)
.saturation(0.8 * color_strength)
.hue(0.2 * color_strength),
),
Box::new(RandomGrayscale::new(0.2)),
Box::new(GaussianBlur::new(
(crop_size / 10).max(3) | 1,
blur_probability,
)),
]));
Self {
crop_size,
color_strength,
blur_probability,
transform,
}
}
pub fn generate_pair(&self, image: &Tensor<f32>) -> Result<(Tensor<f32>, Tensor<f32>)> {
let view1 = self.transform.forward(image)?;
let view2 = self.transform.forward(image)?;
Ok((view1, view2))
}
pub fn generate_views(
&self,
image: &Tensor<f32>,
num_views: usize,
) -> Result<Vec<Tensor<f32>>> {
let mut views = Vec::with_capacity(num_views);
for _ in 0..num_views {
views.push(self.transform.forward(image)?);
}
Ok(views)
}
}
pub struct MoCoAugmentation {
query_transform: Box<dyn Transform>,
key_transform: Box<dyn Transform>,
}
impl MoCoAugmentation {
pub fn new(crop_size: usize) -> Self {
let query_transform = Box::new(Compose::new(vec![
Box::new(RandomResizedCrop::new((crop_size, crop_size))),
Box::new(RandomHorizontalFlip::new(0.5)),
Box::new(
ColorJitter::new()
.brightness(0.4)
.contrast(0.4)
.saturation(0.4)
.hue(0.1),
),
]));
let key_transform = Box::new(Compose::new(vec![
Box::new(RandomResizedCrop::new((crop_size, crop_size))),
Box::new(RandomHorizontalFlip::new(0.5)),
Box::new(
ColorJitter::new()
.brightness(0.4)
.contrast(0.4)
.saturation(0.4)
.hue(0.1),
),
Box::new(RandomGrayscale::new(0.2)),
]));
Self {
query_transform,
key_transform,
}
}
pub fn generate_pair(&self, image: &Tensor<f32>) -> Result<(Tensor<f32>, Tensor<f32>)> {
let query = self.query_transform.forward(image)?;
let key = self.key_transform.forward(image)?;
Ok((query, key))
}
}
pub struct BYOLAugmentation {
online_transform: Box<dyn Transform>,
target_transform: Box<dyn Transform>,
}
impl BYOLAugmentation {
pub fn new(crop_size: usize) -> Self {
let online_transform = Box::new(Compose::new(vec![
Box::new(RandomResizedCrop::new((crop_size, crop_size))),
Box::new(RandomHorizontalFlip::new(0.5)),
Box::new(
ColorJitter::new()
.brightness(0.4)
.contrast(0.4)
.saturation(0.2)
.hue(0.1),
),
Box::new(RandomGrayscale::new(0.2)),
Box::new(GaussianBlur::new((crop_size / 10).max(3) | 1, 1.0)), Box::new(Solarize::new(0.0, 0.0)), ]));
let target_transform = Box::new(Compose::new(vec![
Box::new(RandomResizedCrop::new((crop_size, crop_size))),
Box::new(RandomHorizontalFlip::new(0.5)),
Box::new(
ColorJitter::new()
.brightness(0.4)
.contrast(0.4)
.saturation(0.2)
.hue(0.1),
),
Box::new(RandomGrayscale::new(0.2)),
Box::new(GaussianBlur::new((crop_size / 10).max(3) | 1, 0.1)), ]));
Self {
online_transform,
target_transform,
}
}
pub fn generate_pair(&self, image: &Tensor<f32>) -> Result<(Tensor<f32>, Tensor<f32>)> {
let online_view = self.online_transform.forward(image)?;
let target_view = self.target_transform.forward(image)?;
Ok((online_view, target_view))
}
}
#[derive(Debug)]
pub struct SwAVAugmentation {
global_crop_size: usize,
local_crop_size: usize,
num_global_crops: usize,
num_local_crops: usize,
}
impl SwAVAugmentation {
pub fn new(
global_crop_size: usize,
local_crop_size: usize,
num_global_crops: usize,
num_local_crops: usize,
) -> Self {
Self {
global_crop_size,
local_crop_size,
num_global_crops,
num_local_crops,
}
}
pub fn generate_views(&self, image: &Tensor<f32>) -> Result<Vec<Tensor<f32>>> {
let mut views = Vec::new();
let global_transform = self.create_global_transform();
for _ in 0..self.num_global_crops {
views.push(global_transform.forward(image)?);
}
let local_transform = self.create_local_transform();
for _ in 0..self.num_local_crops {
views.push(local_transform.forward(image)?);
}
Ok(views)
}
fn create_global_transform(&self) -> Box<dyn Transform> {
Box::new(Compose::new(vec![
Box::new(RandomResizedCrop::new((
self.global_crop_size,
self.global_crop_size,
))),
Box::new(RandomHorizontalFlip::new(0.5)),
Box::new(
ColorJitter::new()
.brightness(0.4)
.contrast(0.4)
.saturation(0.2)
.hue(0.1),
),
Box::new(RandomGrayscale::new(0.2)),
Box::new(GaussianBlur::new(
(self.global_crop_size / 10).max(3) | 1,
0.5,
)),
]))
}
fn create_local_transform(&self) -> Box<dyn Transform> {
Box::new(Compose::new(vec![
Box::new(RandomResizedCrop::new((
self.local_crop_size,
self.local_crop_size,
))),
Box::new(RandomHorizontalFlip::new(0.5)),
Box::new(
ColorJitter::new()
.brightness(0.4)
.contrast(0.4)
.saturation(0.2)
.hue(0.1),
),
Box::new(RandomGrayscale::new(0.2)),
Box::new(GaussianBlur::new(
(self.local_crop_size / 10).max(3) | 1,
0.5,
)),
]))
}
}
#[derive(Debug)]
pub struct DINOAugmentation {
global_crop_size: usize,
local_crop_size: usize,
}
impl DINOAugmentation {
pub fn new(global_crop_size: usize, local_crop_size: usize) -> Self {
Self {
global_crop_size,
local_crop_size,
}
}
pub fn generate_views(
&self,
image: &Tensor<f32>,
num_local_crops: usize,
) -> Result<Vec<Tensor<f32>>> {
let mut views = Vec::new();
let global_transform = self.create_global_transform();
views.push(global_transform.forward(image)?);
views.push(global_transform.forward(image)?);
let local_transform = self.create_local_transform();
for _ in 0..num_local_crops {
views.push(local_transform.forward(image)?);
}
Ok(views)
}
fn create_global_transform(&self) -> Box<dyn Transform> {
Box::new(Compose::new(vec![
Box::new(RandomResizedCrop::new((
self.global_crop_size,
self.global_crop_size,
))),
Box::new(RandomHorizontalFlip::new(0.5)),
Box::new(
ColorJitter::new()
.brightness(0.4)
.contrast(0.4)
.saturation(0.2)
.hue(0.1),
),
Box::new(RandomGrayscale::new(0.2)),
Box::new(GaussianBlur::new(
(self.global_crop_size / 10).max(3) | 1,
1.0,
)),
Box::new(Solarize::new(0.0, 0.2)), ]))
}
fn create_local_transform(&self) -> Box<dyn Transform> {
Box::new(Compose::new(vec![
Box::new(RandomResizedCrop::new((
self.local_crop_size,
self.local_crop_size,
))),
Box::new(RandomHorizontalFlip::new(0.5)),
Box::new(
ColorJitter::new()
.brightness(0.4)
.contrast(0.4)
.saturation(0.2)
.hue(0.1),
),
Box::new(RandomGrayscale::new(0.2)),
Box::new(GaussianBlur::new(
(self.local_crop_size / 10).max(3) | 1,
0.5,
)),
]))
}
}
impl RandomGrayscale {
pub fn new(probability: f32) -> Self {
Self { probability }
}
}
#[derive(Debug)]
pub struct RandomGrayscale {
probability: f32,
}
impl Transform for RandomGrayscale {
fn forward(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
let mut rng = thread_rng();
if rng.random::<f32>() < self.probability {
let gray = crate::rgb_to_grayscale(input)?;
Ok(Tensor::cat(&[&gray, &gray, &gray], 0)?)
} else {
Ok(input.clone())
}
}
fn clone_transform(&self) -> Box<dyn Transform> {
Box::new(Self {
probability: self.probability,
})
}
}
#[derive(Debug)]
pub struct GaussianBlur {
kernel_size: usize,
probability: f32,
}
impl GaussianBlur {
pub fn new(kernel_size: usize, probability: f32) -> Self {
Self {
kernel_size,
probability,
}
}
}
impl Transform for GaussianBlur {
fn forward(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
let mut rng = thread_rng();
if rng.random::<f32>() < self.probability {
let sigma = (self.kernel_size as f32) * 0.3;
crate::gaussian_blur(input, sigma)
} else {
Ok(input.clone())
}
}
fn clone_transform(&self) -> Box<dyn Transform> {
Box::new(Self {
kernel_size: self.kernel_size,
probability: self.probability,
})
}
}
#[derive(Debug)]
pub struct Solarize {
threshold: f32,
probability: f32,
}
impl Solarize {
pub fn new(threshold: f32, probability: f32) -> Self {
Self {
threshold,
probability,
}
}
}
impl Transform for Solarize {
fn forward(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
let mut rng = thread_rng();
if rng.random::<f32>() < self.probability {
let inverted = input.mul_scalar(-1.0)?.add_scalar(1.0)?;
Ok(inverted)
} else {
Ok(input.clone())
}
}
fn clone_transform(&self) -> Box<dyn Transform> {
Box::new(Self {
threshold: self.threshold,
probability: self.probability,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_tensor::creation::{self, randn};
#[test]
fn test_simclr_augmentation() {
let aug = SimCLRAugmentation::new(224, 1.0, 0.5);
assert_eq!(aug.crop_size, 224);
assert_eq!(aug.color_strength, 1.0);
assert_eq!(aug.blur_probability, 0.5);
}
#[test]
fn test_simclr_generate_pair() {
let aug = SimCLRAugmentation::new(224, 1.0, 0.5);
let image = randn::<f32>(&[3, 256, 256]).unwrap();
let result = aug.generate_pair(&image);
assert!(result.is_ok());
let (view1, view2) = result.unwrap();
assert!(view1.numel() > 0);
assert!(view2.numel() > 0);
}
#[test]
fn test_moco_augmentation() {
let aug = MoCoAugmentation::new(224);
let image = randn::<f32>(&[3, 256, 256]).unwrap();
let result = aug.generate_pair(&image);
assert!(result.is_ok());
}
#[test]
fn test_byol_augmentation() {
let aug = BYOLAugmentation::new(224);
let image = randn::<f32>(&[3, 256, 256]).unwrap();
let result = aug.generate_pair(&image);
assert!(result.is_ok());
}
#[test]
fn test_swav_augmentation() {
let aug = SwAVAugmentation::new(224, 96, 2, 6);
let image = randn::<f32>(&[3, 256, 256]).unwrap();
let result = aug.generate_views(&image);
assert!(result.is_ok());
let views = result.unwrap();
assert_eq!(views.len(), 8); }
#[test]
fn test_dino_augmentation() {
let aug = DINOAugmentation::new(224, 96);
let image = randn::<f32>(&[3, 256, 256]).unwrap();
let result = aug.generate_views(&image, 4);
assert!(result.is_ok());
let views = result.unwrap();
assert_eq!(views.len(), 6); }
#[test]
fn test_random_grayscale() {
let transform = RandomGrayscale::new(1.0); let image = randn::<f32>(&[3, 64, 64]).unwrap();
let result = transform.forward(&image);
assert!(result.is_ok());
}
#[test]
fn test_gaussian_blur() {
let transform = GaussianBlur::new(5, 1.0);
let image = randn::<f32>(&[3, 64, 64]).unwrap();
let result = transform.forward(&image);
assert!(result.is_ok());
}
#[test]
fn test_solarize() {
let transform = Solarize::new(0.5, 1.0);
let image = randn::<f32>(&[3, 64, 64]).unwrap();
let result = transform.forward(&image);
assert!(result.is_ok());
}
}