Skip to main content

scirs2_ndimage/visualization/
statistical.rs

1//! Statistical Visualization Functions
2//!
3//! This module provides specialized visualization functions for statistical analysis,
4//! comparative studies, and multi-dataset visualization. These functions are designed
5//! to support statistical research and data analysis workflows.
6
7use scirs2_core::ndarray::ArrayStatCompat;
8use scirs2_core::ndarray::{ArrayView1, ArrayView2};
9use scirs2_core::numeric::{Float, FromPrimitive, ToPrimitive, Zero};
10use std::fmt::{Debug, Write};
11
12use crate::error::{NdimageError, NdimageResult};
13use crate::visualization::types::{PlotConfig, ReportFormat};
14use statrs::statistics::Statistics;
15
16/// Create an image montage/grid from multiple 2D arrays
17///
18/// This function arranges multiple images in a grid layout for comparison and analysis.
19/// It automatically scales all images using global min/max values for consistent visualization.
20///
21/// # Arguments
22///
23/// * `images` - Slice of 2D array views representing the images to arrange
24/// * `grid_cols` - Number of columns in the grid layout
25/// * `config` - Plot configuration specifying format and styling
26///
27/// # Returns
28///
29/// A formatted string representation of the image grid in the specified format
30///
31/// # Examples
32///
33/// ```rust,ignore
34/// use scirs2_core::ndarray::Array2;
35/// use scirs2_ndimage::visualization::{PlotConfig, ReportFormat, create_image_montage};
36///
37/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
38/// let img1 = Array2::zeros((10, 10));
39/// let img2 = Array2::ones((10, 10));
40/// let images = vec![img1.view(), img2.view()];
41///
42/// let config = PlotConfig::new()
43///     .with_format(ReportFormat::Text)
44///     .with_title("Image Comparison");
45///
46/// let montage = create_image_montage(&images, 2, &config)?;
47/// # Ok(())
48/// # }
49/// ```
50#[allow(dead_code)]
51pub fn create_image_montage<T>(
52    images: &[ArrayView2<T>],
53    grid_cols: usize,
54    config: &PlotConfig,
55) -> NdimageResult<String>
56where
57    T: Float + FromPrimitive + ToPrimitive + Debug + Clone,
58{
59    if images.is_empty() {
60        return Err(NdimageError::InvalidInput("No images provided".into()));
61    }
62
63    if grid_cols == 0 {
64        return Err(NdimageError::InvalidInput(
65            "Grid columns must be positive".into(),
66        ));
67    }
68
69    let mut plot = String::new();
70    let grid_rows = (images.len() + grid_cols - 1) / grid_cols;
71
72    // Find global min/max for consistent scaling
73    let mut global_min = T::infinity();
74    let mut global_max = T::neg_infinity();
75
76    for image in images {
77        let min_val = image.iter().cloned().fold(T::infinity(), T::min);
78        let max_val = image.iter().cloned().fold(T::neg_infinity(), T::max);
79        global_min = global_min.min(min_val);
80        global_max = global_max.max(max_val);
81    }
82
83    if global_max <= global_min {
84        return Err(NdimageError::InvalidInput(
85            "All image values are the same".into(),
86        ));
87    }
88
89    match config.format {
90        ReportFormat::Html => {
91            writeln!(&mut plot, "<div class='image-montage'>")?;
92            writeln!(&mut plot, "<h3>{}</h3>", config.title)?;
93            writeln!(&mut plot, "<div class='montage-grid' style='display: grid; grid-template-columns: repeat({}, 1fr); gap: 10px;'>", grid_cols)?;
94
95            for (idx, image) in images.iter().enumerate() {
96                let (height, width) = image.dim();
97                writeln!(&mut plot, "<div class='montage-cell'>")?;
98                writeln!(&mut plot, "<h4>Image {}</h4>", idx + 1)?;
99                writeln!(
100                    &mut plot,
101                    "<div class='image-data' data-width='{}' data-height='{}'>",
102                    width, height
103                )?;
104
105                // Simple representation - would need actual image rendering in practice
106                writeln!(&mut plot, "<p>{}×{} array</p>", height, width)?;
107                writeln!(
108                    &mut plot,
109                    "<p>Range: [{:.3}, {:.3}]</p>",
110                    image
111                        .iter()
112                        .cloned()
113                        .fold(T::infinity(), T::min)
114                        .to_f64()
115                        .unwrap_or(0.0),
116                    image
117                        .iter()
118                        .cloned()
119                        .fold(T::neg_infinity(), T::max)
120                        .to_f64()
121                        .unwrap_or(0.0)
122                )?;
123
124                writeln!(&mut plot, "</div>")?;
125                writeln!(&mut plot, "</div>")?;
126            }
127
128            writeln!(&mut plot, "</div>")?;
129            writeln!(&mut plot, "<div class='montage-info'>")?;
130            writeln!(
131                &mut plot,
132                "<p>Global range: [{:.3}, {:.3}]</p>",
133                global_min.to_f64().unwrap_or(0.0),
134                global_max.to_f64().unwrap_or(0.0)
135            )?;
136            writeln!(
137                &mut plot,
138                "<p>Grid: {} rows × {} columns</p>",
139                grid_rows, grid_cols
140            )?;
141            writeln!(&mut plot, "</div>")?;
142            writeln!(&mut plot, "</div>")?;
143        }
144        ReportFormat::Markdown => {
145            writeln!(&mut plot, "## {} (Image Montage)", config.title)?;
146            writeln!(&mut plot)?;
147            writeln!(
148                &mut plot,
149                "Grid layout: {} rows × {} columns",
150                grid_rows, grid_cols
151            )?;
152            writeln!(
153                &mut plot,
154                "Global value range: [{:.3}, {:.3}]",
155                global_min.to_f64().unwrap_or(0.0),
156                global_max.to_f64().unwrap_or(0.0)
157            )?;
158            writeln!(&mut plot)?;
159
160            for (idx, image) in images.iter().enumerate() {
161                let (height, width) = image.dim();
162                let min_val = image.iter().cloned().fold(T::infinity(), T::min);
163                let max_val = image.iter().cloned().fold(T::neg_infinity(), T::max);
164
165                writeln!(&mut plot, "### Image {}", idx + 1)?;
166                writeln!(&mut plot, "- Dimensions: {}×{}", height, width)?;
167                writeln!(
168                    &mut plot,
169                    "- Value range: [{:.3}, {:.3}]",
170                    min_val.to_f64().unwrap_or(0.0),
171                    max_val.to_f64().unwrap_or(0.0)
172                )?;
173                writeln!(&mut plot)?;
174            }
175        }
176        ReportFormat::Text => {
177            writeln!(&mut plot, "{} (Image Montage)", config.title)?;
178            writeln!(&mut plot, "{}", "=".repeat(config.title.len() + 16))?;
179            writeln!(&mut plot)?;
180            writeln!(
181                &mut plot,
182                "Grid layout: {} rows × {} columns",
183                grid_rows, grid_cols
184            )?;
185            writeln!(
186                &mut plot,
187                "Global value range: [{:.3}, {:.3}]",
188                global_min.to_f64().unwrap_or(0.0),
189                global_max.to_f64().unwrap_or(0.0)
190            )?;
191            writeln!(&mut plot)?;
192
193            for (idx, image) in images.iter().enumerate() {
194                let (height, width) = image.dim();
195                let min_val = image.iter().cloned().fold(T::infinity(), T::min);
196                let max_val = image.iter().cloned().fold(T::neg_infinity(), T::max);
197
198                writeln!(
199                    &mut plot,
200                    "Image {}: {}×{}, range [{:.3}, {:.3}]",
201                    idx + 1,
202                    height,
203                    width,
204                    min_val.to_f64().unwrap_or(0.0),
205                    max_val.to_f64().unwrap_or(0.0)
206                )?;
207            }
208        }
209    }
210
211    Ok(plot)
212}
213
214/// Generate a comparative statistical plot for multiple datasets
215///
216/// This function creates a comprehensive statistical comparison table showing
217/// key statistics (count, mean, standard deviation, min, max) for multiple datasets.
218/// Useful for comparing different experimental conditions or processing results.
219///
220/// # Arguments
221///
222/// * `datasets` - Slice of tuples containing dataset names and their 1D data arrays
223/// * `config` - Plot configuration specifying format and styling
224///
225/// # Returns
226///
227/// A formatted statistical comparison table in the specified format
228///
229/// # Examples
230///
231/// ```rust
232/// use scirs2_core::ndarray::Array1;
233/// use scirs2_ndimage::visualization::{PlotConfig, ReportFormat, plot_statistical_comparison};
234///
235/// let control = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
236/// let treatment = Array1::from_vec(vec![2.0, 3.0, 4.0, 5.0, 6.0]);
237///
238/// let datasets = vec![
239///     ("Control", control.view()),
240///     ("Treatment", treatment.view()),
241/// ];
242///
243/// let config = PlotConfig::new()
244///     .with_format(ReportFormat::Markdown)
245///     .with_title("Statistical Comparison");
246///
247/// let comparison = plot_statistical_comparison(&datasets, &config)?;
248/// # Ok::<(), Box<dyn std::error::Error>>(())
249/// ```
250#[allow(dead_code)]
251pub fn plot_statistical_comparison<T>(
252    datasets: &[(&str, ArrayView1<T>)],
253    config: &PlotConfig,
254) -> NdimageResult<String>
255where
256    T: Float + FromPrimitive + ToPrimitive + Debug + Clone,
257{
258    if datasets.is_empty() {
259        return Err(NdimageError::InvalidInput("No datasets provided".into()));
260    }
261
262    let mut plot = String::new();
263
264    // Compute statistics for each dataset
265    let mut stats = Vec::new();
266    for (name, data) in datasets {
267        if data.is_empty() {
268            continue;
269        }
270
271        let mean = data.mean_or(T::zero());
272        let min_val = data.iter().cloned().fold(T::infinity(), T::min);
273        let max_val = data.iter().cloned().fold(T::neg_infinity(), T::max);
274        let variance = data
275            .mapv(|x| (x - mean) * (x - mean))
276            .mean()
277            .unwrap_or(T::zero());
278        let std_dev = variance.sqrt();
279
280        stats.push((name, mean, std_dev, min_val, max_val, data.len()));
281    }
282
283    match config.format {
284        ReportFormat::Html => {
285            writeln!(&mut plot, "<div class='statistical-comparison'>")?;
286            writeln!(&mut plot, "<h3>{}</h3>", config.title)?;
287            writeln!(&mut plot, "<table class='stats-table'>")?;
288            writeln!(&mut plot, "<tr><th>Dataset</th><th>Count</th><th>Mean</th><th>Std Dev</th><th>Min</th><th>Max</th></tr>")?;
289
290            for (name, mean, std_dev, min_val, max_val, count) in &stats {
291                writeln!(
292                    &mut plot,
293                    "<tr><td>{}</td><td>{}</td><td>{:.4}</td><td>{:.4}</td><td>{:.4}</td><td>{:.4}</td></tr>",
294                    name, count,
295                    mean.to_f64().unwrap_or(0.0),
296                    std_dev.to_f64().unwrap_or(0.0),
297                    min_val.to_f64().unwrap_or(0.0),
298                    max_val.to_f64().unwrap_or(0.0)
299                )?;
300            }
301
302            writeln!(&mut plot, "</table>")?;
303            writeln!(&mut plot, "</div>")?;
304        }
305        ReportFormat::Markdown => {
306            writeln!(&mut plot, "## {} (Statistical Comparison)", config.title)?;
307            writeln!(&mut plot)?;
308            writeln!(
309                &mut plot,
310                "| Dataset | Count | Mean | Std Dev | Min | Max |"
311            )?;
312            writeln!(
313                &mut plot,
314                "|---------|-------|------|---------|-----|-----|"
315            )?;
316
317            for (name, mean, std_dev, min_val, max_val, count) in &stats {
318                writeln!(
319                    &mut plot,
320                    "| {} | {} | {:.4} | {:.4} | {:.4} | {:.4} |",
321                    name,
322                    count,
323                    mean.to_f64().unwrap_or(0.0),
324                    std_dev.to_f64().unwrap_or(0.0),
325                    min_val.to_f64().unwrap_or(0.0),
326                    max_val.to_f64().unwrap_or(0.0)
327                )?;
328            }
329            writeln!(&mut plot)?;
330        }
331        ReportFormat::Text => {
332            writeln!(&mut plot, "{} (Statistical Comparison)", config.title)?;
333            writeln!(&mut plot, "{}", "=".repeat(config.title.len() + 25))?;
334            writeln!(&mut plot)?;
335            writeln!(
336                &mut plot,
337                "{:<15} {:>8} {:>10} {:>10} {:>10} {:>10}",
338                "Dataset", "Count", "Mean", "Std Dev", "Min", "Max"
339            )?;
340            writeln!(&mut plot, "{}", "-".repeat(75))?;
341
342            for (name, mean, std_dev, min_val, max_val, count) in &stats {
343                writeln!(
344                    &mut plot,
345                    "{:<15} {:>8} {:>10.4} {:>10.4} {:>10.4} {:>10.4}",
346                    name,
347                    count,
348                    mean.to_f64().unwrap_or(0.0),
349                    std_dev.to_f64().unwrap_or(0.0),
350                    min_val.to_f64().unwrap_or(0.0),
351                    max_val.to_f64().unwrap_or(0.0)
352                )?;
353            }
354        }
355    }
356
357    Ok(plot)
358}
359
360/// Calculate statistical summary for a dataset
361///
362/// Helper function that computes comprehensive statistics for a single dataset.
363/// Used internally by other statistical visualization functions.
364///
365/// # Arguments
366///
367/// * `data` - 1D array view of the dataset
368///
369/// # Returns
370///
371/// Tuple containing (mean, std_dev, min, max, count)
372pub fn calculate_dataset_statistics<T>(data: &ArrayView1<T>) -> (T, T, T, T, usize)
373where
374    T: Float + FromPrimitive + ToPrimitive + Debug + Clone,
375{
376    if data.is_empty() {
377        return (T::zero(), T::zero(), T::zero(), T::zero(), 0);
378    }
379
380    let mean = data.mean_or(T::zero());
381    let min_val = data.iter().cloned().fold(T::infinity(), T::min);
382    let max_val = data.iter().cloned().fold(T::neg_infinity(), T::max);
383    let variance = data
384        .mapv(|x| (x - mean) * (x - mean))
385        .mean()
386        .unwrap_or(T::zero());
387    let std_dev = variance.sqrt();
388
389    (mean, std_dev, min_val, max_val, data.len())
390}
391
392/// Generate correlation matrix visualization for multiple datasets
393///
394/// Creates a text-based correlation matrix showing relationships between datasets.
395/// Useful for understanding data dependencies and relationships.
396///
397/// # Arguments
398///
399/// * `datasets` - Slice of tuples containing dataset names and their 1D data arrays
400/// * `config` - Plot configuration specifying format and styling
401///
402/// # Returns
403///
404/// A formatted correlation matrix in the specified format
405#[allow(dead_code)]
406pub fn plot_correlation_matrix<T>(
407    datasets: &[(&str, ArrayView1<T>)],
408    config: &PlotConfig,
409) -> NdimageResult<String>
410where
411    T: Float + FromPrimitive + ToPrimitive + Debug + Clone,
412{
413    if datasets.len() < 2 {
414        return Err(NdimageError::InvalidInput(
415            "Need at least 2 datasets for correlation".into(),
416        ));
417    }
418
419    let mut plot = String::new();
420    let n = datasets.len();
421
422    // Calculate correlation matrix
423    let mut correlations = vec![vec![0.0; n]; n];
424
425    for i in 0..n {
426        for j in 0..n {
427            if i == j {
428                correlations[i][j] = 1.0;
429            } else {
430                let corr = calculate_correlation(&datasets[i].1, &datasets[j].1);
431                correlations[i][j] = corr;
432            }
433        }
434    }
435
436    match config.format {
437        ReportFormat::Html => {
438            writeln!(&mut plot, "<div class='correlation-matrix'>")?;
439            writeln!(&mut plot, "<h3>{}</h3>", config.title)?;
440            writeln!(&mut plot, "<table class='correlation-table'>")?;
441
442            // Header row
443            write!(&mut plot, "<tr><th></th>")?;
444            for (name, _) in datasets {
445                write!(&mut plot, "<th>{}</th>", name)?;
446            }
447            writeln!(&mut plot, "</tr>")?;
448
449            // Data rows
450            for i in 0..n {
451                write!(&mut plot, "<tr><th>{}</th>", datasets[i].0)?;
452                for j in 0..n {
453                    let corr = correlations[i][j];
454                    let color_class = if corr.abs() > 0.7 {
455                        "strong-corr"
456                    } else {
457                        "weak-corr"
458                    };
459                    write!(&mut plot, "<td class='{}'>{:.3}</td>", color_class, corr)?;
460                }
461                writeln!(&mut plot, "</tr>")?;
462            }
463
464            writeln!(&mut plot, "</table>")?;
465            writeln!(&mut plot, "</div>")?;
466        }
467        ReportFormat::Markdown => {
468            writeln!(&mut plot, "## {} (Correlation Matrix)", config.title)?;
469            writeln!(&mut plot)?;
470
471            // Header row
472            write!(&mut plot, "|")?;
473            for (name, _) in datasets {
474                write!(&mut plot, " {} |", name)?;
475            }
476            writeln!(&mut plot)?;
477
478            // Separator row
479            write!(&mut plot, "|")?;
480            for _ in 0..n {
481                write!(&mut plot, "------|")?;
482            }
483            writeln!(&mut plot)?;
484
485            // Data rows
486            for i in 0..n {
487                write!(&mut plot, "| **{}** |", datasets[i].0)?;
488                for j in 0..n {
489                    write!(&mut plot, " {:.3} |", correlations[i][j])?;
490                }
491                writeln!(&mut plot)?;
492            }
493            writeln!(&mut plot)?;
494        }
495        ReportFormat::Text => {
496            writeln!(&mut plot, "{} (Correlation Matrix)", config.title)?;
497            writeln!(&mut plot, "{}", "=".repeat(config.title.len() + 20))?;
498            writeln!(&mut plot)?;
499
500            // Header row
501            write!(&mut plot, "{:>12}", "")?;
502            for (name, _) in datasets {
503                write!(&mut plot, " {:>8}", &name[..name.len().min(8)])?;
504            }
505            writeln!(&mut plot)?;
506
507            // Data rows
508            for i in 0..n {
509                write!(
510                    &mut plot,
511                    "{:>12}",
512                    &datasets[i].0[..datasets[i].0.len().min(12)]
513                )?;
514                for j in 0..n {
515                    write!(&mut plot, " {:>8.3}", correlations[i][j])?;
516                }
517                writeln!(&mut plot)?;
518            }
519        }
520    }
521
522    Ok(plot)
523}
524
525/// Calculate Pearson correlation coefficient between two datasets
526fn calculate_correlation<T>(data1: &ArrayView1<T>, data2: &ArrayView1<T>) -> f64
527where
528    T: Float + FromPrimitive + ToPrimitive + Debug + Clone,
529{
530    if data1.len() != data2.len() || data1.len() < 2 {
531        return 0.0;
532    }
533
534    let mean1 = data1.mean_or(T::zero()).to_f64().unwrap_or(0.0);
535    let mean2 = data2.mean_or(T::zero()).to_f64().unwrap_or(0.0);
536
537    let mut sum_xy = 0.0;
538    let mut sum_x2 = 0.0;
539    let mut sum_y2 = 0.0;
540
541    for i in 0..data1.len() {
542        let x = data1[i].to_f64().unwrap_or(0.0) - mean1;
543        let y = data2[i].to_f64().unwrap_or(0.0) - mean2;
544
545        sum_xy += x * y;
546        sum_x2 += x * x;
547        sum_y2 += y * y;
548    }
549
550    let denominator = (sum_x2 * sum_y2).sqrt();
551    if denominator == 0.0 {
552        0.0
553    } else {
554        sum_xy / denominator
555    }
556}
557
558#[cfg(test)]
559mod tests {
560    use super::*;
561    use scirs2_core::ndarray::{Array1, Array2};
562
563    #[test]
564    fn test_create_image_montage() {
565        let img1 = Array2::zeros((5, 5));
566        let img2 = Array2::ones((5, 5));
567        let img3 = Array2::from_elem((5, 5), 2.0);
568
569        let images = vec![img1.view(), img2.view(), img3.view()];
570
571        let config = PlotConfig::new()
572            .with_format(ReportFormat::Text)
573            .with_title("Test Montage");
574
575        let result = create_image_montage(&images, 2, &config);
576        assert!(result.is_ok());
577
578        let montage = result.expect("Operation failed");
579        assert!(montage.contains("Test Montage"));
580        assert!(montage.contains("Grid layout: 2 rows × 2 columns"));
581        assert!(montage.contains("Image 1: 5×5"));
582        assert!(montage.contains("Image 2: 5×5"));
583        assert!(montage.contains("Image 3: 5×5"));
584    }
585
586    #[test]
587    fn test_plot_statistical_comparison() {
588        let data1 = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
589        let data2 = Array1::from_vec(vec![2.0, 4.0, 6.0, 8.0, 10.0]);
590
591        let datasets = vec![("Dataset A", data1.view()), ("Dataset B", data2.view())];
592
593        let config = PlotConfig::new()
594            .with_format(ReportFormat::Markdown)
595            .with_title("Statistical Test");
596
597        let result = plot_statistical_comparison(&datasets, &config);
598        assert!(result.is_ok());
599
600        let comparison = result.expect("Operation failed");
601        assert!(comparison.contains("Statistical Test"));
602        assert!(comparison.contains("Dataset A"));
603        assert!(comparison.contains("Dataset B"));
604        assert!(comparison.contains("| Dataset | Count | Mean"));
605    }
606
607    #[test]
608    fn test_calculate_dataset_statistics() {
609        let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
610        let (mean, std_dev, min_val, max_val, count) = calculate_dataset_statistics(&data.view());
611
612        assert!((mean - 3.0).abs() < 1e-6);
613        assert_eq!(min_val, 1.0);
614        assert_eq!(max_val, 5.0);
615        assert_eq!(count, 5);
616        assert!(std_dev > 0.0);
617    }
618
619    #[test]
620    fn test_calculate_correlation() {
621        let data1 = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
622        let data2 = Array1::from_vec(vec![2.0, 4.0, 6.0, 8.0, 10.0]); // Perfect positive correlation
623
624        let corr = calculate_correlation(&data1.view(), &data2.view());
625        assert!((corr - 1.0).abs() < 1e-10); // Should be very close to 1.0
626    }
627
628    #[test]
629    fn test_plot_correlation_matrix() {
630        let data1 = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
631        let data2 = Array1::from_vec(vec![2.0, 4.0, 6.0, 8.0, 10.0]);
632        let data3 = Array1::from_vec(vec![5.0, 4.0, 3.0, 2.0, 1.0]);
633
634        let datasets = vec![
635            ("Data A", data1.view()),
636            ("Data B", data2.view()),
637            ("Data C", data3.view()),
638        ];
639
640        let config = PlotConfig::new()
641            .with_format(ReportFormat::Text)
642            .with_title("Correlation Test");
643
644        let result = plot_correlation_matrix(&datasets, &config);
645        assert!(result.is_ok());
646
647        let matrix = result.expect("Operation failed");
648        assert!(matrix.contains("Correlation Test"));
649        assert!(matrix.contains("Data A"));
650        assert!(matrix.contains("Data B"));
651        assert!(matrix.contains("Data C"));
652    }
653
654    #[test]
655    fn test_empty_image_montage() {
656        let images: Vec<scirs2_core::ndarray::ArrayView2<f64>> = vec![];
657        let config = PlotConfig::new();
658
659        let result = create_image_montage(&images, 2, &config);
660        assert!(result.is_err());
661        assert!(result
662            .unwrap_err()
663            .to_string()
664            .contains("No images provided"));
665    }
666
667    #[test]
668    fn test_zero_grid_cols() {
669        let img = Array2::<f64>::zeros((5, 5));
670        let images = vec![img.view()];
671        let config = PlotConfig::new();
672
673        let result = create_image_montage(&images, 0, &config);
674        assert!(result.is_err());
675        assert!(result
676            .unwrap_err()
677            .to_string()
678            .contains("Grid columns must be positive"));
679    }
680}