ds_r1_rs/utils/
math.rs

1//! # Mathematical Utilities
2//!
3//! Mathematical functions and utilities for the model.
4
5use crate::utils::error::{ModelError, Result};
6
7/// Mathematical utility functions
8pub struct MathUtils;
9
10impl MathUtils {
11    /// Compute softmax over a vector
12    pub fn softmax(input: &[f32]) -> Result<Vec<f32>> {
13        if input.is_empty() {
14            return Err(ModelError::Math("Input vector is empty".to_string()));
15        }
16
17        // Find maximum for numerical stability
18        let max_val = input.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
19
20        // Compute exponentials
21        let exp_values: Vec<f32> = input.iter().map(|&x| (x - max_val).exp()).collect();
22
23        // Compute sum
24        let sum: f32 = exp_values.iter().sum();
25
26        if sum == 0.0 {
27            return Err(ModelError::Math("Softmax sum is zero".to_string()));
28        }
29
30        // Normalize
31        Ok(exp_values.iter().map(|&x| x / sum).collect())
32    }
33
34    /// Compute layer normalization
35    pub fn layer_norm(input: &[f32], eps: f32) -> Result<Vec<f32>> {
36        if input.is_empty() {
37            return Err(ModelError::Math("Input vector is empty".to_string()));
38        }
39
40        let n = input.len() as f32;
41
42        // Compute mean
43        let mean = input.iter().sum::<f32>() / n;
44
45        // Compute variance
46        let variance = input.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / n;
47
48        // Normalize
49        let std_dev = (variance + eps).sqrt();
50        Ok(input.iter().map(|&x| (x - mean) / std_dev).collect())
51    }
52
53    /// Compute GELU activation
54    pub fn gelu(x: f32) -> f32 {
55        0.5 * x * (1.0 + ((2.0 / std::f32::consts::PI).sqrt() * (x + 0.044715 * x.powi(3))).tanh())
56    }
57
58    /// Compute SwiGLU activation
59    pub fn swiglu(x: f32, gate: f32) -> f32 {
60        x * Self::sigmoid(gate)
61    }
62
63    /// Compute sigmoid activation
64    pub fn sigmoid(x: f32) -> f32 {
65        1.0 / (1.0 + (-x).exp())
66    }
67
68    /// Compute ReLU activation
69    pub fn relu(x: f32) -> f32 {
70        x.max(0.0)
71    }
72
73    /// Generate random normal distribution sample (Box-Muller transform)
74    pub fn random_normal(mean: f32, std_dev: f32) -> f32 {
75        use rand::Rng;
76        let mut rng = rand::rng();
77
78        let u1: f32 = rng.random();
79        let u2: f32 = rng.random();
80
81        let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos();
82        mean + std_dev * z0
83    }
84
85    /// Compute cosine similarity between two vectors
86    pub fn cosine_similarity(a: &[f32], b: &[f32]) -> Result<f32> {
87        if a.len() != b.len() {
88            return Err(ModelError::Math("Vector dimensions must match".to_string()));
89        }
90
91        if a.is_empty() {
92            return Err(ModelError::Math("Vectors cannot be empty".to_string()));
93        }
94
95        let dot_product: f32 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
96        let norm_a: f32 = a.iter().map(|&x| x * x).sum::<f32>().sqrt();
97        let norm_b: f32 = b.iter().map(|&x| x * x).sum::<f32>().sqrt();
98
99        if norm_a == 0.0 || norm_b == 0.0 {
100            return Err(ModelError::Math(
101                "Cannot compute similarity with zero vector".to_string(),
102            ));
103        }
104
105        Ok(dot_product / (norm_a * norm_b))
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112
113    #[test]
114    fn test_softmax() {
115        let input = vec![1.0, 2.0, 3.0];
116        let result = MathUtils::softmax(&input).unwrap();
117
118        // Check that probabilities sum to 1
119        let sum: f32 = result.iter().sum();
120        assert!((sum - 1.0).abs() < 1e-6);
121
122        // Check that all values are positive
123        assert!(result.iter().all(|&x| x > 0.0));
124
125        // Check that larger inputs have larger probabilities
126        assert!(result[2] > result[1]);
127        assert!(result[1] > result[0]);
128    }
129
130    #[test]
131    fn test_layer_norm() {
132        let input = vec![1.0, 2.0, 3.0, 4.0];
133        let result = MathUtils::layer_norm(&input, 1e-5).unwrap();
134
135        // Check that mean is approximately 0
136        let mean: f32 = result.iter().sum::<f32>() / result.len() as f32;
137        assert!(mean.abs() < 1e-6);
138
139        // Check that variance is approximately 1
140        let variance: f32 =
141            result.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / result.len() as f32;
142        assert!((variance - 1.0).abs() < 1e-5);
143    }
144
145    #[test]
146    fn test_activations() {
147        // Test GELU
148        assert!(MathUtils::gelu(0.0).abs() < 1e-6);
149        assert!(MathUtils::gelu(1.0) > 0.0);
150        assert!(MathUtils::gelu(-1.0) < 0.0);
151
152        // Test sigmoid
153        assert!((MathUtils::sigmoid(0.0) - 0.5).abs() < 1e-6);
154        assert!(MathUtils::sigmoid(10.0) > 0.9);
155        assert!(MathUtils::sigmoid(-10.0) < 0.1);
156
157        // Test ReLU
158        assert_eq!(MathUtils::relu(5.0), 5.0);
159        assert_eq!(MathUtils::relu(-5.0), 0.0);
160        assert_eq!(MathUtils::relu(0.0), 0.0);
161    }
162
163    #[test]
164    fn test_cosine_similarity() {
165        let a = vec![1.0, 0.0, 0.0];
166        let b = vec![1.0, 0.0, 0.0];
167        let similarity = MathUtils::cosine_similarity(&a, &b).unwrap();
168        assert!((similarity - 1.0).abs() < 1e-6);
169
170        let a = vec![1.0, 0.0];
171        let b = vec![0.0, 1.0];
172        let similarity = MathUtils::cosine_similarity(&a, &b).unwrap();
173        assert!(similarity.abs() < 1e-6);
174    }
175
176    #[test]
177    fn test_error_cases() {
178        // Empty vector for softmax
179        assert!(MathUtils::softmax(&[]).is_err());
180
181        // Empty vector for layer norm
182        assert!(MathUtils::layer_norm(&[], 1e-5).is_err());
183
184        // Mismatched vector sizes for cosine similarity
185        let a = vec![1.0, 2.0];
186        let b = vec![1.0];
187        assert!(MathUtils::cosine_similarity(&a, &b).is_err());
188    }
189}