fbleau/
fbleau_estimation.rs

1//! F-BLEAU estimation routines.
2
3/// Main estimation routine.
4///
5/// Given training and evaluation (test) data, this function runs
6/// the desired estimator and returns a tuple containing:
7///     - smallest estimate
8///     - final estimate (i.e., the estimate when all the training data was
9///       available)
10///     - random guessing error.
11use ndarray::*;
12use std::fs::File;
13use std::io::Write;
14
15use crate::Label;
16use crate::estimates::*;
17use crate::utils::{prepare_data, estimate_random_guessing,has_integer_support};
18
19
20/// Log data either to a .csv file or into a Vec.
21pub enum Logger<T> {
22    LogFile(File),
23    LogVec(Vec<T>),
24}
25
26/// Prepares everything for running F-BLEAU, and runs a forward
27/// estimation strategy.
28pub fn run_fbleau(train_x: Array2<f64>, train_y: Array1<Label>,
29                  test_x: Array2<f64>, test_y: Array1<Label>,
30                  estimate: Estimate, knn_strategy: Option<KNNStrategy>,
31                  distance: Option<String>,
32                  error_logger: &mut Option<Logger<f64>>,
33                  individual_error_logger: &mut Option<Logger<bool>>,
34                  delta: Option<f64>, qstop: Option<usize>, absolute: bool,
35                  scale: bool)
36        -> (f64, f64, f64) {
37
38    // Check label's indexes, and scale data if required.
39    let (train_x, train_y, test_x, test_y, nlabels) =
40        prepare_data(train_x, train_y, test_x, test_y, scale);
41
42    // Convergence with (delta, q)-convergence checker.
43    let convergence_checker = if let Some(delta) = delta {
44        let q = match qstop {
45            Some(q) => q,
46            None => (train_x.len() as f64 * 0.1) as usize,
47        };
48        println!("will stop when (delta={}, q={})-converged", delta, q);
49        Some(ForwardChecker::new(&[delta], q, !absolute))
50    } else if qstop.is_some() {
51        panic!("--qstop should only be specified with --delta");
52    } else {
53        // No convergence checker (i.e., run for all training data).
54        None
55    };
56
57    // Random guessing error.
58    let random_guessing = estimate_random_guessing(&test_y.view());
59    println!("Random guessing error: {}", random_guessing);
60    println!("Estimating leakage measures...");
61
62    // Distance for k-NN (defaults to Euclidean).
63    let distance = match distance.as_ref().map(String::as_ref) {
64        Some("euclidean") => euclidean_distance,
65        Some("levenshtein") => levenshtein_distance,
66        _ => euclidean_distance,
67    };
68
69    // Init estimator and run.
70    let (min_error, last_error) = match estimate {
71        Estimate::Frequentist => {
72            if !has_integer_support(&train_x) || !has_integer_support(&test_x) {
73                println!("Warning: frequentist discouraged for continuous observations!");
74            }
75            let estimator = FrequentistEstimator::new(nlabels,
76                                                      &test_x.view(),
77                                                      &test_y.view());
78            run_forward_strategy(estimator, convergence_checker, error_logger,
79                                 individual_error_logger, train_x, train_y)
80            },
81        Estimate::NN => {
82            if !has_integer_support(&train_x) || !has_integer_support(&test_x) {
83                println!("Warning: NN discouraged for continuous observations!");
84            }
85            let estimator = KNNEstimator::new(&test_x.view(), &test_y.view(),
86                                              train_x.nrows(), distance,
87                                              KNNStrategy::NN);
88            run_forward_strategy(estimator, convergence_checker, error_logger,
89                                 individual_error_logger, train_x, train_y)
90            },
91        Estimate::KNN => {
92            let estimator = KNNEstimator::new(&test_x.view(), &test_y.view(),
93                                              train_x.nrows(), distance,
94                                              knn_strategy.expect(
95                                                  "Specify a k-NN strategy."));
96            run_forward_strategy(estimator, convergence_checker, error_logger,
97                                 individual_error_logger, train_x, train_y)
98            },
99        Estimate::NNBound => {
100            let estimator = NNBoundEstimator::new(&test_x.view(), &test_y.view(),
101                                                  distance, nlabels);
102            run_forward_strategy(estimator, convergence_checker, error_logger,
103                                 individual_error_logger, train_x, train_y)
104            },
105    };
106    (min_error, last_error, random_guessing)
107}
108
109/// Forward strategy for estimation.
110///
111/// Estimates security measures with a forward strategy: the estimator
112/// is trained with an increasing number of examples, and its estimate
113/// is progressively logged.
114/// This function returns:
115///     - smallest estimate
116///     - final estimate (i.e., the estimate when all the training data was
117///       available).
118fn run_forward_strategy<E>(mut estimator: E,
119                           mut convergence_checker: Option<ForwardChecker>,
120                           error_logger: &mut Option<Logger<f64>>,
121                           individual_error_logger: &mut Option<Logger<bool>>,
122                           train_x: Array2<f64>, train_y: Array1<Label>)
123        -> (f64, f64)
124where E: BayesEstimator {
125    // Init logfile, if specified.
126    if let Some(ref mut logger) = error_logger {
127        if let Logger::LogFile(file) = logger {
128            writeln!(file, "n, error-count, estimate")
129                    .expect("Could not write to log file");
130        }
131    }
132
133    // We keep track both of the minimum and of the last estimate.
134    let mut min_error = 1.0;
135    let mut last_error = 1.0;
136    let mut min_individual_errors = vec![];
137
138    for (n, (x, y)) in train_x.outer_iter().zip(train_y.iter()).enumerate() {
139        // Compute error.
140        estimator.add_example(&x, *y)
141                 .expect("Could not add more examples.");
142        last_error = estimator.get_error();
143
144        if min_error > last_error {
145            min_error = last_error;
146            min_individual_errors = estimator.get_individual_errors();
147        }
148
149        // Log current error.
150        if let Some(ref mut logger) = error_logger {
151            match logger {
152                Logger::LogFile(file) =>
153                    writeln!(file, "{}, {}, {}", n,
154                             estimator.get_error_count(), last_error)
155                        .expect("Could not write to log file"),
156                Logger::LogVec(v) => v.push(last_error),
157            }
158        }
159
160        // Should we stop because of (delta, q)-convergence?
161        if let Some(ref mut checker) = convergence_checker {
162            checker.add_estimate(last_error);
163            if checker.all_converged() {
164                break;
165            }
166        }
167    }
168
169    // Log individual test errors.
170    if let Some(logger) = individual_error_logger {
171        match logger {
172            Logger::LogFile(file) =>
173                writeln!(file, "{:?}", min_individual_errors)
174                    .expect("Could not write to log file"),
175            Logger::LogVec(v) => v.extend(min_individual_errors.iter()
176                                                               .cloned()),
177        }
178    }
179
180    (min_error, last_error)
181}