filtration_domination/datasets/
mod.rs1use ordered_float::OrderedFloat;
5use std::cmp::max;
6use std::fmt::Formatter;
7use std::io;
8use thiserror::Error;
9
10use crate::datasets::distance_matrices::get_dataset_distance_matrix;
11use crate::distance_matrix::density_estimation::DensityEstimator;
12use crate::distance_matrix::DistanceMatrix;
13use crate::edges::{BareEdge, EdgeList, FilteredEdge};
14use crate::{OneCriticalGrade, Value};
15
16mod distance_matrices;
17mod sampling;
18
19const DATASET_DIRECTORY: &str = "datasets";
20
21#[derive(Debug, Copy, Clone)]
23pub enum Dataset {
24 Senate,
26 Eleg,
28 Netwsc,
30 Hiv,
32 Dragon,
34 Circle { n_points: usize },
36 Sphere { n_points: usize },
38 Torus { n_points: usize },
40 SwissRoll { n_points: usize },
42 Uniform { n_points: usize },
44}
45
46impl std::fmt::Display for Dataset {
47 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
48 match self {
49 Dataset::Senate => {
50 write!(f, "senate")
51 }
52 Dataset::Eleg => {
53 write!(f, "eleg")
54 }
55 Dataset::Netwsc => {
56 write!(f, "netwsc")
57 }
58 Dataset::Hiv => {
59 write!(f, "hiv")
60 }
61 Dataset::Dragon => {
62 write!(f, "dragon")
63 }
64 Dataset::Circle { n_points } => {
65 write!(f, "circle({n_points})")
66 }
67 Dataset::Sphere { n_points } => {
68 write!(f, "sphere({n_points})")
69 }
70 Dataset::Torus { n_points } => {
71 write!(f, "torus({n_points})")
72 }
73 Dataset::SwissRoll { n_points } => {
74 write!(f, "swiss-roll({n_points})")
75 }
76 Dataset::Uniform { n_points } => {
77 write!(f, "uniform({n_points})")
78 }
79 }
80 }
81}
82
83#[derive(Debug, Copy, Clone)]
85pub enum Threshold {
86 KeepAll,
88 Percentile(f64),
90 Fixed(f64),
92}
93
94#[derive(Error, Debug)]
96pub enum DatasetError {
97 #[error("Couldn't find file \"{0}\". Did you download the datasets?")]
98 FileNotFound(String),
99
100 #[error(transparent)]
101 Io(#[from] io::Error),
102}
103
104pub fn get_dataset_density_edge_list(
112 dataset: Dataset,
113 threshold: Threshold,
114 estimator: Option<DensityEstimator<OrderedFloat<f64>>>,
115 use_cache: bool,
116) -> Result<EdgeList<FilteredEdge<OneCriticalGrade<OrderedFloat<f64>, 2>>>, DatasetError> {
117 let distance_matrix = get_dataset_distance_matrix(dataset, use_cache)?;
118
119 let estimator = estimator.unwrap_or_else(|| default_estimator(&distance_matrix));
120 let mut estimations = estimator.estimate(&distance_matrix);
121 for e in estimations.iter_mut() {
124 *e = OrderedFloat::from(1.0) - *e;
125 }
126
127 let edges = distance_matrices::get_distance_matrix_edge_list(&distance_matrix, threshold);
128
129 let density_edges_it = edges.edges().iter().map(|edge| {
130 let FilteredEdge {
131 grade: OneCriticalGrade([dist]),
132 edge: BareEdge(u, v),
133 } = edge;
134
135 let edge_density = max(estimations[*u], estimations[*v]);
137
138 FilteredEdge {
139 grade: OneCriticalGrade([edge_density, *dist]),
140 edge: BareEdge(*u, *v),
141 }
142 });
143
144 Ok(EdgeList::from_iterator(density_edges_it))
145}
146
147fn default_estimator<F: Value + std::fmt::Display>(
148 matrix: &DistanceMatrix<F>,
149) -> DensityEstimator<F> {
150 let bandwidth = matrix.percentile(0.2);
151 DensityEstimator::Gaussian(*bandwidth)
152}