burn-train 0.21.0-pre.3

Training crate for the Burn framework
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
//! DISTS (Deep Image Structure and Texture Similarity) metric.
//!
//! DISTS is a full-reference image quality assessment metric that combines
//! structure and texture similarity using deep features from VGG16.
//!
//! Reference: "Image Quality Assessment: Unifying Structure and Texture Similarity"
//! https://arxiv.org/abs/2004.07728

use burn_core as burn;

use burn::config::Config;
use burn::module::{Content, DisplaySettings, Module, ModuleDisplay, Param};
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
use burn_nn::loss::Reduction;

use super::vgg16_l2pool::Vgg16L2PoolExtractor;

/// Channel counts for each stage: [input, stage1, stage2, stage3, stage4, stage5]
const CHANNELS: [usize; 6] = [3, 64, 128, 256, 512, 512];

/// Small constant for numerical stability in structure similarity.
const C1: f32 = 1e-6;

/// Small constant for numerical stability in texture similarity.
const C2: f32 = 1e-6;

/// ImageNet normalization constants.
const IMAGENET_MEAN: [f32; 3] = [0.485, 0.456, 0.406];
const IMAGENET_STD: [f32; 3] = [0.229, 0.224, 0.225];

/// Image normalizer with pre-initialized mean and std tensors.
///
/// This struct holds the mean and std tensors for normalization,
/// avoiding the need to create them on each forward pass.
#[derive(Module, Debug)]
pub struct Normalizer<B: Backend> {
    /// Mean tensor of shape [1, 3, 1, 1] for broadcasting.
    pub mean: Tensor<B, 4>,
    /// Std tensor of shape [1, 3, 1, 1] for broadcasting.
    pub std: Tensor<B, 4>,
}

impl<B: Backend> Normalizer<B> {
    /// Create a new ImageNet normalizer.
    pub fn imagenet(device: &B::Device) -> Self {
        // Shape: [1, 3, 1, 1] for broadcasting over [batch, channels, height, width]
        let mean = Tensor::from_floats(
            [[
                [[IMAGENET_MEAN[0]]],
                [[IMAGENET_MEAN[1]]],
                [[IMAGENET_MEAN[2]]],
            ]],
            device,
        );
        let std = Tensor::from_floats(
            [[
                [[IMAGENET_STD[0]]],
                [[IMAGENET_STD[1]]],
                [[IMAGENET_STD[2]]],
            ]],
            device,
        );
        Self { mean, std }
    }

    /// Normalize a tensor: (x - mean) / std
    pub fn normalize(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
        x.sub(self.mean.clone()).div(self.std.clone())
    }
}

/// Configuration for DISTS metric.
#[derive(Config, Debug)]
pub struct DistsConfig {
    /// Whether to apply ImageNet normalization to input images.
    #[config(default = true)]
    pub normalize: bool,
}

impl DistsConfig {
    /// Initialize a DISTS module with default weights.
    pub fn init<B: Backend>(&self, device: &B::Device) -> Dists<B> {
        let total_channels: usize = CHANNELS.iter().sum();

        // Initialize alpha and beta with constant value 0.1 for all channels
        let alpha_data: Vec<f32> = (0..total_channels).map(|_| 0.1).collect();
        let beta_data: Vec<f32> = (0..total_channels).map(|_| 0.1).collect();

        let normalizer = if self.normalize {
            Some(Normalizer::imagenet(device))
        } else {
            None
        };

        Dists {
            extractor: Vgg16L2PoolExtractor::new(device),
            alpha: Param::from_tensor(Tensor::from_floats(alpha_data.as_slice(), device)),
            beta: Param::from_tensor(Tensor::from_floats(beta_data.as_slice(), device)),
            normalizer,
        }
    }

    /// Initialize a DISTS module with pretrained weights.
    pub fn init_pretrained<B: Backend>(&self, device: &B::Device) -> Dists<B> {
        let dists = self.init(device);
        super::weights::load_pretrained_weights(dists)
    }
}

