1pub mod csv_io;
18pub mod matrix_operations;
19pub mod multivariate_regression;
20
21use csv_io::{read_input_output, train_test_split};
22
23#[cfg(test)]
24mod tests {
25 use super::*;
26 use crate::multivariate_regression::cost_fn::cost_fn::CostFn;
27 use crate::multivariate_regression::gradient::Gradient;
28 use crate::multivariate_regression::regularization::regularization::Regularization;
29 use crate::multivariate_regression::training::train::train;
30 use crate::multivariate_regression::training::train_config::TrainConfigBuilder;
31 use crate::multivariate_regression::update_weight::{MiniBatchSize, UpdatationMethod};
32
33 #[test]
34 fn train_test() {
35 let filepath = "Student_Performance.csv".to_string();
36 let output_cols = vec!["Performance Index".to_string()];
37 let input_exclude_cols: Vec<String> = Vec::new();
38
39 let (input, output) = read_input_output(filepath, output_cols, input_exclude_cols)
40 .expect("Failed to read input and output from CSV");
41 let (x_train, y_train, x_test, y_test) = train_test_split(input, output, 0.7)
42 .expect("Failed to split dataset");
43
44 let config = TrainConfigBuilder::new()
45 .epochs(100)
46 .print_log(true)
47 .cost_fn(CostFn::mean_absolute_error())
48 .delta(0.9)
49 .regularization(Regularization::l1(0.9))
50 .learning_rate(0.001)
51 .optimizer(UpdatationMethod::BGD)
52 .gradient_fn(Gradient::mean_absolute_error(Regularization::l1(0.9)))
53 .build();
54
55 train(x_train, y_train, config);
56 }
57}