1use crate::utils::error::{ModelError, Result};
6
7pub struct MathUtils;
9
10impl MathUtils {
11 pub fn softmax(input: &[f32]) -> Result<Vec<f32>> {
13 if input.is_empty() {
14 return Err(ModelError::Math("Input vector is empty".to_string()));
15 }
16
17 let max_val = input.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
19
20 let exp_values: Vec<f32> = input.iter().map(|&x| (x - max_val).exp()).collect();
22
23 let sum: f32 = exp_values.iter().sum();
25
26 if sum == 0.0 {
27 return Err(ModelError::Math("Softmax sum is zero".to_string()));
28 }
29
30 Ok(exp_values.iter().map(|&x| x / sum).collect())
32 }
33
34 pub fn layer_norm(input: &[f32], eps: f32) -> Result<Vec<f32>> {
36 if input.is_empty() {
37 return Err(ModelError::Math("Input vector is empty".to_string()));
38 }
39
40 let n = input.len() as f32;
41
42 let mean = input.iter().sum::<f32>() / n;
44
45 let variance = input.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / n;
47
48 let std_dev = (variance + eps).sqrt();
50 Ok(input.iter().map(|&x| (x - mean) / std_dev).collect())
51 }
52
53 pub fn gelu(x: f32) -> f32 {
55 0.5 * x * (1.0 + ((2.0 / std::f32::consts::PI).sqrt() * (x + 0.044715 * x.powi(3))).tanh())
56 }
57
58 pub fn swiglu(x: f32, gate: f32) -> f32 {
60 x * Self::sigmoid(gate)
61 }
62
63 pub fn sigmoid(x: f32) -> f32 {
65 1.0 / (1.0 + (-x).exp())
66 }
67
68 pub fn relu(x: f32) -> f32 {
70 x.max(0.0)
71 }
72
73 pub fn random_normal(mean: f32, std_dev: f32) -> f32 {
75 use rand::Rng;
76 let mut rng = rand::rng();
77
78 let u1: f32 = rng.random();
79 let u2: f32 = rng.random();
80
81 let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos();
82 mean + std_dev * z0
83 }
84
85 pub fn cosine_similarity(a: &[f32], b: &[f32]) -> Result<f32> {
87 if a.len() != b.len() {
88 return Err(ModelError::Math("Vector dimensions must match".to_string()));
89 }
90
91 if a.is_empty() {
92 return Err(ModelError::Math("Vectors cannot be empty".to_string()));
93 }
94
95 let dot_product: f32 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
96 let norm_a: f32 = a.iter().map(|&x| x * x).sum::<f32>().sqrt();
97 let norm_b: f32 = b.iter().map(|&x| x * x).sum::<f32>().sqrt();
98
99 if norm_a == 0.0 || norm_b == 0.0 {
100 return Err(ModelError::Math(
101 "Cannot compute similarity with zero vector".to_string(),
102 ));
103 }
104
105 Ok(dot_product / (norm_a * norm_b))
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112
113 #[test]
114 fn test_softmax() {
115 let input = vec![1.0, 2.0, 3.0];
116 let result = MathUtils::softmax(&input).unwrap();
117
118 let sum: f32 = result.iter().sum();
120 assert!((sum - 1.0).abs() < 1e-6);
121
122 assert!(result.iter().all(|&x| x > 0.0));
124
125 assert!(result[2] > result[1]);
127 assert!(result[1] > result[0]);
128 }
129
130 #[test]
131 fn test_layer_norm() {
132 let input = vec![1.0, 2.0, 3.0, 4.0];
133 let result = MathUtils::layer_norm(&input, 1e-5).unwrap();
134
135 let mean: f32 = result.iter().sum::<f32>() / result.len() as f32;
137 assert!(mean.abs() < 1e-6);
138
139 let variance: f32 =
141 result.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / result.len() as f32;
142 assert!((variance - 1.0).abs() < 1e-5);
143 }
144
145 #[test]
146 fn test_activations() {
147 assert!(MathUtils::gelu(0.0).abs() < 1e-6);
149 assert!(MathUtils::gelu(1.0) > 0.0);
150 assert!(MathUtils::gelu(-1.0) < 0.0);
151
152 assert!((MathUtils::sigmoid(0.0) - 0.5).abs() < 1e-6);
154 assert!(MathUtils::sigmoid(10.0) > 0.9);
155 assert!(MathUtils::sigmoid(-10.0) < 0.1);
156
157 assert_eq!(MathUtils::relu(5.0), 5.0);
159 assert_eq!(MathUtils::relu(-5.0), 0.0);
160 assert_eq!(MathUtils::relu(0.0), 0.0);
161 }
162
163 #[test]
164 fn test_cosine_similarity() {
165 let a = vec![1.0, 0.0, 0.0];
166 let b = vec![1.0, 0.0, 0.0];
167 let similarity = MathUtils::cosine_similarity(&a, &b).unwrap();
168 assert!((similarity - 1.0).abs() < 1e-6);
169
170 let a = vec![1.0, 0.0];
171 let b = vec![0.0, 1.0];
172 let similarity = MathUtils::cosine_similarity(&a, &b).unwrap();
173 assert!(similarity.abs() < 1e-6);
174 }
175
176 #[test]
177 fn test_error_cases() {
178 assert!(MathUtils::softmax(&[]).is_err());
180
181 assert!(MathUtils::layer_norm(&[], 1e-5).is_err());
183
184 let a = vec![1.0, 2.0];
186 let b = vec![1.0];
187 assert!(MathUtils::cosine_similarity(&a, &b).is_err());
188 }
189}