tangram_linear 0.7.0

Tangram makes it easy for programmers to train, deploy, and monitor machine learning models.
Documentation
use maplit::btreemap;
use ndarray::prelude::*;
use rayon::prelude::*;
use serde_json::json;
use std::path::Path;
use tangram_linear::Progress;
use tangram_table::prelude::*;
use tangram_zip::{pzip, zip};

fn main() {
	// Load the data.
	let csv_file_path_train = Path::new("data/higgs_train.csv");
	let csv_file_path_test = Path::new("data/higgs_test.csv");
	let target_column_index = 0;
	let options = tangram_table::FromCsvOptions {
		column_types: Some(btreemap! {
			"signal".to_owned() => TableColumnType::Enum { variants: vec!["false".to_owned(), "true".to_owned()] },
			"lepton_pt".to_owned() => TableColumnType::Number,
			"lepton_eta".to_owned() => TableColumnType::Number,
			"lepton_phi".to_owned() => TableColumnType::Number,
			"missing_energy_magnitude".to_owned() => TableColumnType::Number,
			"missing_energy_phi".to_owned() => TableColumnType::Number,
			"jet_1_pt".to_owned() => TableColumnType::Number,
			"jet_1_eta".to_owned() => TableColumnType::Number,
			"jet_1_phi".to_owned() => TableColumnType::Number,
			"jet_1_b_tag".to_owned() => TableColumnType::Number,
			"jet_2_pt".to_owned() => TableColumnType::Number,
			"jet_2_eta".to_owned() => TableColumnType::Number,
			"jet_2_phi".to_owned() => TableColumnType::Number,
			"jet_2_b_tag".to_owned() => TableColumnType::Number,
			"jet_3_pt".to_owned() => TableColumnType::Number,
			"jet_3_eta".to_owned() => TableColumnType::Number,
			"jet_3_phi".to_owned() => TableColumnType::Number,
			"jet_3_b_tag".to_owned() => TableColumnType::Number,
			"jet_4_pt".to_owned() => TableColumnType::Number,
			"jet_4_eta".to_owned() => TableColumnType::Number,
			"jet_4_phi".to_owned() => TableColumnType::Number,
			"jet_4_b_tag".to_owned() => TableColumnType::Number,
			"m_jj".to_owned() => TableColumnType::Number,
			"m_jjj".to_owned() => TableColumnType::Number,
			"m_lv".to_owned() => TableColumnType::Number,
			"m_jlv".to_owned() => TableColumnType::Number,
			"m_bb".to_owned() => TableColumnType::Number,
			"m_wbb".to_owned() => TableColumnType::Number,
			"m_wwbb".to_owned() => TableColumnType::Number,
		}),
		..Default::default()
	};
	let mut features_train =
		Table::from_path(csv_file_path_train, options.clone(), &mut |_| {}).unwrap();
	let labels_train = features_train.columns_mut().remove(target_column_index);
	let labels_train = labels_train.as_enum().unwrap();
	let mut features_test =
		Table::from_path(csv_file_path_test, options.clone(), &mut |_| {}).unwrap();
	let labels_test = features_test.columns_mut().remove(target_column_index);
	let labels_test = labels_test.as_enum().unwrap();
	let feature_groups: Vec<tangram_features::FeatureGroup> = features_train
		.columns()
		.iter()
		.map(|column| match column {
			TableColumn::Number(column) => {
				let mean_variance = tangram_metrics::MeanVariance::compute(
					column.view().as_slice().iter().cloned(),
				);
				tangram_features::FeatureGroup::Normalized(
					tangram_features::NormalizedFeatureGroup {
						source_column_name: column.name().clone().unwrap(),
						mean: mean_variance.mean,
						variance: mean_variance.variance,
					},
				)
			}
			_ => unreachable!(),
		})
		.collect();
	let features_train = tangram_features::compute_features_array_f32(
		&features_train.view(),
		feature_groups.as_slice(),
		&|| {},
	);
	let features_test = tangram_features::compute_features_array_f32(
		&features_test.view(),
		feature_groups.as_slice(),
		&|| {},
	);

	// Train the model.
	let train_output = tangram_linear::BinaryClassifier::train(
		features_train.view(),
		labels_train.view(),
		&tangram_linear::TrainOptions {
			learning_rate: 0.01,
			max_epochs: 1,
			n_examples_per_batch: 1000,
			..Default::default()
		},
		Progress {
			kill_chip: &tangram_kill_chip::KillChip::default(),
			handle_progress_event: &mut |_| {},
		},
	);

	// Make predictions on the test data.
	let chunk_size =
		(features_test.nrows() + rayon::current_num_threads() - 1) / rayon::current_num_threads();
	let mut probabilities = Array::zeros(features_test.nrows());
	pzip!(
		features_test.axis_chunks_iter(Axis(0), chunk_size),
		probabilities.axis_chunks_iter_mut(Axis(0), chunk_size),
	)
	.for_each(|(features_test_chunk, probabilities_chunk)| {
		train_output
			.model
			.predict(features_test_chunk, probabilities_chunk);
	});

	// Compute metrics.
	let input = zip!(probabilities.iter(), labels_test.iter())
		.map(|(probability, label)| (*probability, label.unwrap()))
		.collect();
	let auc_roc = tangram_metrics::AucRoc::compute(input);

	let output = json!({
		"auc_roc": auc_roc,
	});
	println!("{}", output);
}