/// DISTS (Deep Image Structure and Texture Similarity) metric.
///
/// Computes perceptual similarity between two images by combining
/// structure similarity (based on spatial means) and texture similarity
/// (based on variances and covariances) across VGG16 feature maps.
///
/// # Example
///
/// ```ignore
/// use burn_train::metric::vision::{DistsConfig, Reduction};
///
/// let device = Default::default();
/// let dists = DistsConfig::new().init_pretrained(&device);
///
/// let img1: Tensor<B, 4> = /* [batch, 3, H, W] */;
/// let img2: Tensor<B, 4> = /* [batch, 3, H, W] */;
///
/// let distance = dists.forward(img1, img2, Reduction::Mean);
/// ```
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct Dists<B: Backend> {
    /// VGG16 feature extractor with L2 pooling
    pub(crate) extractor: Vgg16L2PoolExtractor<B>,
    /// Learned weights for structure similarity (per channel)
    pub(crate) alpha: Param<Tensor<B, 1>>,
    /// Learned weights for texture similarity (per channel)
    pub(crate) beta: Param<Tensor<B, 1>>,
    /// Optional normalizer for input preprocessing
    pub(crate) normalizer: Option<Normalizer<B>>,
}

impl<B: Backend> ModuleDisplay for Dists<B> {
    fn custom_settings(&self) -> Option<DisplaySettings> {
        DisplaySettings::new()
            .with_new_line_after_attribute(false)
            .optional()
    }

    fn custom_content(&self, content: Content) -> Option<Content> {
        content
            .add("backbone", &"VGG16-L2Pool".to_string())
            .add("normalize", &self.normalizer.is_some().to_string())
            .optional()
    }
}

impl<B: Backend> Dists<B> {
    /// Compute DISTS distance with reduction.
    ///
    /// # Arguments
    ///
    /// * `input` - First image tensor of shape `[batch, 3, H, W]`
    /// * `target` - Second image tensor of shape `[batch, 3, H, W]`
    /// * `reduction` - How to reduce the output (Mean, Sum, or Auto)
    ///
    /// # Returns
    ///
    /// Scalar tensor of shape `[1]`.
    pub fn forward(
        &self,
        input: Tensor<B, 4>,
        target: Tensor<B, 4>,
        reduction: Reduction,
    ) -> Tensor<B, 1> {
        let distance = self.forward_no_reduction(input, target);

        match reduction {
            Reduction::Mean | Reduction::Auto | Reduction::BatchMean => distance.mean(),
            Reduction::Sum => distance.sum(),
        }
    }

    /// Compute DISTS distance without reduction.
    ///
    /// # Arguments
    ///
    /// * `input` - First image tensor of shape `[batch, 3, H, W]`
    /// * `target` - Second image tensor of shape `[batch, 3, H, W]`
    ///
    /// # Returns
    ///
    /// Per-sample distance tensor of shape `[batch]`.
    pub fn forward_no_reduction(&self, input: Tensor<B, 4>, target: Tensor<B, 4>) -> Tensor<B, 1> {
        let [batch, _, _, _] = input.dims();

        // Preprocess inputs
        let (input, target) = self.preprocess(input, target);

        // Extract features from both images
        let feats_x = self.extractor.forward(input);
        let feats_y = self.extractor.forward(target);

        // Get alpha and beta weights
        let alpha = self.alpha.val();
        let beta = self.beta.val();

        // Compute weighted sum of alpha and beta for normalization
        let alpha_sum = alpha.clone().sum();
        let beta_sum = beta.clone().sum();

        let device = feats_x[0].device();

        // Initialize accumulators
        let mut structure_dist = Tensor::<B, 1>::zeros([batch], &device);
        let mut texture_dist = Tensor::<B, 1>::zeros([batch], &device);

        let mut channel_offset = 0;

        // Compute similarity for each stage
        for (feat_x, feat_y) in feats_x.iter().zip(feats_y.iter()) {
            let [_b, c, _h, _w] = feat_x.dims();

            // Get alpha and beta for this stage
            let alpha_stage = alpha.clone().narrow(0, channel_offset, c);
            let beta_stage = beta.clone().narrow(0, channel_offset, c);

            // Compute structure and texture similarity for this stage
            let (s_dist, t_dist) = self.compute_stage_similarity(
                feat_x.clone(),
                feat_y.clone(),
                alpha_stage,
                beta_stage,
            );

            structure_dist = structure_dist.add(s_dist);
            texture_dist = texture_dist.add(t_dist);

            channel_offset += c;
        }

        // Normalize by sum of weights
        structure_dist = structure_dist.div(alpha_sum);
        texture_dist = texture_dist.div(beta_sum);

        // DISTS = 1 - (structure_similarity + texture_similarity)
        // Since we computed distances (1 - similarity), we return the sum
        structure_dist.add(texture_dist)
    }

