1use crate::Progress;
2
3use super::{
4 shap::{compute_shap_values_for_example, ComputeShapValuesForExampleOutput},
5 train_early_stopping_split, EarlyStoppingMonitor, TrainOptions, TrainProgressEvent,
6};
7use ndarray::{self, prelude::*};
8use num::{clamp, ToPrimitive};
9use rayon::{self, prelude::*};
10use std::num::NonZeroUsize;
11use tangram_metrics::{CrossEntropy, CrossEntropyInput};
12use tangram_progress_counter::ProgressCounter;
13use tangram_table::prelude::*;
14use tangram_zip::{pzip, zip};
15
16#[derive(Clone, Debug)]
18pub struct MulticlassClassifier {
19 pub biases: Array1<f32>,
21 pub weights: Array2<f32>,
23 pub means: Vec<f32>,
25}
26
27pub struct MulticlassClassifierTrainOutput {
29 pub model: MulticlassClassifier,
31 pub losses: Option<Vec<f32>>,
33 pub feature_importances: Option<Vec<f32>>,
35}
36
37impl MulticlassClassifier {
38 pub fn train(
40 features: ArrayView2<f32>,
41 labels: EnumTableColumnView,
42 train_options: &TrainOptions,
43 progress: Progress,
44 ) -> MulticlassClassifierTrainOutput {
45 let n_classes = labels.variants().len();
46 let n_features = features.ncols();
47 let (features_train, labels_train, features_early_stopping, labels_early_stopping) =
48 train_early_stopping_split(
49 features,
50 labels.as_slice().into(),
51 train_options
52 .early_stopping_options
53 .as_ref()
54 .map(|o| o.early_stopping_fraction)
55 .unwrap_or(0.0),
56 );
57 let means = features_train
58 .axis_iter(Axis(1))
59 .map(|column| column.mean().unwrap())
60 .collect();
61 let mut model = MulticlassClassifier {
62 biases: <Array1<f32>>::zeros(n_classes),
63 weights: <Array2<f32>>::zeros((n_features, n_classes)),
64 means,
65 };
66 let mut early_stopping_monitor =
67 train_options
68 .early_stopping_options
69 .as_ref()
70 .map(|early_stopping_options| {
71 EarlyStoppingMonitor::new(
72 early_stopping_options.min_decrease_in_loss_for_significant_change,
73 early_stopping_options.n_rounds_without_improvement_to_stop,
74 )
75 });
76 let progress_counter = ProgressCounter::new(train_options.max_epochs.to_u64().unwrap());
77 (progress.handle_progress_event)(TrainProgressEvent::Train(progress_counter.clone()));
78 let mut probabilities_buffer: Array2<f32> = Array2::zeros((labels.len(), n_classes));
79 let mut losses = if train_options.compute_losses {
80 Some(Vec::new())
81 } else {
82 None
83 };
84 let kill_chip = progress.kill_chip;
85 for _ in 0..train_options.max_epochs {
86 progress_counter.inc(1);
87 let n_examples_per_batch = train_options.n_examples_per_batch;
88 struct MulticlassClassifierPtr(*mut MulticlassClassifier);
89 unsafe impl Send for MulticlassClassifierPtr {}
90 unsafe impl Sync for MulticlassClassifierPtr {}
91 let model_ptr = MulticlassClassifierPtr(&mut model);
92 pzip!(
93 features_train.axis_chunks_iter(Axis(0), n_examples_per_batch),
94 labels_train.axis_chunks_iter(Axis(0), n_examples_per_batch),
95 probabilities_buffer.axis_chunks_iter_mut(Axis(0), n_examples_per_batch),
96 )
97 .for_each(|(features, labels, probabilities)| {
98 let model = unsafe { &mut *model_ptr.0 };
99 MulticlassClassifier::train_batch(
100 model,
101 features,
102 labels,
103 probabilities,
104 train_options,
105 kill_chip,
106 );
107 });
108 if let Some(losses) = &mut losses {
109 let loss =
110 MulticlassClassifier::compute_loss(probabilities_buffer.view(), labels_train);
111 losses.push(loss);
112 }
113 if let Some(early_stopping_monitor) = early_stopping_monitor.as_mut() {
114 let early_stopping_metric_value =
115 MulticlassClassifier::compute_early_stopping_metric_value(
116 &model,
117 features_early_stopping,
118 labels_early_stopping,
119 train_options,
120 );
121 let should_stop = early_stopping_monitor.update(early_stopping_metric_value);
122 if should_stop {
123 break;
124 }
125 }
126 if progress.kill_chip.is_activated() {
128 break;
129 }
130 }
131 (progress.handle_progress_event)(TrainProgressEvent::TrainDone);
132 let feature_importances = MulticlassClassifier::compute_feature_importances(&model);
133 MulticlassClassifierTrainOutput {
134 model,
135 losses,
136 feature_importances: Some(feature_importances),
137 }
138 }
139
140 fn compute_feature_importances(model: &MulticlassClassifier) -> Vec<f32> {
141 let mut feature_importances = model
143 .weights
144 .axis_iter(Axis(0))
145 .map(|weights_each_class| {
146 weights_each_class
147 .iter()
148 .map(|weight| weight.abs())
149 .sum::<f32>() / model.weights.ncols().to_f32().unwrap()
150 })
151 .collect::<Vec<_>>();
152 let feature_importances_sum: f32 = feature_importances.iter().sum::<f32>();
154 feature_importances
155 .iter_mut()
156 .for_each(|feature_importance| *feature_importance /= feature_importances_sum);
157 feature_importances
158 }
159
160 fn train_batch(
161 &mut self,
162 features: ArrayView2<f32>,
163 labels: ArrayView1<Option<NonZeroUsize>>,
164 mut probabilities: ArrayViewMut2<f32>,
165 train_options: &TrainOptions,
166 kill_chip: &tangram_kill_chip::KillChip,
167 ) {
168 if kill_chip.is_activated() {
169 return;
170 }
171 let learning_rate = train_options.learning_rate;
172 let n_classes = self.weights.ncols();
173 let mut logits = features.dot(&self.weights) + &self.biases;
174 softmax(logits.view_mut());
175 for (probability, logit) in zip!(probabilities.iter_mut(), logits.iter()) {
176 *probability = *logit;
177 }
178 let mut predictions = logits;
179 for (mut predictions, label) in zip!(predictions.axis_iter_mut(Axis(0)), labels) {
180 for (class_index, prediction) in predictions.iter_mut().enumerate() {
181 *prediction -= if class_index == label.unwrap().get() - 1 {
182 1.0
183 } else {
184 0.0
185 };
186 }
187 }
188 let py = predictions;
189 for class_index in 0..n_classes {
190 let weight_gradients = (&features * &py.column(class_index).insert_axis(Axis(1)))
191 .mean_axis(Axis(0))
192 .unwrap();
193 for (weight, weight_gradient) in zip!(
194 self.weights.column_mut(class_index),
195 weight_gradients.iter()
196 ) {
197 *weight += -learning_rate * weight_gradient
198 }
199 let bias_gradients = py
200 .column(class_index)
201 .insert_axis(Axis(1))
202 .mean_axis(Axis(0))
203 .unwrap();
204 self.biases[class_index] += -learning_rate * bias_gradients[0];
205 }
206 }
207
208 pub fn compute_loss(
209 probabilities: ArrayView2<f32>,
210 labels: ArrayView1<Option<NonZeroUsize>>,
211 ) -> f32 {
212 let mut loss = 0.0;
213 for (label, probabilities) in zip!(labels.into_iter(), probabilities.axis_iter(Axis(0))) {
214 for (index, &probability) in probabilities.indexed_iter() {
215 let probability = clamp(probability, std::f32::EPSILON, 1.0 - std::f32::EPSILON);
216 if index == (label.unwrap().get() - 1) {
217 loss += -probability.ln();
218 }
219 }
220 }
221 loss / labels.len().to_f32().unwrap()
222 }
223
224 fn compute_early_stopping_metric_value(
225 &self,
226 features: ArrayView2<f32>,
227 labels: ArrayView1<Option<NonZeroUsize>>,
228 train_options: &TrainOptions,
229 ) -> f32 {
230 let n_classes = self.biases.len();
231 pzip!(
232 features.axis_chunks_iter(Axis(0), train_options.n_examples_per_batch),
233 labels.axis_chunks_iter(Axis(0), train_options.n_examples_per_batch),
234 )
235 .fold(
236 || {
237 let predictions = unsafe {
238 <Array2<f32>>::uninit((train_options.n_examples_per_batch, n_classes))
239 .assume_init()
240 };
241 let metric = CrossEntropy::default();
242 (predictions, metric)
243 },
244 |(mut predictions, mut metric), (features, labels)| {
245 let slice = s![0..features.nrows(), ..];
246 let mut predictions_slice = predictions.slice_mut(slice);
247 self.predict(features, predictions_slice.view_mut());
248 for (prediction, label) in zip!(predictions_slice.axis_iter(Axis(0)), labels.iter())
249 {
250 metric.update(CrossEntropyInput {
251 probabilities: prediction,
252 label: *label,
253 });
254 }
255 (predictions, metric)
256 },
257 )
258 .map(|(_, metric)| metric)
259 .reduce(CrossEntropy::new, |mut a, b| {
260 a.merge(b);
261 a
262 })
263 .finalize()
264 .0
265 .unwrap()
266 }
267
268 pub fn predict(&self, features: ArrayView2<f32>, mut probabilities: ArrayViewMut2<f32>) {
270 for mut row in probabilities.axis_iter_mut(Axis(0)) {
271 row.assign(&self.biases.view());
272 }
273 ndarray::linalg::general_mat_mul(1.0, &features, &self.weights, 1.0, &mut probabilities);
274 softmax(probabilities);
275 }
276
277 pub fn compute_feature_contributions(
278 &self,
279 features: ArrayView2<f32>,
280 ) -> Vec<Vec<ComputeShapValuesForExampleOutput>> {
281 features
282 .axis_iter(Axis(0))
283 .map(|features| {
284 zip!(self.weights.axis_iter(Axis(1)), self.biases.view())
285 .map(|(weights, bias)| {
286 compute_shap_values_for_example(
287 features.as_slice().unwrap(),
288 *bias,
289 weights.view(),
290 &self.means,
291 )
292 })
293 .collect()
294 })
295 .collect()
296 }
297
298 pub fn from_reader(
299 multiclass_classifier: crate::serialize::MulticlassClassifierReader,
300 ) -> MulticlassClassifier {
301 crate::serialize::deserialize_multiclass_classifier(multiclass_classifier)
302 }
303
304 pub fn to_writer(
305 &self,
306 writer: &mut buffalo::Writer,
307 ) -> buffalo::Position<crate::serialize::MulticlassClassifierWriter> {
308 crate::serialize::serialize_multiclass_classifier(self, writer)
309 }
310
311 pub fn from_bytes(&self, bytes: &[u8]) -> MulticlassClassifier {
312 let reader = buffalo::read::<crate::serialize::MulticlassClassifierReader>(bytes);
313 Self::from_reader(reader)
314 }
315
316 pub fn to_bytes(&self) -> Vec<u8> {
317 let mut writer = buffalo::Writer::new();
319 self.to_writer(&mut writer);
320 writer.into_bytes()
321 }
322}
323
324fn softmax(mut logits: ArrayViewMut2<f32>) {
325 for mut logits in logits.axis_iter_mut(Axis(0)) {
326 let max = logits.iter().fold(std::f32::MIN, |a, &b| f32::max(a, b));
327 for logit in logits.iter_mut() {
328 *logit = (*logit - max).exp();
329 }
330 let sum = logits.iter().sum::<f32>();
331 for logit in logits.iter_mut() {
332 *logit /= sum;
333 }
334 }
335}