burn_train/metric/vision/
dice.rs

1use crate::metric::MetricName;
2
3use super::super::{
4    Metric, MetricEntry, MetricMetadata, Numeric,
5    state::{FormatOptions, NumericMetricState},
6};
7use burn_core::{
8    prelude::{Backend, Tensor},
9    tensor::{ElementConversion, Int, s},
10};
11use core::marker::PhantomData;
12
13/// Input type for the [DiceMetric].
14///
15/// # Type Parameters
16/// - `B`: Backend type.
17/// - `D`: Number of dimensions. Should be more than, or equal to 3 (default 4).
18pub struct DiceInput<B: Backend, const D: usize = 4> {
19    /// Model outputs (predictions), as a tensor.
20    outputs: Tensor<B, D, Int>,
21    /// Ground truth targets, as a tensor.
22    targets: Tensor<B, D, Int>,
23}
24
25impl<B: Backend, const D: usize> DiceInput<B, D> {
26    /// Creates a new DiceInput with the given outputs and targets.
27    ///
28    /// Inputs are expected to have the dimensions `[B, C, ...]`
29    /// where `B` is the batch size, `C` is the number of classes,
30    /// and `...` represents additional dimensions (e.g., height, width for images).
31    ///
32    /// If `C` is more than 1, the first class (index 0) is considered the background.
33    /// Additionally, one-hot encoding is the responsibility of the caller.
34    ///
35    /// # Arguments
36    /// - `outputs`: The model outputs as a tensor.
37    /// - `targets`: The ground truth targets as a tensor.
38    ///
39    /// # Returns
40    /// A new instance of `DiceInput`.
41    ///
42    ///  # Panics
43    /// - If `D` is less than 3.
44    /// - If `outputs` and `targets` do not have the same dimensions.
45    /// - If `outputs` or `targets` do not have exactly `D` dimensions.
46    /// - If `outputs` and `targets` do not have the same shape.
47    pub fn new(outputs: Tensor<B, D, Int>, targets: Tensor<B, D, Int>) -> Self {
48        assert!(D >= 3, "DiceInput requires at least 3 dimensions.");
49        assert!(
50            outputs.dims() == targets.dims(),
51            "Outputs and targets must have the same dimensions. Got {:?} and {:?}",
52            outputs.dims(),
53            targets.dims()
54        );
55        Self { outputs, targets }
56    }
57}
58
59/// Configuration for the [DiceMetric].
60#[derive(Debug, Clone, Copy)]
61pub struct DiceMetricConfig {
62    /// Epsilon value to avoid division by zero.
63    pub epsilon: f64,
64    /// Whether to include the background class in the metric calculation.
65    /// The background is assumed to be the first class (index 0).
66    /// if `true`, will panic if there are fewer than 2 classes.
67    pub include_background: bool,
68}
69
70impl Default for DiceMetricConfig {
71    fn default() -> Self {
72        Self {
73            epsilon: 1e-7,
74            include_background: false,
75        }
76    }
77}
78
79/// The Dice-Sorenson coefficient (DSC) for evaluating overlap between two binary masks.
80/// The DSC is defined as:
81/// `DSC = 2 * (|X ∩ Y|) / (|X| + |Y|)`
82/// where `X` is the model output and `Y` is the ground truth target.
83///
84///  # Type Parameters
85/// - `B`: Backend type.
86/// - `D`: Number of dimensions. Should be more than, or equal to 3 (default 4).
87#[derive(Default, Clone)]
88pub struct DiceMetric<B: Backend, const D: usize = 4> {
89    name: MetricName,
90    /// Internal state for numeric metric aggregation.
91    state: NumericMetricState,
92    /// Marker for backend type.
93    _b: PhantomData<B>,
94    /// Configuration for the metric.
95    config: DiceMetricConfig,
96}
97
98impl<B: Backend, const D: usize> DiceMetric<B, D> {
99    /// Creates a new Dice metric instance with default config.
100    pub fn new() -> Self {
101        Self::with_config(DiceMetricConfig::default())
102    }
103
104    /// Creates a new Dice metric with a custom config.
105    pub fn with_config(config: DiceMetricConfig) -> Self {
106        let name = MetricName::new(format!("{D}D Dice Metric"));
107        assert!(D >= 3, "DiceMetric requires at least 3 dimensions.");
108        Self {
109            name,
110            config,
111            ..Default::default()
112        }
113    }
114}
115
116impl<B: Backend, const D: usize> Metric for DiceMetric<B, D> {
117    type Input = DiceInput<B, D>;
118
119    fn name(&self) -> MetricName {
120        self.name.clone()
121    }
122
123    fn update(&mut self, item: &Self::Input, _metadata: &MetricMetadata) -> MetricEntry {
124        // Dice coefficient: 2 * (|X ∩ Y|) / (|X| + |Y|)
125        if item.outputs.dims() != item.targets.dims() {
126            panic!(
127                "Outputs and targets must have the same dimensions. Got {:?} and {:?}",
128                item.outputs.dims(),
129                item.targets.dims()
130            );
131        }
132
133        let dims = item.outputs.dims();
134        let batch_size = dims[0];
135        let n_classes = dims[1];
136
137        let mut outputs = item.outputs.clone();
138        let mut targets = item.targets.clone();
139
140        if !self.config.include_background && n_classes > 1 {
141            // If not including background, we can ignore the first class
142            outputs = outputs.slice(s![.., 1..]);
143            targets = targets.slice(s![.., 1..]);
144        } else if self.config.include_background && n_classes < 2 {
145            // If including background, we need at least 2 classes
146            panic!("Dice metric requires at least 2 classes when including background.");
147        }
148
149        let intersection = (outputs.clone() * targets.clone()).sum();
150        let outputs_sum = outputs.sum();
151        let targets_sum = targets.sum();
152
153        // Convert to f64
154        let intersection_val = intersection.into_scalar().elem::<f64>();
155        let outputs_sum_val = outputs_sum.into_scalar().elem::<f64>();
156        let targets_sum_val = targets_sum.into_scalar().elem::<f64>();
157
158        // Use epsilon from config
159        let epsilon = self.config.epsilon;
160        let dice =
161            (2.0 * intersection_val + epsilon) / (outputs_sum_val + targets_sum_val + epsilon);
162
163        self.state.update(
164            dice,
165            batch_size,
166            FormatOptions::new(self.name()).precision(4),
167        )
168    }
169
170    /// Clears the metric state.
171    fn clear(&mut self) {
172        self.state.reset();
173    }
174}
175
176impl<B: Backend, const D: usize> Numeric for DiceMetric<B, D> {
177    /// Returns the current value of the metric.
178    fn value(&self) -> crate::metric::NumericEntry {
179        self.state.value()
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186    use crate::TestBackend;
187    use burn_core::tensor::{Shape, Tensor};
188
189    #[test]
190    fn test_dice_perfect_overlap() {
191        let device = Default::default();
192        let mut metric = DiceMetric::<TestBackend, 4>::new();
193        let input = DiceInput::new(
194            Tensor::from_data([[[[1, 0], [1, 0]]]], &device),
195            Tensor::from_data([[[[1, 0], [1, 0]]]], &device),
196        );
197        let _entry = metric.update(&input, &MetricMetadata::fake());
198        assert!((metric.value().current() - 1.0).abs() < 1e-6);
199    }
200
201    #[test]
202    fn test_dice_no_overlap() {
203        let device = Default::default();
204        let mut metric = DiceMetric::<TestBackend, 4>::new();
205        let input = DiceInput::new(
206            Tensor::from_data([[[[1, 0], [1, 0]]]], &device),
207            Tensor::from_data([[[[0, 1], [0, 1]]]], &device),
208        );
209        let _entry = metric.update(&input, &MetricMetadata::fake());
210        assert!(metric.value().current() < 1e-6);
211    }
212
213    #[test]
214    fn test_dice_partial_overlap() {
215        let device = Default::default();
216        let mut metric = DiceMetric::<TestBackend, 4>::new();
217        let input = DiceInput::new(
218            Tensor::from_data([[[[1, 1], [0, 0]]]], &device),
219            Tensor::from_data([[[[1, 0], [1, 0]]]], &device),
220        );
221        let _entry = metric.update(&input, &MetricMetadata::fake());
222        // intersection = 1, sum = 2+2=4, dice = 2*1/4 = 0.5
223        assert!((metric.value().current() - 0.5).abs() < 1e-6);
224    }
225
226    #[test]
227    fn test_dice_empty_masks() {
228        let device = Default::default();
229        let mut metric = DiceMetric::<TestBackend, 4>::new();
230        let input = DiceInput::new(
231            Tensor::from_data([[[[0, 0], [0, 0]]]], &device),
232            Tensor::from_data([[[[0, 0], [0, 0]]]], &device),
233        );
234        let _entry = metric.update(&input, &MetricMetadata::fake());
235        assert!((metric.value().current() - 1.0).abs() < 1e-6);
236    }
237
238    #[test]
239    fn test_dice_no_background() {
240        let device = Default::default();
241        let mut metric = DiceMetric::<TestBackend, 4>::new();
242        let input = DiceInput::new(
243            Tensor::ones(Shape::new([1, 1, 2, 2]), &device),
244            Tensor::ones(Shape::new([1, 1, 2, 2]), &device),
245        );
246        let _entry = metric.update(&input, &MetricMetadata::fake());
247        assert!((metric.value().current() - 1.0).abs() < 1e-6);
248    }
249
250    #[test]
251    fn test_dice_with_background() {
252        let device = Default::default();
253        let config = DiceMetricConfig {
254            epsilon: 1e-7,
255            include_background: true,
256        };
257        let mut metric = DiceMetric::<TestBackend, 4>::with_config(config);
258        let input = DiceInput::new(
259            Tensor::ones(Shape::new([1, 2, 2, 2]), &device),
260            Tensor::ones(Shape::new([1, 2, 2, 2]), &device),
261        );
262        let _entry = metric.update(&input, &MetricMetadata::fake());
263        assert!((metric.value().current() - 1.0).abs() < 1e-6);
264    }
265
266    #[test]
267    fn test_dice_ignored_background() {
268        let device = Default::default();
269        let config = DiceMetricConfig {
270            epsilon: 1e-7,
271            include_background: false,
272        };
273        let mut metric = DiceMetric::<TestBackend, 4>::with_config(config);
274        let input = DiceInput::new(
275            Tensor::ones(Shape::new([1, 2, 2, 2]), &device),
276            Tensor::ones(Shape::new([1, 2, 2, 2]), &device),
277        );
278        let _entry = metric.update(&input, &MetricMetadata::fake());
279        assert!((metric.value().current() - 1.0).abs() < 1e-6);
280    }
281
282    #[test]
283    #[should_panic(expected = "DiceInput requires at least 3 dimensions.")]
284    fn test_invalid_input_dimensions() {
285        let device = Default::default();
286        // D = 2, should panic
287        let _ = DiceInput::<TestBackend, 2>::new(
288            Tensor::from_data([[0.0, 0.0]], &device),
289            Tensor::from_data([[0.0, 0.0]], &device),
290        );
291    }
292
293    #[test]
294    #[should_panic(
295        expected = "Outputs and targets must have the same dimensions. Got [1, 1, 2, 2] and [1, 1, 2, 3]"
296    )]
297    fn test_mismatched_shape() {
298        let device = Default::default();
299        // shapes differ
300        let _ = DiceInput::<TestBackend, 4>::new(
301            Tensor::from_data([[[[0.0; 2]; 2]; 1]; 1], &device),
302            Tensor::from_data([[[[0.0; 3]; 2]; 1]; 1], &device),
303        );
304    }
305
306    #[test]
307    #[should_panic(expected = "Dice metric requires at least 2 classes when including background.")]
308    fn test_include_background_panic() {
309        let device = Default::default();
310        let config = DiceMetricConfig {
311            epsilon: 1e-7,
312            include_background: true,
313        };
314        let mut metric = DiceMetric::<TestBackend, 4>::with_config(config);
315        let input = DiceInput::new(
316            Tensor::from_data([[[[1.0; 2]; 1]; 1]; 1], &device),
317            Tensor::from_data([[[[1.0; 2]; 1]; 1]; 1], &device),
318        );
319        // n_classes = 2, should not panic
320        let _entry = metric.update(&input, &MetricMetadata::fake());
321
322        let config = DiceMetricConfig {
323            epsilon: 1e-7,
324            include_background: true,
325        };
326        let mut metric = DiceMetric::<TestBackend, 4>::with_config(config);
327        let input = DiceInput::new(
328            Tensor::from_data([[[[1.0; 1]; 1]; 1]; 1], &device),
329            Tensor::from_data([[[[1.0; 1]; 1]; 1]; 1], &device),
330        );
331        // n_classes = 1, should panic
332        let _entry = metric.update(&input, &MetricMetadata::fake());
333    }
334}