tangram_linear/
multiclass_classifier.rs

1use crate::Progress;
2
3use super::{
4	shap::{compute_shap_values_for_example, ComputeShapValuesForExampleOutput},
5	train_early_stopping_split, EarlyStoppingMonitor, TrainOptions, TrainProgressEvent,
6};
7use ndarray::{self, prelude::*};
8use num::{clamp, ToPrimitive};
9use rayon::{self, prelude::*};
10use std::num::NonZeroUsize;
11use tangram_metrics::{CrossEntropy, CrossEntropyInput};
12use tangram_progress_counter::ProgressCounter;
13use tangram_table::prelude::*;
14use tangram_zip::{pzip, zip};
15
16/// This struct describes a linear multiclass classifier model. You can train one by calling `MulticlassClassifier::train`.
17#[derive(Clone, Debug)]
18pub struct MulticlassClassifier {
19	/// These are the biases the model learned.
20	pub biases: Array1<f32>,
21	/// These are the weights the model learned. The shape is (n_features, n_classes).
22	pub weights: Array2<f32>,
23	/// These are the mean values of each feature in the training set. They are used to compute SHAP values.
24	pub means: Vec<f32>,
25}
26
27/// This struct is returned by `MulticlassClassifier::train`.
28pub struct MulticlassClassifierTrainOutput {
29	/// This is the model you just trained.
30	pub model: MulticlassClassifier,
31	/// These are the loss values for each epoch.
32	pub losses: Option<Vec<f32>>,
33	/// These are the importances of each feature.
34	pub feature_importances: Option<Vec<f32>>,
35}
36
37impl MulticlassClassifier {
38	/// Train a linear multiclass classifier.
39	pub fn train(
40		features: ArrayView2<f32>,
41		labels: EnumTableColumnView,
42		train_options: &TrainOptions,
43		progress: Progress,
44	) -> MulticlassClassifierTrainOutput {
45		let n_classes = labels.variants().len();
46		let n_features = features.ncols();
47		let (features_train, labels_train, features_early_stopping, labels_early_stopping) =
48			train_early_stopping_split(
49				features,
50				labels.as_slice().into(),
51				train_options
52					.early_stopping_options
53					.as_ref()
54					.map(|o| o.early_stopping_fraction)
55					.unwrap_or(0.0),
56			);
57		let means = features_train
58			.axis_iter(Axis(1))
59			.map(|column| column.mean().unwrap())
60			.collect();
61		let mut model = MulticlassClassifier {
62			biases: <Array1<f32>>::zeros(n_classes),
63			weights: <Array2<f32>>::zeros((n_features, n_classes)),
64			means,
65		};
66		let mut early_stopping_monitor =
67			train_options
68				.early_stopping_options
69				.as_ref()
70				.map(|early_stopping_options| {
71					EarlyStoppingMonitor::new(
72						early_stopping_options.min_decrease_in_loss_for_significant_change,
73						early_stopping_options.n_rounds_without_improvement_to_stop,
74					)
75				});
76		let progress_counter = ProgressCounter::new(train_options.max_epochs.to_u64().unwrap());
77		(progress.handle_progress_event)(TrainProgressEvent::Train(progress_counter.clone()));
78		let mut probabilities_buffer: Array2<f32> = Array2::zeros((labels.len(), n_classes));
79		let mut losses = if train_options.compute_losses {
80			Some(Vec::new())
81		} else {
82			None
83		};
84		let kill_chip = progress.kill_chip;
85		for _ in 0..train_options.max_epochs {
86			progress_counter.inc(1);
87			let n_examples_per_batch = train_options.n_examples_per_batch;
88			struct MulticlassClassifierPtr(*mut MulticlassClassifier);
89			unsafe impl Send for MulticlassClassifierPtr {}
90			unsafe impl Sync for MulticlassClassifierPtr {}
91			let model_ptr = MulticlassClassifierPtr(&mut model);
92			pzip!(
93				features_train.axis_chunks_iter(Axis(0), n_examples_per_batch),
94				labels_train.axis_chunks_iter(Axis(0), n_examples_per_batch),
95				probabilities_buffer.axis_chunks_iter_mut(Axis(0), n_examples_per_batch),
96			)
97			.for_each(|(features, labels, probabilities)| {
98				let model = unsafe { &mut *model_ptr.0 };
99				MulticlassClassifier::train_batch(
100					model,
101					features,
102					labels,
103					probabilities,
104					train_options,
105					kill_chip,
106				);
107			});
108			if let Some(losses) = &mut losses {
109				let loss =
110					MulticlassClassifier::compute_loss(probabilities_buffer.view(), labels_train);
111				losses.push(loss);
112			}
113			if let Some(early_stopping_monitor) = early_stopping_monitor.as_mut() {
114				let early_stopping_metric_value =
115					MulticlassClassifier::compute_early_stopping_metric_value(
116						&model,
117						features_early_stopping,
118						labels_early_stopping,
119						train_options,
120					);
121				let should_stop = early_stopping_monitor.update(early_stopping_metric_value);
122				if should_stop {
123					break;
124				}
125			}
126			// Check if we should stop training.
127			if progress.kill_chip.is_activated() {
128				break;
129			}
130		}
131		(progress.handle_progress_event)(TrainProgressEvent::TrainDone);
132		let feature_importances = MulticlassClassifier::compute_feature_importances(&model);
133		MulticlassClassifierTrainOutput {
134			model,
135			losses,
136			feature_importances: Some(feature_importances),
137		}
138	}
139
140	fn compute_feature_importances(model: &MulticlassClassifier) -> Vec<f32> {
141		// Compute the absolute value of each of the weights.
142		let mut feature_importances = model
143			.weights
144			.axis_iter(Axis(0))
145			.map(|weights_each_class| {
146				weights_each_class
147					.iter()
148					.map(|weight| weight.abs())
149					.sum::<f32>() / model.weights.ncols().to_f32().unwrap()
150			})
151			.collect::<Vec<_>>();
152		// Compute the sum and normalize so the importances sum to 1.
153		let feature_importances_sum: f32 = feature_importances.iter().sum::<f32>();
154		feature_importances
155			.iter_mut()
156			.for_each(|feature_importance| *feature_importance /= feature_importances_sum);
157		feature_importances
158	}
159
160	fn train_batch(
161		&mut self,
162		features: ArrayView2<f32>,
163		labels: ArrayView1<Option<NonZeroUsize>>,
164		mut probabilities: ArrayViewMut2<f32>,
165		train_options: &TrainOptions,
166		kill_chip: &tangram_kill_chip::KillChip,
167	) {
168		if kill_chip.is_activated() {
169			return;
170		}
171		let learning_rate = train_options.learning_rate;
172		let n_classes = self.weights.ncols();
173		let mut logits = features.dot(&self.weights) + &self.biases;
174		softmax(logits.view_mut());
175		for (probability, logit) in zip!(probabilities.iter_mut(), logits.iter()) {
176			*probability = *logit;
177		}
178		let mut predictions = logits;
179		for (mut predictions, label) in zip!(predictions.axis_iter_mut(Axis(0)), labels) {
180			for (class_index, prediction) in predictions.iter_mut().enumerate() {
181				*prediction -= if class_index == label.unwrap().get() - 1 {
182					1.0
183				} else {
184					0.0
185				};
186			}
187		}
188		let py = predictions;
189		for class_index in 0..n_classes {
190			let weight_gradients = (&features * &py.column(class_index).insert_axis(Axis(1)))
191				.mean_axis(Axis(0))
192				.unwrap();
193			for (weight, weight_gradient) in zip!(
194				self.weights.column_mut(class_index),
195				weight_gradients.iter()
196			) {
197				*weight += -learning_rate * weight_gradient
198			}
199			let bias_gradients = py
200				.column(class_index)
201				.insert_axis(Axis(1))
202				.mean_axis(Axis(0))
203				.unwrap();
204			self.biases[class_index] += -learning_rate * bias_gradients[0];
205		}
206	}
207
208	pub fn compute_loss(
209		probabilities: ArrayView2<f32>,
210		labels: ArrayView1<Option<NonZeroUsize>>,
211	) -> f32 {
212		let mut loss = 0.0;
213		for (label, probabilities) in zip!(labels.into_iter(), probabilities.axis_iter(Axis(0))) {
214			for (index, &probability) in probabilities.indexed_iter() {
215				let probability = clamp(probability, std::f32::EPSILON, 1.0 - std::f32::EPSILON);
216				if index == (label.unwrap().get() - 1) {
217					loss += -probability.ln();
218				}
219			}
220		}
221		loss / labels.len().to_f32().unwrap()
222	}
223
224	fn compute_early_stopping_metric_value(
225		&self,
226		features: ArrayView2<f32>,
227		labels: ArrayView1<Option<NonZeroUsize>>,
228		train_options: &TrainOptions,
229	) -> f32 {
230		let n_classes = self.biases.len();
231		pzip!(
232			features.axis_chunks_iter(Axis(0), train_options.n_examples_per_batch),
233			labels.axis_chunks_iter(Axis(0), train_options.n_examples_per_batch),
234		)
235		.fold(
236			|| {
237				let predictions = unsafe {
238					<Array2<f32>>::uninit((train_options.n_examples_per_batch, n_classes))
239						.assume_init()
240				};
241				let metric = CrossEntropy::default();
242				(predictions, metric)
243			},
244			|(mut predictions, mut metric), (features, labels)| {
245				let slice = s![0..features.nrows(), ..];
246				let mut predictions_slice = predictions.slice_mut(slice);
247				self.predict(features, predictions_slice.view_mut());
248				for (prediction, label) in zip!(predictions_slice.axis_iter(Axis(0)), labels.iter())
249				{
250					metric.update(CrossEntropyInput {
251						probabilities: prediction,
252						label: *label,
253					});
254				}
255				(predictions, metric)
256			},
257		)
258		.map(|(_, metric)| metric)
259		.reduce(CrossEntropy::new, |mut a, b| {
260			a.merge(b);
261			a
262		})
263		.finalize()
264		.0
265		.unwrap()
266	}
267
268	/// Write predicted probabilities into `probabilities` for the input `features`.
269	pub fn predict(&self, features: ArrayView2<f32>, mut probabilities: ArrayViewMut2<f32>) {
270		for mut row in probabilities.axis_iter_mut(Axis(0)) {
271			row.assign(&self.biases.view());
272		}
273		ndarray::linalg::general_mat_mul(1.0, &features, &self.weights, 1.0, &mut probabilities);
274		softmax(probabilities);
275	}
276
277	pub fn compute_feature_contributions(
278		&self,
279		features: ArrayView2<f32>,
280	) -> Vec<Vec<ComputeShapValuesForExampleOutput>> {
281		features
282			.axis_iter(Axis(0))
283			.map(|features| {
284				zip!(self.weights.axis_iter(Axis(1)), self.biases.view())
285					.map(|(weights, bias)| {
286						compute_shap_values_for_example(
287							features.as_slice().unwrap(),
288							*bias,
289							weights.view(),
290							&self.means,
291						)
292					})
293					.collect()
294			})
295			.collect()
296	}
297
298	pub fn from_reader(
299		multiclass_classifier: crate::serialize::MulticlassClassifierReader,
300	) -> MulticlassClassifier {
301		crate::serialize::deserialize_multiclass_classifier(multiclass_classifier)
302	}
303
304	pub fn to_writer(
305		&self,
306		writer: &mut buffalo::Writer,
307	) -> buffalo::Position<crate::serialize::MulticlassClassifierWriter> {
308		crate::serialize::serialize_multiclass_classifier(self, writer)
309	}
310
311	pub fn from_bytes(&self, bytes: &[u8]) -> MulticlassClassifier {
312		let reader = buffalo::read::<crate::serialize::MulticlassClassifierReader>(bytes);
313		Self::from_reader(reader)
314	}
315
316	pub fn to_bytes(&self) -> Vec<u8> {
317		// Create the writer.
318		let mut writer = buffalo::Writer::new();
319		self.to_writer(&mut writer);
320		writer.into_bytes()
321	}
322}
323
324fn softmax(mut logits: ArrayViewMut2<f32>) {
325	for mut logits in logits.axis_iter_mut(Axis(0)) {
326		let max = logits.iter().fold(std::f32::MIN, |a, &b| f32::max(a, b));
327		for logit in logits.iter_mut() {
328			*logit = (*logit - max).exp();
329		}
330		let sum = logits.iter().sum::<f32>();
331		for logit in logits.iter_mut() {
332			*logit /= sum;
333		}
334	}
335}