ferrite_rs/
lib.rs

1//! # Ferrite - An ML Library
2//!
3//! A Rust-based machine learning library for multivariate regression and matrix operations.
4//!
5//! ## Features
6//! - CSV Input/Output
7//! - Multivariate Regression
8//! - Matrix Operations
9//! - Training with Different Optimizers
10//!
11//! ## Example Usage
12//! ```rust
13//! use ferrite::csv_io::read_input_output;
14//! ```
15
16// Re-export public modules
17pub mod csv_io;
18pub mod matrix_operations;
19pub mod multivariate_regression;
20
21use csv_io::{read_input_output, train_test_split};
22
23#[cfg(test)]
24mod tests {
25    use super::*;
26    use crate::multivariate_regression::cost_fn::cost_fn::CostFn;
27    use crate::multivariate_regression::gradient::Gradient;
28    use crate::multivariate_regression::regularization::regularization::Regularization;
29    use crate::multivariate_regression::training::train::train;
30    use crate::multivariate_regression::training::train_config::TrainConfigBuilder;
31    use crate::multivariate_regression::update_weight::{MiniBatchSize, UpdatationMethod};
32
33    #[test]
34    fn train_test() {
35        let filepath = "Student_Performance.csv".to_string();
36        let output_cols = vec!["Performance Index".to_string()];
37        let input_exclude_cols: Vec<String> = Vec::new();
38
39        let (input, output) = read_input_output(filepath, output_cols, input_exclude_cols)
40            .expect("Failed to read input and output from CSV");
41        let (x_train, y_train, x_test, y_test) = train_test_split(input, output, 0.7)
42            .expect("Failed to split dataset");
43
44        let config = TrainConfigBuilder::new()
45            .epochs(100)
46            .print_log(true)
47            .cost_fn(CostFn::mean_absolute_error())
48            .delta(0.9)
49            .regularization(Regularization::l1(0.9))
50            .learning_rate(0.001)
51            .optimizer(UpdatationMethod::BGD)
52            .gradient_fn(Gradient::mean_absolute_error(Regularization::l1(0.9)))
53            .build();
54
55        train(x_train, y_train, config);
56    }
57}