fast_distances/
utils.rs

1use ndarray;
2
3use ndarray::{Array1, Array2};
4use num::{Float, Zero};
5
6/// sign is a function that takes a floating-point number (f64 in this case) as input and returns an integer (i32).
7/// If a is less than 0.0, it returns -1, otherwise it returns 1.
8pub fn sign<T>(a: T) -> i32
9where
10    T: Zero + PartialOrd + Copy,
11{
12    if a < T::zero() {
13        -1
14    } else {
15        1
16    }
17}
18
19// Function to generate an identity matrix of size n x n for a given type T
20pub fn identity_matrix<T>(n: usize) -> Array2<T>
21where
22    T: Float,
23{
24    let mut identity = Array2::<T>::zeros((n, n));
25    for i in 0..n {
26        identity[(i, i)] = T::one(); // Set diagonal elements to 1.0
27    }
28    identity
29}
30
31// Function to generate a ones vector of size n for a given type T
32pub fn ones_vector<T>(n: usize) -> Array1<T>
33where
34    T: Float,
35{
36    Array1::<T>::ones(n)
37}
38
39// Function to generate a cost matrix (1.0 - identity matrix) of size n x n for a given type T
40pub fn cost_matrix<T>(n: usize) -> Array2<T>
41where
42    T: Float,
43{
44    let identity = identity_matrix::<T>(n);
45    Array2::<T>::from_elem((n, n), T::one()) - identity
46}
47
48#[cfg(test)]
49mod tests {
50    use super::*;
51
52    // Test for sign with f64
53    #[test]
54    fn test_sign_f64() {
55        assert_eq!(sign(-3.5), -1); // Negative number
56        assert_eq!(sign(0.0), 1); // Zero
57        assert_eq!(sign(4.5), 1); // Positive number
58    }
59
60    // Test for sign with f32
61    #[test]
62    fn test_sign_f32() {
63        assert_eq!(sign(-2.7f32), -1); // Negative number
64        assert_eq!(sign(0.0f32), 1); // Zero
65        assert_eq!(sign(3.3f32), 1); // Positive number
66    }
67
68    // Test for sign with i32
69    #[test]
70    fn test_sign_i32() {
71        assert_eq!(sign(-10), -1); // Negative number
72        assert_eq!(sign(0), 1); // Zero
73        assert_eq!(sign(25), 1); // Positive number
74    }
75
76    // Test for sign with edge cases
77    #[test]
78    fn test_sign_edge_cases() {
79        // Test with large values
80        assert_eq!(sign(f64::INFINITY), 1); // Positive infinity
81        assert_eq!(sign(f64::NEG_INFINITY), -1); // Negative infinity
82        assert_eq!(sign(f64::NAN), 1); // NaN (Not a Number), treated as non-negative
83    }
84
85    // Test identity matrix for f64
86    #[test]
87    fn test_identity_matrix_f64() {
88        let n = 3;
89        let identity = identity_matrix::<f64>(n);
90
91        // Assert diagonal elements are 1.0 and others are 0.0
92        for i in 0..n {
93            for j in 0..n {
94                if i == j {
95                    assert_eq!(
96                        identity[(i, j)],
97                        1.0,
98                        "Diagonal element at ({}, {}) should be 1.0",
99                        i,
100                        j
101                    );
102                } else {
103                    assert_eq!(
104                        identity[(i, j)],
105                        0.0,
106                        "Off-diagonal element at ({}, {}) should be 0.0",
107                        i,
108                        j
109                    );
110                }
111            }
112        }
113    }
114
115    // Test identity matrix for f32
116    #[test]
117    fn test_identity_matrix_f32() {
118        let n = 2;
119        let identity = identity_matrix::<f32>(n);
120
121        // Assert diagonal elements are 1.0 and others are 0.0
122        for i in 0..n {
123            for j in 0..n {
124                if i == j {
125                    assert_eq!(
126                        identity[(i, j)],
127                        1.0,
128                        "Diagonal element at ({}, {}) should be 1.0",
129                        i,
130                        j
131                    );
132                } else {
133                    assert_eq!(
134                        identity[(i, j)],
135                        0.0,
136                        "Off-diagonal element at ({}, {}) should be 0.0",
137                        i,
138                        j
139                    );
140                }
141            }
142        }
143    }
144
145    // Test ones vector for f64
146    #[test]
147    fn test_ones_vector_f64() {
148        let n = 4;
149        let ones = ones_vector::<f64>(n);
150
151        // Assert all elements are 1.0
152        for i in 0..n {
153            assert_eq!(ones[i], 1.0, "Element at index {} should be 1.0", i);
154        }
155    }
156
157    // Test ones vector for f32
158    #[test]
159    fn test_ones_vector_f32() {
160        let n = 3;
161        let ones = ones_vector::<f32>(n);
162
163        // Assert all elements are 1.0
164        for i in 0..n {
165            assert_eq!(ones[i], 1.0, "Element at index {} should be 1.0", i);
166        }
167    }
168
169    // Test cost matrix for f64
170    #[test]
171    fn test_cost_matrix_f64() {
172        let n = 3;
173        let cost = cost_matrix::<f64>(n);
174        let identity = identity_matrix::<f64>(n);
175
176        // Assert that cost matrix is 1.0 - identity matrix
177        for i in 0..n {
178            for j in 0..n {
179                assert_eq!(
180                    cost[(i, j)],
181                    1.0 - identity[(i, j)],
182                    "Element at ({}, {}) should be 1.0 - identity",
183                    i,
184                    j
185                );
186            }
187        }
188    }
189
190    // Test cost matrix for f32
191    #[test]
192    fn test_cost_matrix_f32() {
193        let n = 2;
194        let cost = cost_matrix::<f32>(n);
195        let identity = identity_matrix::<f32>(n);
196
197        // Assert that cost matrix is 1.0 - identity matrix
198        for i in 0..n {
199            for j in 0..n {
200                assert_eq!(
201                    cost[(i, j)],
202                    1.0 - identity[(i, j)],
203                    "Element at ({}, {}) should be 1.0 - identity",
204                    i,
205                    j
206                );
207            }
208        }
209    }
210
211    // Test edge case: identity matrix for 1x1 (f64)
212    #[test]
213    fn test_identity_matrix_1x1_f64() {
214        let n = 1;
215        let identity = identity_matrix::<f64>(n);
216
217        // Assert that it is a 1x1 matrix with 1.0
218        assert_eq!(identity[(0, 0)], 1.0);
219    }
220
221    // Test edge case: identity matrix for 1x1 (f32)
222    #[test]
223    fn test_identity_matrix_1x1_f32() {
224        let n = 1;
225        let identity = identity_matrix::<f32>(n);
226
227        // Assert that it is a 1x1 matrix with 1.0
228        assert_eq!(identity[(0, 0)], 1.0);
229    }
230}