1use ndarray::*;
12use std::fs::File;
13use std::io::Write;
14
15use crate::Label;
16use crate::estimates::*;
17use crate::utils::{prepare_data, estimate_random_guessing,has_integer_support};
18
19
20pub enum Logger<T> {
22 LogFile(File),
23 LogVec(Vec<T>),
24}
25
26pub fn run_fbleau(train_x: Array2<f64>, train_y: Array1<Label>,
29 test_x: Array2<f64>, test_y: Array1<Label>,
30 estimate: Estimate, knn_strategy: Option<KNNStrategy>,
31 distance: Option<String>,
32 error_logger: &mut Option<Logger<f64>>,
33 individual_error_logger: &mut Option<Logger<bool>>,
34 delta: Option<f64>, qstop: Option<usize>, absolute: bool,
35 scale: bool)
36 -> (f64, f64, f64) {
37
38 let (train_x, train_y, test_x, test_y, nlabels) =
40 prepare_data(train_x, train_y, test_x, test_y, scale);
41
42 let convergence_checker = if let Some(delta) = delta {
44 let q = match qstop {
45 Some(q) => q,
46 None => (train_x.len() as f64 * 0.1) as usize,
47 };
48 println!("will stop when (delta={}, q={})-converged", delta, q);
49 Some(ForwardChecker::new(&[delta], q, !absolute))
50 } else if qstop.is_some() {
51 panic!("--qstop should only be specified with --delta");
52 } else {
53 None
55 };
56
57 let random_guessing = estimate_random_guessing(&test_y.view());
59 println!("Random guessing error: {}", random_guessing);
60 println!("Estimating leakage measures...");
61
62 let distance = match distance.as_ref().map(String::as_ref) {
64 Some("euclidean") => euclidean_distance,
65 Some("levenshtein") => levenshtein_distance,
66 _ => euclidean_distance,
67 };
68
69 let (min_error, last_error) = match estimate {
71 Estimate::Frequentist => {
72 if !has_integer_support(&train_x) || !has_integer_support(&test_x) {
73 println!("Warning: frequentist discouraged for continuous observations!");
74 }
75 let estimator = FrequentistEstimator::new(nlabels,
76 &test_x.view(),
77 &test_y.view());
78 run_forward_strategy(estimator, convergence_checker, error_logger,
79 individual_error_logger, train_x, train_y)
80 },
81 Estimate::NN => {
82 if !has_integer_support(&train_x) || !has_integer_support(&test_x) {
83 println!("Warning: NN discouraged for continuous observations!");
84 }
85 let estimator = KNNEstimator::new(&test_x.view(), &test_y.view(),
86 train_x.nrows(), distance,
87 KNNStrategy::NN);
88 run_forward_strategy(estimator, convergence_checker, error_logger,
89 individual_error_logger, train_x, train_y)
90 },
91 Estimate::KNN => {
92 let estimator = KNNEstimator::new(&test_x.view(), &test_y.view(),
93 train_x.nrows(), distance,
94 knn_strategy.expect(
95 "Specify a k-NN strategy."));
96 run_forward_strategy(estimator, convergence_checker, error_logger,
97 individual_error_logger, train_x, train_y)
98 },
99 Estimate::NNBound => {
100 let estimator = NNBoundEstimator::new(&test_x.view(), &test_y.view(),
101 distance, nlabels);
102 run_forward_strategy(estimator, convergence_checker, error_logger,
103 individual_error_logger, train_x, train_y)
104 },
105 };
106 (min_error, last_error, random_guessing)
107}
108
109fn run_forward_strategy<E>(mut estimator: E,
119 mut convergence_checker: Option<ForwardChecker>,
120 error_logger: &mut Option<Logger<f64>>,
121 individual_error_logger: &mut Option<Logger<bool>>,
122 train_x: Array2<f64>, train_y: Array1<Label>)
123 -> (f64, f64)
124where E: BayesEstimator {
125 if let Some(ref mut logger) = error_logger {
127 if let Logger::LogFile(file) = logger {
128 writeln!(file, "n, error-count, estimate")
129 .expect("Could not write to log file");
130 }
131 }
132
133 let mut min_error = 1.0;
135 let mut last_error = 1.0;
136 let mut min_individual_errors = vec![];
137
138 for (n, (x, y)) in train_x.outer_iter().zip(train_y.iter()).enumerate() {
139 estimator.add_example(&x, *y)
141 .expect("Could not add more examples.");
142 last_error = estimator.get_error();
143
144 if min_error > last_error {
145 min_error = last_error;
146 min_individual_errors = estimator.get_individual_errors();
147 }
148
149 if let Some(ref mut logger) = error_logger {
151 match logger {
152 Logger::LogFile(file) =>
153 writeln!(file, "{}, {}, {}", n,
154 estimator.get_error_count(), last_error)
155 .expect("Could not write to log file"),
156 Logger::LogVec(v) => v.push(last_error),
157 }
158 }
159
160 if let Some(ref mut checker) = convergence_checker {
162 checker.add_estimate(last_error);
163 if checker.all_converged() {
164 break;
165 }
166 }
167 }
168
169 if let Some(logger) = individual_error_logger {
171 match logger {
172 Logger::LogFile(file) =>
173 writeln!(file, "{:?}", min_individual_errors)
174 .expect("Could not write to log file"),
175 Logger::LogVec(v) => v.extend(min_individual_errors.iter()
176 .cloned()),
177 }
178 }
179
180 (min_error, last_error)
181}