TreeRustler
TreeRustler is a simple implementation of a Decision Tree Classifier using the Rust programming language. This project serves as a beginner's exploration of Rust while building a decision tree classifier. The main modules in this project are the data module and the tree module.
Data Module
The data module provides the Data struct, which represents the expected data type for the tree implementation. The Data struct holds the training features required for building the decision tree classifier. This module handles only the necessary data operations for the Decision Tree (i.e. no loading, preprocessing, or splitting yet!).
Tree Module
The tree module contains the DecisionTreeClassifier struct. This struct represents the decision tree classifier and provides the following parameters:
max_depth: Specifies the maximum depth of the decision tree. It controls how deep the tree can grow during training. Setting a smaller value can help prevent overfitting, but too small may result in underfitting.min_samples_split: Specifies the minimum number of samples required to split an internal node during training. It controls when to stop splitting nodes further. Setting a higher value can prevent overfitting, but too high may result in underfitting.
The DecisionTreeClassifier struct has the following methods:
fit(x: &Data, y: &Vec<u8>): Fits the decision tree classifier to the provided training data. This method trains the classifier using the feature from theDatastruct and labels from a different vector.predict_proba(x: &Data) -> Vec<f64>: Predicts the class probabilities for the provided data using the trained decision tree. It returns a vector of probabilities for each class.
Usage
To use the TreeRustler project, follow these steps:
- Clone the repository:
git clone https://github.com/EduardoPach/treerustler.git - Navigate to the project directory:
cd treerustler - Make sure you have Rust installed. If not, install Rust from https://www.rust-lang.org/.
- Load your data and convert your features data to the
Datastruct and your labels to aVec<u8>. - Create an instance of
DecisionTreeClassifierfrom thetreemodule, specifying the desiredmax_depthandmin_samples_splitvalues. - Call the
fitmethod on the classifier instance, passing your training data. - Use the
predict_probamethod to predict class probabilities for new data points.
use Data;
use DecisionTreeClassifier;