1use ndarray::{Array1, Array2, ArrayView1, Axis};
4
5pub fn quantile(arr: &Array1<f32>, q: f64) -> f32 {
16 if arr.is_empty() {
17 return 0.0;
18 }
19
20 let mut sorted: Vec<f32> = arr.iter().copied().collect();
21 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
22
23 let n = sorted.len();
24 let idx_float = q * (n - 1) as f64;
25 let lower_idx = idx_float.floor() as usize;
26 let upper_idx = idx_float.ceil() as usize;
27
28 if lower_idx == upper_idx {
29 sorted[lower_idx]
30 } else {
31 let weight = (idx_float - lower_idx as f64) as f32;
32 sorted[lower_idx] * (1.0 - weight) + sorted[upper_idx] * weight
33 }
34}
35
36pub fn quantiles(arr: &Array1<f32>, qs: &[f64]) -> Vec<f32> {
47 if arr.is_empty() {
48 return vec![0.0; qs.len()];
49 }
50
51 let mut sorted: Vec<f32> = arr.iter().copied().collect();
52 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
53
54 let n = sorted.len();
55
56 qs.iter()
57 .map(|&q| {
58 let idx_float = q * (n - 1) as f64;
59 let lower_idx = idx_float.floor() as usize;
60 let upper_idx = idx_float.ceil() as usize;
61
62 if lower_idx == upper_idx {
63 sorted[lower_idx]
64 } else {
65 let weight = (idx_float - lower_idx as f64) as f32;
66 sorted[lower_idx] * (1.0 - weight) + sorted[upper_idx] * weight
67 }
68 })
69 .collect()
70}
71
72pub fn normalize_rows(arr: &Array2<f32>) -> Array2<f32> {
82 let mut result = arr.clone();
83 for mut row in result.axis_iter_mut(Axis(0)) {
84 let norm = row.dot(&row).sqrt().max(1e-12);
85 row /= norm;
86 }
87 result
88}
89
90pub fn row_norms(arr: &Array2<f32>) -> Array1<f32> {
100 arr.axis_iter(Axis(0))
101 .map(|row| row.dot(&row).sqrt())
102 .collect()
103}
104
105pub fn packbits(bits: &[u8]) -> Vec<u8> {
115 bits.chunks(8)
116 .map(|chunk| {
117 let mut byte = 0u8;
118 for (i, &bit) in chunk.iter().enumerate() {
119 byte |= bit << (7 - i);
120 }
121 byte
122 })
123 .collect()
124}
125
126pub fn unpackbits(bytes: &[u8]) -> Vec<u8> {
136 let mut bits = Vec::with_capacity(bytes.len() * 8);
137 for &byte in bytes {
138 for i in (0..8).rev() {
139 bits.push((byte >> i) & 1);
140 }
141 }
142 bits
143}
144
145pub fn create_mask(lengths: &ArrayView1<i64>, max_len: usize) -> Array2<bool> {
156 let batch_size = lengths.len();
157 let mut mask = Array2::from_elem((batch_size, max_len), false);
158
159 for (i, &len) in lengths.iter().enumerate() {
160 for j in 0..(len as usize).min(max_len) {
161 mask[[i, j]] = true;
162 }
163 }
164
165 mask
166}
167
168pub fn pad_sequences(sequences: &[Array2<f32>], pad_value: f32) -> (Array2<f32>, Array1<i64>) {
179 if sequences.is_empty() {
180 return (Array2::zeros((0, 0)), Array1::zeros(0));
181 }
182
183 let max_len = sequences.iter().map(|s| s.nrows()).max().unwrap_or(0);
184 let dim = sequences[0].ncols();
185 let batch_size = sequences.len();
186
187 let mut padded = Array2::from_elem((batch_size * max_len, dim), pad_value);
188 let mut lengths = Array1::<i64>::zeros(batch_size);
189
190 for (i, seq) in sequences.iter().enumerate() {
191 let len = seq.nrows();
192 lengths[i] = len as i64;
193 for j in 0..len {
194 for k in 0..dim {
195 padded[[i * max_len + j, k]] = seq[[j, k]];
196 }
197 }
198 }
199
200 (padded, lengths)
201}
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206
207 #[test]
208 fn test_quantile() {
209 let arr = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
210 assert!((quantile(&arr, 0.5) - 3.0).abs() < 1e-6);
211 assert!((quantile(&arr, 0.0) - 1.0).abs() < 1e-6);
212 assert!((quantile(&arr, 1.0) - 5.0).abs() < 1e-6);
213 }
214
215 #[test]
216 fn test_packbits_unpackbits() {
217 let bits = vec![1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0];
218 let packed = packbits(&bits);
219 assert_eq!(packed, vec![0b10101010, 0b11110000]);
220
221 let unpacked = unpackbits(&packed);
222 assert_eq!(unpacked, bits);
223 }
224
225 #[test]
226 fn test_normalize_rows() {
227 let arr = Array2::from_shape_vec((2, 3), vec![3.0, 0.0, 4.0, 0.0, 5.0, 0.0]).unwrap();
228 let normalized = normalize_rows(&arr);
229
230 assert!((normalized[[0, 0]] - 0.6).abs() < 1e-6);
232 assert!((normalized[[0, 2]] - 0.8).abs() < 1e-6);
233
234 assert!((normalized[[1, 1]] - 1.0).abs() < 1e-6);
236 }
237}