1extern crate nalgebra;
2extern crate rand;
3
4use nalgebra::{DVector, DMatrix};
5use rand::Rng;
6
7pub enum ActivationFunction {
8 Sigmoid,
9 ReLU,
10 Tanh,
11 LeakyReLU(f64),
12 ELU(f64),
13}
14
15pub enum Regularization {
16 L2(f64),
17 L1(f64),
18 Dropout(f64),
19}
20
21pub struct Hextral {
22 h: DMatrix<f64>,
23 qft: f64,
24 laplace: f64,
25}
26
27impl Hextral {
28 pub fn new(qft: f64, laplace: f64) -> Self {
29 let h = DMatrix::from_fn(10, 10, |_, _| rand::thread_rng().gen::<f64>() * 0.1);
30 Hextral { h, qft, laplace }
31 }
32
33 pub fn forward_pass(&self, input: &DVector<f64>, activation: ActivationFunction) -> DVector<f64> {
34 let output = &self.h * input;
35
36 let output = match activation {
37 ActivationFunction::Sigmoid => output.map(|x| sigmoid(x)),
38 ActivationFunction::ReLU => output.map(|x| x.max(0.0)),
39 ActivationFunction::Tanh => output.map(|x| x.tanh()),
40 ActivationFunction::LeakyReLU(alpha) => output.map(|x| if x >= 0.0 { x } else { alpha * x }),
41 ActivationFunction::ELU(alpha) => output.map(|x| if x >= 0.0 { x } else { alpha * (x.exp() - 1.0) }),
42 };
43
44 output
45 }
46
47 pub fn train(&mut self, inputs: &[DVector<f64>], targets: &[DVector<f64>], learning_rate: f64, regularization: Regularization, epochs: usize, batch_size: usize) {
48 for _ in 0..epochs {
49 for batch_start in (0..inputs.len()).step_by(batch_size) {
50 let batch_end = (batch_start + batch_size).min(inputs.len());
51 let batch_inputs = &inputs[batch_start..batch_end];
52 let batch_targets = &targets[batch_start..batch_end];
53
54 let mut batch_gradients = DMatrix::zeros(self.h.nrows(), self.h.ncols());
55
56 for (input, target) in batch_inputs.iter().zip(batch_targets.iter()) {
57 let output = self.forward_pass(input, ActivationFunction::Sigmoid);
58 let loss_gradient = &output - target;
59 let gradients = loss_gradient.clone() * input.transpose();
60 batch_gradients += gradients;
61 }
62
63 batch_gradients /= batch_size as f64;
64
65 self.update_parameters(learning_rate, &batch_gradients, ®ularization);
66 }
67 }
68 }
69
70 pub fn update_parameters(&mut self, learning_rate: f64, gradients: &DMatrix<f64>, regularization: &Regularization) {
71 let gradient_update = learning_rate * gradients;
72
73 match regularization {
74 Regularization::L2(lambda) => {
75 self.h *= 1.0 - learning_rate * *lambda;
76 self.h -= &gradient_update;
77 }
78 Regularization::L1(lambda) => {
79 let signum = self.h.map(|x| x.signum());
80 self.h *= 1.0 - learning_rate * *lambda;
81 self.h -= &gradient_update;
82 self.h += learning_rate * *lambda * &signum;
83 }
84 Regularization::Dropout(rate) => {
85 let dropout_mask = DMatrix::from_fn(gradients.nrows(), gradients.ncols(), |_, _| {
86 if rand::thread_rng().gen::<f64>() < *rate {
87 0.0
88 } else {
89 1.0 / (1.0 - *rate)
90 }
91 });
92 self.h = &self.h.component_mul(&dropout_mask) - &gradient_update;
93 }
94 }
95 }
96
97 pub fn predict(&self, input: &DVector<f64>) -> DVector<f64> {
98 self.forward_pass(input, ActivationFunction::Sigmoid)
99 }
100
101 pub fn evaluate(&self, inputs: &[DVector<f64>], targets: &[DVector<f64>]) -> f64 {
102 let mut total_loss = 0.0;
103 for (input, target) in inputs.iter().zip(targets.iter()) {
104 let output = self.predict(input);
105 let loss = (&output - target).norm_squared();
106 total_loss += loss;
107 }
108 total_loss / inputs.len() as f64
109 }
110}
111
112fn sigmoid(x: f64) -> f64 {
113 1.0 / (1.0 + (-x).exp())
114}
115
116fn main() {
117 let mut hextral = Hextral::new(0.1, 0.2);
118
119 let num_samples = 1000;
120 let inputs: Vec<DVector<f64>> = (0..num_samples)
121 .map(|_| DVector::from_iterator(10, (0..10).map(|_| rand::thread_rng().gen::<f64>())))
122 .collect();
123
124 let targets: Vec<DVector<f64>> = (0..num_samples)
125 .map(|_| DVector::from_iterator(10, (0..10).map(|_| rand::thread_rng().gen::<f64>())))
126 .collect();
127
128 hextral.train(&inputs, &targets, 0.01, Regularization::L2(0.001), 100, 32);
129
130 let input = DVector::from_iterator(10, (0..10).map(|_| rand::thread_rng().gen::<f64>()));
131 let prediction = hextral.predict(&input);
132 println!("Prediction: {:?}", prediction);
133
134 let evaluation_loss = hextral.evaluate(&inputs, &targets);
135 println!("Evaluation Loss: {}", evaluation_loss);
136}