1use crate::metric::MetricName;
2
3use super::super::{
4 Metric, MetricEntry, MetricMetadata, Numeric,
5 state::{FormatOptions, NumericMetricState},
6};
7use burn_core::{
8 prelude::{Backend, Tensor},
9 tensor::{ElementConversion, Int, s},
10};
11use core::marker::PhantomData;
12
13pub struct DiceInput<B: Backend, const D: usize = 4> {
19 outputs: Tensor<B, D, Int>,
21 targets: Tensor<B, D, Int>,
23}
24
25impl<B: Backend, const D: usize> DiceInput<B, D> {
26 pub fn new(outputs: Tensor<B, D, Int>, targets: Tensor<B, D, Int>) -> Self {
48 assert!(D >= 3, "DiceInput requires at least 3 dimensions.");
49 assert!(
50 outputs.dims() == targets.dims(),
51 "Outputs and targets must have the same dimensions. Got {:?} and {:?}",
52 outputs.dims(),
53 targets.dims()
54 );
55 Self { outputs, targets }
56 }
57}
58
59#[derive(Debug, Clone, Copy)]
61pub struct DiceMetricConfig {
62 pub epsilon: f64,
64 pub include_background: bool,
68}
69
70impl Default for DiceMetricConfig {
71 fn default() -> Self {
72 Self {
73 epsilon: 1e-7,
74 include_background: false,
75 }
76 }
77}
78
79#[derive(Default, Clone)]
88pub struct DiceMetric<B: Backend, const D: usize = 4> {
89 name: MetricName,
90 state: NumericMetricState,
92 _b: PhantomData<B>,
94 config: DiceMetricConfig,
96}
97
98impl<B: Backend, const D: usize> DiceMetric<B, D> {
99 pub fn new() -> Self {
101 Self::with_config(DiceMetricConfig::default())
102 }
103
104 pub fn with_config(config: DiceMetricConfig) -> Self {
106 let name = MetricName::new(format!("{D}D Dice Metric"));
107 assert!(D >= 3, "DiceMetric requires at least 3 dimensions.");
108 Self {
109 name,
110 config,
111 ..Default::default()
112 }
113 }
114}
115
116impl<B: Backend, const D: usize> Metric for DiceMetric<B, D> {
117 type Input = DiceInput<B, D>;
118
119 fn name(&self) -> MetricName {
120 self.name.clone()
121 }
122
123 fn update(&mut self, item: &Self::Input, _metadata: &MetricMetadata) -> MetricEntry {
124 if item.outputs.dims() != item.targets.dims() {
126 panic!(
127 "Outputs and targets must have the same dimensions. Got {:?} and {:?}",
128 item.outputs.dims(),
129 item.targets.dims()
130 );
131 }
132
133 let dims = item.outputs.dims();
134 let batch_size = dims[0];
135 let n_classes = dims[1];
136
137 let mut outputs = item.outputs.clone();
138 let mut targets = item.targets.clone();
139
140 if !self.config.include_background && n_classes > 1 {
141 outputs = outputs.slice(s![.., 1..]);
143 targets = targets.slice(s![.., 1..]);
144 } else if self.config.include_background && n_classes < 2 {
145 panic!("Dice metric requires at least 2 classes when including background.");
147 }
148
149 let intersection = (outputs.clone() * targets.clone()).sum();
150 let outputs_sum = outputs.sum();
151 let targets_sum = targets.sum();
152
153 let intersection_val = intersection.into_scalar().elem::<f64>();
155 let outputs_sum_val = outputs_sum.into_scalar().elem::<f64>();
156 let targets_sum_val = targets_sum.into_scalar().elem::<f64>();
157
158 let epsilon = self.config.epsilon;
160 let dice =
161 (2.0 * intersection_val + epsilon) / (outputs_sum_val + targets_sum_val + epsilon);
162
163 self.state.update(
164 dice,
165 batch_size,
166 FormatOptions::new(self.name()).precision(4),
167 )
168 }
169
170 fn clear(&mut self) {
172 self.state.reset();
173 }
174}
175
176impl<B: Backend, const D: usize> Numeric for DiceMetric<B, D> {
177 fn value(&self) -> crate::metric::NumericEntry {
179 self.state.value()
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186 use crate::TestBackend;
187 use burn_core::tensor::{Shape, Tensor};
188
189 #[test]
190 fn test_dice_perfect_overlap() {
191 let device = Default::default();
192 let mut metric = DiceMetric::<TestBackend, 4>::new();
193 let input = DiceInput::new(
194 Tensor::from_data([[[[1, 0], [1, 0]]]], &device),
195 Tensor::from_data([[[[1, 0], [1, 0]]]], &device),
196 );
197 let _entry = metric.update(&input, &MetricMetadata::fake());
198 assert!((metric.value().current() - 1.0).abs() < 1e-6);
199 }
200
201 #[test]
202 fn test_dice_no_overlap() {
203 let device = Default::default();
204 let mut metric = DiceMetric::<TestBackend, 4>::new();
205 let input = DiceInput::new(
206 Tensor::from_data([[[[1, 0], [1, 0]]]], &device),
207 Tensor::from_data([[[[0, 1], [0, 1]]]], &device),
208 );
209 let _entry = metric.update(&input, &MetricMetadata::fake());
210 assert!(metric.value().current() < 1e-6);
211 }
212
213 #[test]
214 fn test_dice_partial_overlap() {
215 let device = Default::default();
216 let mut metric = DiceMetric::<TestBackend, 4>::new();
217 let input = DiceInput::new(
218 Tensor::from_data([[[[1, 1], [0, 0]]]], &device),
219 Tensor::from_data([[[[1, 0], [1, 0]]]], &device),
220 );
221 let _entry = metric.update(&input, &MetricMetadata::fake());
222 assert!((metric.value().current() - 0.5).abs() < 1e-6);
224 }
225
226 #[test]
227 fn test_dice_empty_masks() {
228 let device = Default::default();
229 let mut metric = DiceMetric::<TestBackend, 4>::new();
230 let input = DiceInput::new(
231 Tensor::from_data([[[[0, 0], [0, 0]]]], &device),
232 Tensor::from_data([[[[0, 0], [0, 0]]]], &device),
233 );
234 let _entry = metric.update(&input, &MetricMetadata::fake());
235 assert!((metric.value().current() - 1.0).abs() < 1e-6);
236 }
237
238 #[test]
239 fn test_dice_no_background() {
240 let device = Default::default();
241 let mut metric = DiceMetric::<TestBackend, 4>::new();
242 let input = DiceInput::new(
243 Tensor::ones(Shape::new([1, 1, 2, 2]), &device),
244 Tensor::ones(Shape::new([1, 1, 2, 2]), &device),
245 );
246 let _entry = metric.update(&input, &MetricMetadata::fake());
247 assert!((metric.value().current() - 1.0).abs() < 1e-6);
248 }
249
250 #[test]
251 fn test_dice_with_background() {
252 let device = Default::default();
253 let config = DiceMetricConfig {
254 epsilon: 1e-7,
255 include_background: true,
256 };
257 let mut metric = DiceMetric::<TestBackend, 4>::with_config(config);
258 let input = DiceInput::new(
259 Tensor::ones(Shape::new([1, 2, 2, 2]), &device),
260 Tensor::ones(Shape::new([1, 2, 2, 2]), &device),
261 );
262 let _entry = metric.update(&input, &MetricMetadata::fake());
263 assert!((metric.value().current() - 1.0).abs() < 1e-6);
264 }
265
266 #[test]
267 fn test_dice_ignored_background() {
268 let device = Default::default();
269 let config = DiceMetricConfig {
270 epsilon: 1e-7,
271 include_background: false,
272 };
273 let mut metric = DiceMetric::<TestBackend, 4>::with_config(config);
274 let input = DiceInput::new(
275 Tensor::ones(Shape::new([1, 2, 2, 2]), &device),
276 Tensor::ones(Shape::new([1, 2, 2, 2]), &device),
277 );
278 let _entry = metric.update(&input, &MetricMetadata::fake());
279 assert!((metric.value().current() - 1.0).abs() < 1e-6);
280 }
281
282 #[test]
283 #[should_panic(expected = "DiceInput requires at least 3 dimensions.")]
284 fn test_invalid_input_dimensions() {
285 let device = Default::default();
286 let _ = DiceInput::<TestBackend, 2>::new(
288 Tensor::from_data([[0.0, 0.0]], &device),
289 Tensor::from_data([[0.0, 0.0]], &device),
290 );
291 }
292
293 #[test]
294 #[should_panic(
295 expected = "Outputs and targets must have the same dimensions. Got [1, 1, 2, 2] and [1, 1, 2, 3]"
296 )]
297 fn test_mismatched_shape() {
298 let device = Default::default();
299 let _ = DiceInput::<TestBackend, 4>::new(
301 Tensor::from_data([[[[0.0; 2]; 2]; 1]; 1], &device),
302 Tensor::from_data([[[[0.0; 3]; 2]; 1]; 1], &device),
303 );
304 }
305
306 #[test]
307 #[should_panic(expected = "Dice metric requires at least 2 classes when including background.")]
308 fn test_include_background_panic() {
309 let device = Default::default();
310 let config = DiceMetricConfig {
311 epsilon: 1e-7,
312 include_background: true,
313 };
314 let mut metric = DiceMetric::<TestBackend, 4>::with_config(config);
315 let input = DiceInput::new(
316 Tensor::from_data([[[[1.0; 2]; 1]; 1]; 1], &device),
317 Tensor::from_data([[[[1.0; 2]; 1]; 1]; 1], &device),
318 );
319 let _entry = metric.update(&input, &MetricMetadata::fake());
321
322 let config = DiceMetricConfig {
323 epsilon: 1e-7,
324 include_background: true,
325 };
326 let mut metric = DiceMetric::<TestBackend, 4>::with_config(config);
327 let input = DiceInput::new(
328 Tensor::from_data([[[[1.0; 1]; 1]; 1]; 1], &device),
329 Tensor::from_data([[[[1.0; 1]; 1]; 1]; 1], &device),
330 );
331 let _entry = metric.update(&input, &MetricMetadata::fake());
333 }
334}