Skip to main content

next_plaid/
utils.rs

1//! Utility functions for next-plaid
2
3use ndarray::{Array1, Array2, ArrayView1, Axis};
4
5/// Compute the k-th quantile of a 1D array using linear interpolation.
6///
7/// # Arguments
8///
9/// * `arr` - Input array (will be sorted)
10/// * `q` - Quantile to compute (between 0.0 and 1.0)
11///
12/// # Returns
13///
14/// The quantile value
15pub fn quantile(arr: &Array1<f32>, q: f64) -> f32 {
16    if arr.is_empty() {
17        return 0.0;
18    }
19
20    let mut sorted: Vec<f32> = arr.iter().copied().collect();
21    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
22
23    let n = sorted.len();
24    let idx_float = q * (n - 1) as f64;
25    let lower_idx = idx_float.floor() as usize;
26    let upper_idx = idx_float.ceil() as usize;
27
28    if lower_idx == upper_idx {
29        sorted[lower_idx]
30    } else {
31        let weight = (idx_float - lower_idx as f64) as f32;
32        sorted[lower_idx] * (1.0 - weight) + sorted[upper_idx] * weight
33    }
34}
35
36/// Compute multiple quantiles efficiently.
37///
38/// # Arguments
39///
40/// * `arr` - Input array
41/// * `quantiles` - Array of quantiles to compute
42///
43/// # Returns
44///
45/// Array of quantile values
46pub fn quantiles(arr: &Array1<f32>, qs: &[f64]) -> Vec<f32> {
47    if arr.is_empty() {
48        return vec![0.0; qs.len()];
49    }
50
51    let mut sorted: Vec<f32> = arr.iter().copied().collect();
52    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
53
54    let n = sorted.len();
55
56    qs.iter()
57        .map(|&q| {
58            let idx_float = q * (n - 1) as f64;
59            let lower_idx = idx_float.floor() as usize;
60            let upper_idx = idx_float.ceil() as usize;
61
62            if lower_idx == upper_idx {
63                sorted[lower_idx]
64            } else {
65                let weight = (idx_float - lower_idx as f64) as f32;
66                sorted[lower_idx] * (1.0 - weight) + sorted[upper_idx] * weight
67            }
68        })
69        .collect()
70}
71
72/// Normalize rows of a 2D array to unit length.
73///
74/// # Arguments
75///
76/// * `arr` - Input array of shape `[N, dim]`
77///
78/// # Returns
79///
80/// Normalized array
81pub fn normalize_rows(arr: &Array2<f32>) -> Array2<f32> {
82    let mut result = arr.clone();
83    for mut row in result.axis_iter_mut(Axis(0)) {
84        let norm = row.dot(&row).sqrt().max(1e-12);
85        row /= norm;
86    }
87    result
88}
89
90/// Compute L2 norm of each row.
91///
92/// # Arguments
93///
94/// * `arr` - Input array of shape `[N, dim]`
95///
96/// # Returns
97///
98/// Array of norms of shape `[N]`
99pub fn row_norms(arr: &Array2<f32>) -> Array1<f32> {
100    arr.axis_iter(Axis(0))
101        .map(|row| row.dot(&row).sqrt())
102        .collect()
103}
104
105/// Pack bits into bytes (big-endian).
106///
107/// # Arguments
108///
109/// * `bits` - Array of bits (0 or 1)
110///
111/// # Returns
112///
113/// Packed bytes
114pub fn packbits(bits: &[u8]) -> Vec<u8> {
115    bits.chunks(8)
116        .map(|chunk| {
117            let mut byte = 0u8;
118            for (i, &bit) in chunk.iter().enumerate() {
119                byte |= bit << (7 - i);
120            }
121            byte
122        })
123        .collect()
124}
125
126/// Unpack bytes into bits (big-endian).
127///
128/// # Arguments
129///
130/// * `bytes` - Packed bytes
131///
132/// # Returns
133///
134/// Unpacked bits
135pub fn unpackbits(bytes: &[u8]) -> Vec<u8> {
136    let mut bits = Vec::with_capacity(bytes.len() * 8);
137    for &byte in bytes {
138        for i in (0..8).rev() {
139            bits.push((byte >> i) & 1);
140        }
141    }
142    bits
143}
144
145/// Create a boolean mask from sequence lengths.
146///
147/// # Arguments
148///
149/// * `lengths` - Array of sequence lengths
150/// * `max_len` - Maximum sequence length
151///
152/// # Returns
153///
154/// Boolean mask of shape `[batch_size, max_len]`
155pub fn create_mask(lengths: &ArrayView1<i64>, max_len: usize) -> Array2<bool> {
156    let batch_size = lengths.len();
157    let mut mask = Array2::from_elem((batch_size, max_len), false);
158
159    for (i, &len) in lengths.iter().enumerate() {
160        for j in 0..(len as usize).min(max_len) {
161            mask[[i, j]] = true;
162        }
163    }
164
165    mask
166}
167
168/// Pad sequences to uniform length.
169///
170/// # Arguments
171///
172/// * `sequences` - List of sequence arrays
173/// * `pad_value` - Value to use for padding
174///
175/// # Returns
176///
177/// Tuple of (padded array, lengths)
178pub fn pad_sequences(sequences: &[Array2<f32>], pad_value: f32) -> (Array2<f32>, Array1<i64>) {
179    if sequences.is_empty() {
180        return (Array2::zeros((0, 0)), Array1::zeros(0));
181    }
182
183    let max_len = sequences.iter().map(|s| s.nrows()).max().unwrap_or(0);
184    let dim = sequences[0].ncols();
185    let batch_size = sequences.len();
186
187    let mut padded = Array2::from_elem((batch_size * max_len, dim), pad_value);
188    let mut lengths = Array1::<i64>::zeros(batch_size);
189
190    for (i, seq) in sequences.iter().enumerate() {
191        let len = seq.nrows();
192        lengths[i] = len as i64;
193        for j in 0..len {
194            for k in 0..dim {
195                padded[[i * max_len + j, k]] = seq[[j, k]];
196            }
197        }
198    }
199
200    (padded, lengths)
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206
207    #[test]
208    fn test_quantile() {
209        let arr = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
210        assert!((quantile(&arr, 0.5) - 3.0).abs() < 1e-6);
211        assert!((quantile(&arr, 0.0) - 1.0).abs() < 1e-6);
212        assert!((quantile(&arr, 1.0) - 5.0).abs() < 1e-6);
213    }
214
215    #[test]
216    fn test_packbits_unpackbits() {
217        let bits = vec![1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0];
218        let packed = packbits(&bits);
219        assert_eq!(packed, vec![0b10101010, 0b11110000]);
220
221        let unpacked = unpackbits(&packed);
222        assert_eq!(unpacked, bits);
223    }
224
225    #[test]
226    fn test_normalize_rows() {
227        let arr = Array2::from_shape_vec((2, 3), vec![3.0, 0.0, 4.0, 0.0, 5.0, 0.0]).unwrap();
228        let normalized = normalize_rows(&arr);
229
230        // First row: [3, 0, 4] / 5 = [0.6, 0, 0.8]
231        assert!((normalized[[0, 0]] - 0.6).abs() < 1e-6);
232        assert!((normalized[[0, 2]] - 0.8).abs() < 1e-6);
233
234        // Second row: [0, 5, 0] / 5 = [0, 1, 0]
235        assert!((normalized[[1, 1]] - 1.0).abs() < 1e-6);
236    }
237}