    /// Compute structure and texture similarity for a single stage.
    fn compute_stage_similarity(
        &self,
        feat_x: Tensor<B, 4>,
        feat_y: Tensor<B, 4>,
        alpha: Tensor<B, 1>,
        beta: Tensor<B, 1>,
    ) -> (Tensor<B, 1>, Tensor<B, 1>) {
        let [batch, channels, height, width] = feat_x.dims();
        let device = feat_x.device();

        // Reshape to [batch, channels, H*W] for easier computation
        let x = feat_x.reshape([batch, channels, height * width]);
        let y = feat_y.reshape([batch, channels, height * width]);

        // Compute means: [batch, channels] (squeeze after mean_dim to remove the reduced dimension)
        let mean_x = x.clone().mean_dim(2).squeeze_dim::<2>(2);
        let mean_y = y.clone().mean_dim(2).squeeze_dim::<2>(2);

        // Compute structure similarity: (2*mean_x*mean_y + c1) / (mean_x^2 + mean_y^2 + c1)
        let c1 = Tensor::<B, 2>::full([batch, channels], C1, &device);
        let structure_sim = mean_x
            .clone()
            .mul(mean_y.clone())
            .mul_scalar(2.0)
            .add(c1.clone())
            .div(
                mean_x
                    .clone()
                    .mul(mean_x.clone())
                    .add(mean_y.clone().mul(mean_y.clone()))
                    .add(c1),
            );

        // Compute variances and covariance
        // var_x = E[x^2] - E[x]^2, clamped at 0 for numerical stability
        let var_x = x
            .clone()
            .mul(x.clone())
            .mean_dim(2)
            .squeeze_dim::<2>(2)
            .sub(mean_x.clone().mul(mean_x.clone()))
            .clamp_min(0.0);
        let var_y = y
            .clone()
            .mul(y.clone())
            .mean_dim(2)
            .squeeze_dim::<2>(2)
            .sub(mean_y.clone().mul(mean_y.clone()))
            .clamp_min(0.0);

        // cov_xy = E[xy] - E[x]E[y]
        let cov_xy = x
            .mul(y)
            .mean_dim(2)
            .squeeze_dim::<2>(2)
            .sub(mean_x.clone().mul(mean_y.clone()));

        // Compute texture similarity: (2*cov_xy + c2) / (var_x + var_y + c2)
        let c2 = Tensor::<B, 2>::full([batch, channels], C2, &device);
        let texture_sim = cov_xy
            .mul_scalar(2.0)
            .add(c2.clone())
            .div(var_x.add(var_y).add(c2));

        // Convert similarity to distance: 1 - similarity
        let structure_dist = Tensor::<B, 2>::ones([batch, channels], &device).sub(structure_sim);
        let texture_dist = Tensor::<B, 2>::ones([batch, channels], &device).sub(texture_sim);

        // Apply weights: [batch, channels] * [channels] -> [batch, channels]
        // Then sum over channels -> [batch]
        let weighted_structure = structure_dist
            .mul(alpha.unsqueeze_dim::<2>(0))
            .sum_dim(1)
            .squeeze_dim::<1>(1);
        let weighted_texture = texture_dist
            .mul(beta.unsqueeze_dim::<2>(0))
            .sum_dim(1)
            .squeeze_dim::<1>(1);

        (weighted_structure, weighted_texture)
    }

    /// Preprocess input images using the configured normalizer.
    fn preprocess(
        &self,
        input: Tensor<B, 4>,
        target: Tensor<B, 4>,
    ) -> (Tensor<B, 4>, Tensor<B, 4>) {
        match &self.normalizer {
            Some(normalizer) => {
                let input = normalizer.normalize(input);
                let target = normalizer.normalize(target);
                (input, target)
            }
            None => (input, target),
        }
    }
}

// =============================================================================
// Tests
// =============================================================================

#[cfg(test)]
mod tests {
    use super::*;
    use burn_core::tensor::{TensorData, Tolerance, ops::FloatElem};
    use burn_ndarray::NdArray;

    type TestBackend = NdArray<f32>;
    type FT = FloatElem<TestBackend>;
    type TestTensor<const D: usize> = Tensor<TestBackend, D>;

