1use ndarray::{Array2, ArrayView2};
4use rayon::prelude::*;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum PerformanceHint {
9 CacheFriendly,
11 Vectorize,
13 Parallel,
15 GpuAccelerated,
17 LowLatency,
19 HighThroughput,
21}
22
23#[derive(Debug, Clone, Default)]
25pub struct PerformanceHints {
26 hints: Vec<PerformanceHint>,
27}
28
29impl PerformanceHints {
30 pub fn new() -> Self {
32 Self::default()
33 }
34
35 pub fn with_hint(mut self, hint: PerformanceHint) -> Self {
37 self.hints.push(hint);
38 self
39 }
40
41 pub fn has_hint(&self, hint: PerformanceHint) -> bool {
43 self.hints.contains(&hint)
44 }
45
46 pub fn hints(&self) -> &[PerformanceHint] {
48 &self.hints
49 }
50}
51
52pub struct AudioKNN {
54 k_neighbors: usize,
55}
56
57impl AudioKNN {
58 pub fn new(k_neighbors: usize) -> Self {
60 Self { k_neighbors }
61 }
62
63 pub fn find_neighbors(
66 &self,
67 query: ArrayView2<f32>,
68 data: ArrayView2<f32>,
69 ) -> Vec<Vec<(usize, f32)>> {
70 let n_queries = query.nrows();
71 let n_data = data.nrows();
72
73 (0..n_queries)
74 .into_par_iter()
75 .map(|q_idx| {
76 let query_point = query.row(q_idx);
77 let mut distances: Vec<(usize, f32)> = (0..n_data)
78 .map(|d_idx| {
79 let data_point = data.row(d_idx);
80 let dist = Self::euclidean_distance(query_point, data_point);
81 (d_idx, dist)
82 })
83 .collect();
84
85 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
87 distances.truncate(self.k_neighbors);
88 distances
89 })
90 .collect()
91 }
92
93 #[inline(always)]
95 fn euclidean_distance(a: ndarray::ArrayView1<f32>, b: ndarray::ArrayView1<f32>) -> f32 {
96 a.iter()
97 .zip(b.iter())
98 .map(|(x, y)| {
99 let diff = x - y;
100 diff * diff
101 })
102 .sum::<f32>()
103 .sqrt()
104 }
105
106 pub fn find_similar_segments(
108 &self,
109 segment: ArrayView2<f32>,
110 audio: ArrayView2<f32>,
111 hop_size: usize,
112 ) -> Vec<(usize, f32)> {
113 let segment_len = segment.ncols();
114 let audio_len = audio.ncols();
115
116 if segment_len > audio_len {
117 return Vec::new();
118 }
119
120 let num_windows = (audio_len - segment_len) / hop_size + 1;
121
122 let mut similarities: Vec<(usize, f32)> = (0..num_windows)
123 .into_par_iter()
124 .map(|i| {
125 let start = i * hop_size;
126 let end = start + segment_len;
127 let window = audio.slice(ndarray::s![.., start..end]);
128
129 let distance = Self::matrix_distance(&segment, &window);
131 (start, distance)
132 })
133 .collect();
134
135 similarities.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
137 similarities.truncate(self.k_neighbors);
138 similarities
139 }
140
141 fn matrix_distance(a: &ArrayView2<f32>, b: &ArrayView2<f32>) -> f32 {
143 a.iter()
144 .zip(b.iter())
145 .map(|(x, y)| {
146 let diff = x - y;
147 diff * diff
148 })
149 .sum::<f32>()
150 .sqrt()
151 }
152}
153
154pub struct BatchProcessor {
156 #[allow(dead_code)]
157 batch_size: usize,
158 num_threads: Option<usize>,
159}
160
161impl BatchProcessor {
162 pub fn new(batch_size: usize) -> Self {
164 Self {
165 batch_size,
166 num_threads: None,
167 }
168 }
169
170 pub fn with_threads(mut self, num_threads: usize) -> Self {
172 self.num_threads = Some(num_threads);
173 self
174 }
175
176 pub fn process<T, F, R>(&self, items: Vec<T>, f: F) -> Vec<R>
178 where
179 T: Send,
180 F: Fn(T) -> R + Send + Sync,
181 R: Send,
182 {
183 if let Some(threads) = self.num_threads {
184 rayon::ThreadPoolBuilder::new()
185 .num_threads(threads)
186 .build()
187 .unwrap()
188 .install(|| items.into_par_iter().map(f).collect())
189 } else {
190 items.into_par_iter().map(f).collect()
191 }
192 }
193
194 pub fn process_chunks<F>(&self, data: &Array2<f32>, chunk_size: usize, f: F) -> Vec<Array2<f32>>
196 where
197 F: Fn(ArrayView2<f32>) -> Array2<f32> + Send + Sync,
198 {
199 let n_samples = data.ncols();
200 let n_chunks = n_samples.div_ceil(chunk_size);
201
202 (0..n_chunks)
203 .into_par_iter()
204 .map(|i| {
205 let start = i * chunk_size;
206 let end = (start + chunk_size).min(n_samples);
207 let chunk = data.slice(ndarray::s![.., start..end]);
208 f(chunk)
209 })
210 .collect()
211 }
212}
213
214pub struct SimdOps;
216
217impl SimdOps {
218 #[inline]
220 pub fn multiply(a: &mut [f32], b: &[f32]) {
221 assert_eq!(a.len(), b.len());
222 a.iter_mut().zip(b.iter()).for_each(|(x, y)| *x *= y);
223 }
224
225 #[inline]
227 pub fn add(a: &mut [f32], b: &[f32]) {
228 assert_eq!(a.len(), b.len());
229 a.iter_mut().zip(b.iter()).for_each(|(x, y)| *x += y);
230 }
231
232 #[inline]
234 pub fn dot(a: &[f32], b: &[f32]) -> f32 {
235 assert_eq!(a.len(), b.len());
236 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
237 }
238
239 #[inline]
241 pub fn rms(data: &[f32]) -> f32 {
242 let sum_squares: f32 = data.iter().map(|x| x * x).sum();
243 (sum_squares / data.len() as f32).sqrt()
244 }
245
246 pub fn find_peaks(data: &[f32], threshold: f32) -> Vec<usize> {
248 data.windows(3)
249 .enumerate()
250 .filter_map(|(i, window)| {
251 if window[1] > threshold && window[1] > window[0] && window[1] > window[2] {
252 Some(i + 1)
253 } else {
254 None
255 }
256 })
257 .collect()
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264 use approx::assert_abs_diff_eq;
265 use ndarray::Array2;
266
267 #[test]
268 fn test_performance_hints() {
269 let hints = PerformanceHints::new()
270 .with_hint(PerformanceHint::Parallel)
271 .with_hint(PerformanceHint::Vectorize);
272
273 assert!(hints.has_hint(PerformanceHint::Parallel));
274 assert!(hints.has_hint(PerformanceHint::Vectorize));
275 assert!(!hints.has_hint(PerformanceHint::GpuAccelerated));
276 }
277
278 #[test]
279 fn test_audio_knn() {
280 let data = Array2::from_shape_vec(
281 (5, 3),
282 vec![
283 1.0, 2.0, 3.0, 2.0, 3.0, 4.0, 3.0, 4.0, 5.0, 4.0, 5.0, 6.0, 5.0, 6.0, 7.0,
284 ],
285 )
286 .unwrap();
287
288 let query = Array2::from_shape_vec((1, 3), vec![2.5, 3.5, 4.5]).unwrap();
289
290 let knn = AudioKNN::new(3);
291 let neighbors = knn.find_neighbors(query.view(), data.view());
292
293 assert_eq!(neighbors.len(), 1);
294 assert_eq!(neighbors[0].len(), 3);
295 }
296
297 #[test]
298 fn test_simd_ops() {
299 let mut a = vec![1.0, 2.0, 3.0, 4.0];
300 let b = vec![2.0, 2.0, 2.0, 2.0];
301
302 SimdOps::multiply(&mut a, &b);
303 assert_eq!(a, vec![2.0, 4.0, 6.0, 8.0]);
304
305 let rms = SimdOps::rms(&[1.0, 2.0, 3.0, 4.0]);
306 assert_abs_diff_eq!(rms, 2.7386, epsilon = 0.001);
307 }
308
309 #[test]
310 fn test_batch_processor() {
311 let processor = BatchProcessor::new(10);
312 let items: Vec<i32> = (0..100).collect();
313 let results = processor.process(items, |x| x * 2);
314
315 assert_eq!(results.len(), 100);
316 assert_eq!(results[0], 0);
317 assert_eq!(results[99], 198);
318 }
319}