1use ndarray;
2
3use ndarray::{Array1, Array2};
4use num::{Float, Zero};
5
6pub fn sign<T>(a: T) -> i32
9where
10 T: Zero + PartialOrd + Copy,
11{
12 if a < T::zero() {
13 -1
14 } else {
15 1
16 }
17}
18
19pub fn identity_matrix<T>(n: usize) -> Array2<T>
21where
22 T: Float,
23{
24 let mut identity = Array2::<T>::zeros((n, n));
25 for i in 0..n {
26 identity[(i, i)] = T::one(); }
28 identity
29}
30
31pub fn ones_vector<T>(n: usize) -> Array1<T>
33where
34 T: Float,
35{
36 Array1::<T>::ones(n)
37}
38
39pub fn cost_matrix<T>(n: usize) -> Array2<T>
41where
42 T: Float,
43{
44 let identity = identity_matrix::<T>(n);
45 Array2::<T>::from_elem((n, n), T::one()) - identity
46}
47
48#[cfg(test)]
49mod tests {
50 use super::*;
51
52 #[test]
54 fn test_sign_f64() {
55 assert_eq!(sign(-3.5), -1); assert_eq!(sign(0.0), 1); assert_eq!(sign(4.5), 1); }
59
60 #[test]
62 fn test_sign_f32() {
63 assert_eq!(sign(-2.7f32), -1); assert_eq!(sign(0.0f32), 1); assert_eq!(sign(3.3f32), 1); }
67
68 #[test]
70 fn test_sign_i32() {
71 assert_eq!(sign(-10), -1); assert_eq!(sign(0), 1); assert_eq!(sign(25), 1); }
75
76 #[test]
78 fn test_sign_edge_cases() {
79 assert_eq!(sign(f64::INFINITY), 1); assert_eq!(sign(f64::NEG_INFINITY), -1); assert_eq!(sign(f64::NAN), 1); }
84
85 #[test]
87 fn test_identity_matrix_f64() {
88 let n = 3;
89 let identity = identity_matrix::<f64>(n);
90
91 for i in 0..n {
93 for j in 0..n {
94 if i == j {
95 assert_eq!(
96 identity[(i, j)],
97 1.0,
98 "Diagonal element at ({}, {}) should be 1.0",
99 i,
100 j
101 );
102 } else {
103 assert_eq!(
104 identity[(i, j)],
105 0.0,
106 "Off-diagonal element at ({}, {}) should be 0.0",
107 i,
108 j
109 );
110 }
111 }
112 }
113 }
114
115 #[test]
117 fn test_identity_matrix_f32() {
118 let n = 2;
119 let identity = identity_matrix::<f32>(n);
120
121 for i in 0..n {
123 for j in 0..n {
124 if i == j {
125 assert_eq!(
126 identity[(i, j)],
127 1.0,
128 "Diagonal element at ({}, {}) should be 1.0",
129 i,
130 j
131 );
132 } else {
133 assert_eq!(
134 identity[(i, j)],
135 0.0,
136 "Off-diagonal element at ({}, {}) should be 0.0",
137 i,
138 j
139 );
140 }
141 }
142 }
143 }
144
145 #[test]
147 fn test_ones_vector_f64() {
148 let n = 4;
149 let ones = ones_vector::<f64>(n);
150
151 for i in 0..n {
153 assert_eq!(ones[i], 1.0, "Element at index {} should be 1.0", i);
154 }
155 }
156
157 #[test]
159 fn test_ones_vector_f32() {
160 let n = 3;
161 let ones = ones_vector::<f32>(n);
162
163 for i in 0..n {
165 assert_eq!(ones[i], 1.0, "Element at index {} should be 1.0", i);
166 }
167 }
168
169 #[test]
171 fn test_cost_matrix_f64() {
172 let n = 3;
173 let cost = cost_matrix::<f64>(n);
174 let identity = identity_matrix::<f64>(n);
175
176 for i in 0..n {
178 for j in 0..n {
179 assert_eq!(
180 cost[(i, j)],
181 1.0 - identity[(i, j)],
182 "Element at ({}, {}) should be 1.0 - identity",
183 i,
184 j
185 );
186 }
187 }
188 }
189
190 #[test]
192 fn test_cost_matrix_f32() {
193 let n = 2;
194 let cost = cost_matrix::<f32>(n);
195 let identity = identity_matrix::<f32>(n);
196
197 for i in 0..n {
199 for j in 0..n {
200 assert_eq!(
201 cost[(i, j)],
202 1.0 - identity[(i, j)],
203 "Element at ({}, {}) should be 1.0 - identity",
204 i,
205 j
206 );
207 }
208 }
209 }
210
211 #[test]
213 fn test_identity_matrix_1x1_f64() {
214 let n = 1;
215 let identity = identity_matrix::<f64>(n);
216
217 assert_eq!(identity[(0, 0)], 1.0);
219 }
220
221 #[test]
223 fn test_identity_matrix_1x1_f32() {
224 let n = 1;
225 let identity = identity_matrix::<f32>(n);
226
227 assert_eq!(identity[(0, 0)], 1.0);
229 }
230}