    #[test]
    fn test_dists_identical_images_zero_distance() {
        let device = Default::default();
        // Use random image instead of constant to avoid numerical edge cases
        let image = TestTensor::<4>::random(
            [1, 3, 64, 64],
            burn_core::tensor::Distribution::Uniform(0.0, 1.0),
            &device,
        );

        let dists: Dists<TestBackend> = DistsConfig::new().init(&device);
        let distance = dists.forward(image.clone(), image, Reduction::Mean);

        let expected = TensorData::from([0.0]);
        distance
            .into_data()
            .assert_approx_eq::<FT>(&expected, Tolerance::default());
    }

    #[test]
    fn test_dists_different_images_nonzero_distance() {
        let device = Default::default();

        let image1 = TestTensor::<4>::zeros([1, 3, 64, 64], &device);
        let image2 = TestTensor::<4>::ones([1, 3, 64, 64], &device);

        let dists: Dists<TestBackend> = DistsConfig::new().init(&device);
        let distance = dists.forward(image1, image2, Reduction::Mean);

        let distance_value = distance.into_data().to_vec::<f32>().unwrap()[0];
        assert!(
            distance_value.abs() > 1e-6,
            "DISTS should be != 0 for different images"
        );
    }

    #[test]
    fn test_dists_symmetry() {
        let device = Default::default();

        let image1 = TestTensor::<4>::zeros([1, 3, 32, 32], &device);
        let image2 = TestTensor::<4>::ones([1, 3, 32, 32], &device);

        let dists: Dists<TestBackend> = DistsConfig::new().init(&device);
        let distance_forward = dists.forward(image1.clone(), image2.clone(), Reduction::Mean);
        let distance_reverse = dists.forward(image2, image1, Reduction::Mean);

        distance_forward
            .into_data()
            .assert_approx_eq::<FT>(&distance_reverse.into_data(), Tolerance::default());
    }

    #[test]
    fn test_dists_batch_processing() {
        let device = Default::default();

        let image1 = TestTensor::<4>::zeros([2, 3, 32, 32], &device);
        let image2 = TestTensor::<4>::ones([2, 3, 32, 32], &device);

        let dists: Dists<TestBackend> = DistsConfig::new().init(&device);
        let distance = dists.forward(image1, image2, Reduction::Mean);

        assert_eq!(distance.dims(), [1]);
    }

    #[test]
    fn test_dists_no_reduction() {
        let device = Default::default();

        let batch_size = 4;
        let image1 = TestTensor::<4>::zeros([batch_size, 3, 32, 32], &device);
        let image2 = TestTensor::<4>::ones([batch_size, 3, 32, 32], &device);

        let dists: Dists<TestBackend> = DistsConfig::new().init(&device);
        let distance = dists.forward_no_reduction(image1, image2);

        assert_eq!(distance.dims(), [batch_size]);
    }

    #[test]
    fn display_dists() {
        let device = Default::default();
        let dists: Dists<TestBackend> = DistsConfig::new().init(&device);

        let display_str = format!("{dists}");
        assert!(display_str.contains("Dists"));
        assert!(display_str.contains("VGG16-L2Pool"));
    }

    // =========================================================================
    // Pretrained Weights Tests (requires network)
    // =========================================================================

    /// Test DISTS pretrained weights download and loading.
    #[test]
    #[ignore = "downloads pre-trained weights"]
    fn test_dists_pretrained() {
        let device = Default::default();

        let dists: Dists<TestBackend> = DistsConfig::new().init_pretrained(&device);

        // Test with identical images - should be ~0
        // Use random image to avoid numerical edge cases with constant images
        let image = TestTensor::<4>::random(
            [1, 3, 64, 64],
            burn_core::tensor::Distribution::Uniform(0.0, 1.0),
            &device,
        );
        let distance = dists.forward(image.clone(), image, Reduction::Mean);
        let distance_value = distance.into_data().to_vec::<f32>().unwrap()[0];
        assert!(
            distance_value.abs() < 1e-5,
            "Pretrained DISTS should be ~0 for identical images, got {}",
            distance_value
        );

        // Test with different images - should be positive
        let image1 = TestTensor::<4>::random(
            [1, 3, 64, 64],
            burn_core::tensor::Distribution::Uniform(0.0, 0.3),
            &device,
        );
        let image2 = TestTensor::<4>::random(
            [1, 3, 64, 64],
            burn_core::tensor::Distribution::Uniform(0.7, 1.0),
            &device,
        );
        let distance = dists.forward(image1, image2, Reduction::Mean);
        let distance_value = distance.into_data().to_vec::<f32>().unwrap()[0];
        assert!(
            distance_value > 0.0,
            "Pretrained DISTS should be > 0 for different images"
        );
    }
}