filtration_domination/datasets/
mod.rs

1//! Dataset reading and sampling.
2//!
3//! The main entry point is [get_dataset_density_edge_list], which returns a bifiltered edge list.
4use ordered_float::OrderedFloat;
5use std::cmp::max;
6use std::fmt::Formatter;
7use std::io;
8use thiserror::Error;
9
10use crate::datasets::distance_matrices::get_dataset_distance_matrix;
11use crate::distance_matrix::density_estimation::DensityEstimator;
12use crate::distance_matrix::DistanceMatrix;
13use crate::edges::{BareEdge, EdgeList, FilteredEdge};
14use crate::{OneCriticalGrade, Value};
15
16mod distance_matrices;
17mod sampling;
18
19const DATASET_DIRECTORY: &str = "datasets";
20
21/// All datasets that we support.
22#[derive(Debug, Copy, Clone)]
23pub enum Dataset {
24    /// The senate dataset from <https://github.com/n-otter/PH-roadmap>.
25    Senate,
26    /// The eleg dataset from <https://github.com/n-otter/PH-roadmap>.
27    Eleg,
28    /// The netwsc dataset from <https://github.com/n-otter/PH-roadmap>.
29    Netwsc,
30    /// The hiv dataset from <https://github.com/n-otter/PH-roadmap>.
31    Hiv,
32    /// The dragon dataset from <https://github.com/n-otter/PH-roadmap>.
33    Dragon,
34    /// A circle in R^2.
35    Circle { n_points: usize },
36    /// A noisy sphere in R^3.
37    Sphere { n_points: usize },
38    /// A torus sphere in R^3.
39    Torus { n_points: usize },
40    /// A swiss roll, that is, a plane rolled up in a spiral in R^3.
41    SwissRoll { n_points: usize },
42    /// Points sampled uniformly from a square in the plane.
43    Uniform { n_points: usize },
44}
45
46impl std::fmt::Display for Dataset {
47    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
48        match self {
49            Dataset::Senate => {
50                write!(f, "senate")
51            }
52            Dataset::Eleg => {
53                write!(f, "eleg")
54            }
55            Dataset::Netwsc => {
56                write!(f, "netwsc")
57            }
58            Dataset::Hiv => {
59                write!(f, "hiv")
60            }
61            Dataset::Dragon => {
62                write!(f, "dragon")
63            }
64            Dataset::Circle { n_points } => {
65                write!(f, "circle({n_points})")
66            }
67            Dataset::Sphere { n_points } => {
68                write!(f, "sphere({n_points})")
69            }
70            Dataset::Torus { n_points } => {
71                write!(f, "torus({n_points})")
72            }
73            Dataset::SwissRoll { n_points } => {
74                write!(f, "swiss-roll({n_points})")
75            }
76            Dataset::Uniform { n_points } => {
77                write!(f, "uniform({n_points})")
78            }
79        }
80    }
81}
82
83/// Possible thresholding settings.
84#[derive(Debug, Copy, Clone)]
85pub enum Threshold {
86    /// Keep all edges.
87    KeepAll,
88    /// Restrict to the edges of length less than the given percentile of all distances.
89    Percentile(f64),
90    /// Restrict to the edges of length less that the given value.
91    Fixed(f64),
92}
93
94/// Error when reading or creating a dataset.
95#[derive(Error, Debug)]
96pub enum DatasetError {
97    #[error("Couldn't find file \"{0}\". Did you download the datasets?")]
98    FileNotFound(String),
99
100    #[error(transparent)]
101    Io(#[from] io::Error),
102}
103
104/// Return the edge list of the associated dataset. Each edge is bifiltered by codensity and length.
105/// Codensity means that we order the density parameter from densest to least dense.
106///
107/// Possibly removes some edges according to `threshold`. See [Threshold].
108/// If a `estimator` is not provided, the function uses the Gaussian kernel estimator with
109/// bandwidth parameter set to the 20th percentile of the distances.
110/// If `use_cache` is set, the function caches the distance matrices of the sampled datasets.
111pub fn get_dataset_density_edge_list(
112    dataset: Dataset,
113    threshold: Threshold,
114    estimator: Option<DensityEstimator<OrderedFloat<f64>>>,
115    use_cache: bool,
116) -> Result<EdgeList<FilteredEdge<OneCriticalGrade<OrderedFloat<f64>, 2>>>, DatasetError> {
117    let distance_matrix = get_dataset_distance_matrix(dataset, use_cache)?;
118
119    let estimator = estimator.unwrap_or_else(|| default_estimator(&distance_matrix));
120    let mut estimations = estimator.estimate(&distance_matrix);
121    // Instead of working with densities, we work with codensities. That is, smaller values correspond
122    // to higher density estimations.
123    for e in estimations.iter_mut() {
124        *e = OrderedFloat::from(1.0) - *e;
125    }
126
127    let edges = distance_matrices::get_distance_matrix_edge_list(&distance_matrix, threshold);
128
129    let density_edges_it = edges.edges().iter().map(|edge| {
130        let FilteredEdge {
131            grade: OneCriticalGrade([dist]),
132            edge: BareEdge(u, v),
133        } = edge;
134
135        // The edge density is the max of the codensity of its vertices.
136        let edge_density = max(estimations[*u], estimations[*v]);
137
138        FilteredEdge {
139            grade: OneCriticalGrade([edge_density, *dist]),
140            edge: BareEdge(*u, *v),
141        }
142    });
143
144    Ok(EdgeList::from_iterator(density_edges_it))
145}
146
147fn default_estimator<F: Value + std::fmt::Display>(
148    matrix: &DistanceMatrix<F>,
149) -> DensityEstimator<F> {
150    let bandwidth = matrix.percentile(0.2);
151    DensityEstimator::Gaussian(*bandwidth)
152}