Skip to main content

numrs2/
unique_optimized.rs

1use crate::array::Array;
2use crate::error::{NumRs2Error, Result};
3use num_traits::Zero;
4use scirs2_core::parallel_ops::*;
5use std::collections::HashMap;
6use std::collections::HashSet;
7use std::fmt::Debug;
8use std::hash::Hash;
9use std::sync::atomic::{AtomicUsize, Ordering};
10
11/// Optimized version of the unique function to find unique elements of an array.
12///
13/// This version includes several optimizations over the standard unique function:
14/// - Parallel processing for large arrays
15/// - Pre-allocated buffers to reduce memory allocations
16/// - Early capacity estimation based on input size
17/// - SIMD-friendly memory access patterns where possible
18///
19/// # Parameters
20///
21/// * `a` - Input array
22/// * `axis` - Optional axis along which to find unique elements
23/// * `return_index` - If true, also return the indices of the first occurrences
24/// * `return_inverse` - If true, also return the indices to reconstruct the original array
25/// * `return_counts` - If true, also return the counts of each unique value
26///
27/// # Returns
28///
29/// A UniqueResult struct containing some or all of:
30/// * The unique values in the array (required)
31/// * The indices of the first unique values (if return_index is true)
32/// * The indices to reconstruct the original array (if return_inverse is true)
33/// * The counts of each unique value (if return_counts is true)
34pub fn unique_optimized<T>(
35    a: &Array<T>,
36    axis: Option<usize>,
37    return_index: Option<bool>,
38    return_inverse: Option<bool>,
39    return_counts: Option<bool>,
40) -> Result<UniqueResult<T>>
41where
42    T: Clone + Hash + Eq + Debug + Zero + Send + Sync,
43{
44    // Helper functions
45    fn estimate_capacity(array_size: usize) -> usize {
46        // Heuristic: for random data, expect about 90% of elements to be unique
47        // for small arrays, just use the full size
48        if array_size < 1000 {
49            array_size
50        } else {
51            (array_size as f64 * 0.9) as usize
52        }
53    }
54
55    // If no axis is provided, flatten the array and find unique elements
56    if axis.is_none() {
57        let flat_data = a.to_vec();
58        let array_size = flat_data.len();
59
60        // Optimize for large arrays by using parallel processing
61        if array_size > 10000 {
62            return unique_optimized_large(
63                &flat_data,
64                array_size,
65                return_index,
66                return_inverse,
67                return_counts,
68            );
69        }
70
71        // For smaller arrays, use a more efficient sequential approach with pre-allocation
72        let estimated_capacity = estimate_capacity(array_size);
73
74        // Pre-allocate with estimated capacity
75        let mut unique_elements = Vec::with_capacity(estimated_capacity);
76        let mut first_indices = if return_index.unwrap_or(false) {
77            Vec::with_capacity(estimated_capacity)
78        } else {
79            Vec::new()
80        };
81        let mut inverse_indices = if return_inverse.unwrap_or(false) {
82            vec![0; array_size]
83        } else {
84            Vec::new()
85        };
86        let need_inverse = return_inverse.unwrap_or(false);
87
88        // Use HashMap with capacity hint
89        let mut value_to_index = HashMap::with_capacity(estimated_capacity);
90
91        // Process each element
92        for (i, value) in flat_data.iter().enumerate() {
93            if let Some(&idx) = value_to_index.get(value) {
94                // Element already seen
95                if need_inverse {
96                    inverse_indices[i] = idx;
97                }
98            } else {
99                // New unique element
100                let new_idx = unique_elements.len();
101                unique_elements.push(value.clone());
102                if return_index.unwrap_or(false) {
103                    first_indices.push(i);
104                }
105                value_to_index.insert(value, new_idx);
106                if need_inverse {
107                    inverse_indices[i] = new_idx;
108                }
109            }
110        }
111
112        // Calculate counts if needed
113        let counts = if return_counts.unwrap_or(false) {
114            let mut counts_vec = vec![0; unique_elements.len()];
115            if need_inverse {
116                // If we already computed inverse indices, use them
117                for &idx in &inverse_indices {
118                    counts_vec[idx] += 1;
119                }
120            } else {
121                // Otherwise, count directly from original data
122                for value in flat_data.iter() {
123                    let idx = *value_to_index
124                        .get(value)
125                        .expect("value must exist in value_to_index map");
126                    counts_vec[idx] += 1;
127                }
128            }
129            Some(Array::from_vec(counts_vec))
130        } else {
131            None
132        };
133
134        // Construct the result
135        let unique_array = Array::from_vec(unique_elements);
136
137        return Ok(UniqueResult {
138            values: unique_array,
139            indices: if return_index.unwrap_or(false) {
140                Some(Array::from_vec(first_indices))
141            } else {
142                None
143            },
144            inverse: if return_inverse.unwrap_or(false) {
145                Some(Array::from_vec(inverse_indices))
146            } else {
147                None
148            },
149            counts,
150        });
151    }
152
153    // Process along a specific axis
154    let axis_val = axis.expect("axis must be Some at this point (None case handled above)");
155    if axis_val >= a.ndim() {
156        return Err(NumRs2Error::DimensionMismatch(format!(
157            "Axis {} out of bounds for array of dimension {}",
158            axis_val,
159            a.ndim()
160        )));
161    }
162
163    // Get the shape
164    let shape = a.shape();
165
166    // For 1D arrays, axis=0 is the same as no axis
167    if shape.len() == 1 && axis_val == 0 {
168        return unique_optimized(a, None, return_index, return_inverse, return_counts);
169    }
170
171    // For higher dimensions, we need to find unique subarrays along the specified axis
172
173    // Get the size of the axis and calculate the shape of each subarray
174    let axis_len = shape[axis_val];
175
176    // Optimize memory allocation for subarrays
177    let mut subarrays = Vec::with_capacity(axis_len);
178    let mut subarray_hashes = Vec::with_capacity(axis_len);
179
180    // Extract subarrays along the specified axis
181    for i in 0..axis_len {
182        // Get the subarray
183        let subarray = a.slice(axis_val, i)?;
184
185        // Convert to a hashable representation
186        let hash_rep = subarray.to_vec();
187
188        subarrays.push(subarray);
189        subarray_hashes.push(hash_rep);
190    }
191
192    // Estimate capacity for unique subarrays
193    let estimated_capacity = estimate_capacity(axis_len);
194
195    // Find unique subarrays with pre-allocation
196    let mut unique_indices = Vec::with_capacity(estimated_capacity);
197    let mut index_map = HashMap::with_capacity(estimated_capacity);
198    let mut inverse = if return_inverse.unwrap_or(false) {
199        vec![0; axis_len]
200    } else {
201        Vec::new()
202    };
203    let need_inverse = return_inverse.unwrap_or(false);
204    let mut seen = HashSet::with_capacity(estimated_capacity);
205
206    for i in 0..axis_len {
207        let hash_rep = &subarray_hashes[i];
208
209        if !seen.contains(hash_rep) {
210            // This is a new unique subarray
211            let idx = unique_indices.len();
212            unique_indices.push(i);
213            index_map.insert(hash_rep, idx);
214            seen.insert(hash_rep.clone());
215            if need_inverse {
216                inverse[i] = idx;
217            }
218        } else {
219            // This subarray has been seen before
220            if need_inverse {
221                let idx = *index_map
222                    .get(hash_rep)
223                    .expect("hash_rep must exist in index_map for seen subarrays");
224                inverse[i] = idx;
225            }
226        }
227    }
228
229    // Calculate counts if needed
230    let counts = if return_counts.unwrap_or(false) {
231        let mut counts_vec = vec![0; unique_indices.len()];
232        if need_inverse {
233            for &idx in &inverse {
234                counts_vec[idx] += 1;
235            }
236        } else {
237            for hash_rep in &subarray_hashes {
238                if let Some(&idx) = index_map.get(hash_rep) {
239                    counts_vec[idx] += 1;
240                }
241            }
242        }
243        Some(Array::from_vec(counts_vec))
244    } else {
245        None
246    };
247
248    // Create the output arrays
249
250    // Create a new shape for the output with the axis dimension set to the number of unique subarrays
251    let mut output_shape = shape.clone();
252    output_shape[axis_val] = unique_indices.len();
253
254    // Create the result array by concatenating the unique subarrays along the axis
255    let mut unique_subarrays = Vec::with_capacity(unique_indices.len());
256    for &idx in &unique_indices {
257        unique_subarrays.push(&subarrays[idx]);
258    }
259
260    // Use the concatenate function to join the unique subarrays
261    let values = if !unique_subarrays.is_empty() {
262        // For now, convert the subarrays to a 1D array for each unique subarray
263        // A better implementation would use proper array concatenation along the specified axis
264        let mut unique_data = Vec::new();
265        for &idx in &unique_indices {
266            unique_data.extend_from_slice(&subarray_hashes[idx]);
267        }
268        Array::from_vec(unique_data).reshape(&output_shape)
269    } else {
270        // Empty result
271        Array::zeros(&output_shape)
272    };
273
274    Ok(UniqueResult {
275        values,
276        indices: if return_index.unwrap_or(false) {
277            Some(Array::from_vec(unique_indices))
278        } else {
279            None
280        },
281        inverse: if return_inverse.unwrap_or(false) {
282            Some(Array::from_vec(inverse))
283        } else {
284            None
285        },
286        counts,
287    })
288}
289
290// Special optimized implementation for large arrays using parallel processing
291fn unique_optimized_large<T>(
292    flat_data: &[T],
293    array_size: usize,
294    return_index: Option<bool>,
295    return_inverse: Option<bool>,
296    return_counts: Option<bool>,
297) -> Result<UniqueResult<T>>
298where
299    T: Clone + Hash + Eq + Debug + Send + Sync,
300{
301    let need_index = return_index.unwrap_or(false);
302    let need_inverse = return_inverse.unwrap_or(false);
303    let need_counts = return_counts.unwrap_or(false);
304
305    // Use atomic counter for thread-safe indexing
306    let unique_counter = AtomicUsize::new(0);
307
308    // Create a shared HashMap for value-to-index mapping
309    // Using estimated capacity heuristic
310    let estimated_capacity = (array_size as f64 * 0.9) as usize;
311    let value_to_index = std::sync::RwLock::new(HashMap::with_capacity(estimated_capacity));
312
313    // First pass: identify unique elements and assign indices
314    let mut unique_elements = Vec::with_capacity(estimated_capacity);
315    let mut first_indices = if need_index {
316        Vec::with_capacity(estimated_capacity)
317    } else {
318        Vec::new()
319    };
320
321    // Using batched processing for better performance
322    let batch_size = std::cmp::max(1, array_size / scirs2_core::parallel_ops::num_threads());
323    let batches = flat_data.chunks(batch_size);
324
325    // Process each batch for unique values
326    batches.enumerate().for_each(|(batch_idx, batch)| {
327        let mut local_uniques = HashMap::new();
328        let base_index = batch_idx * batch_size;
329
330        // First pass within batch to find local unique elements
331        for (local_idx, value) in batch.iter().enumerate() {
332            let global_idx = base_index + local_idx;
333            if !local_uniques.contains_key(value) {
334                local_uniques.insert(value.clone(), global_idx);
335            }
336        }
337
338        // Second pass: merge with global unique set
339        let mut value_map = value_to_index
340            .write()
341            .expect("value_to_index RwLock poisoned: failed to acquire write lock");
342        for (value, local_first_idx) in local_uniques {
343            if !value_map.contains_key(&value) {
344                let new_idx = unique_counter.fetch_add(1, Ordering::SeqCst);
345                value_map.insert(value.clone(), new_idx);
346
347                // This is thread-safe because each value is processed only by the thread that first discovers it
348                synchronized_push(&mut unique_elements, value);
349                if need_index {
350                    synchronized_push(&mut first_indices, local_first_idx);
351                }
352            }
353        }
354    });
355
356    // Create inverse indices if needed
357    let inverse_indices = if need_inverse {
358        let value_map = value_to_index
359            .read()
360            .expect("value_to_index RwLock poisoned: failed to acquire read lock");
361        flat_data
362            .par_iter()
363            .map(|value| {
364                *value_map
365                    .get(value)
366                    .expect("value must exist in value_map for inverse indices")
367            })
368            .collect()
369    } else {
370        Vec::new()
371    };
372
373    // Calculate counts if needed
374    let counts = if need_counts {
375        let mut counts_vec = vec![0; unique_elements.len()];
376
377        if need_inverse {
378            // If we already have inverse indices, use them
379            for &idx in &inverse_indices {
380                counts_vec[idx] += 1;
381            }
382        } else {
383            // Otherwise, count directly using value map
384            let value_map = value_to_index
385                .read()
386                .expect("value_to_index RwLock poisoned: failed to acquire read lock");
387
388            // Use thread-local counters and then merge
389            let local_counts = flat_data
390                .par_iter()
391                .map(|value| {
392                    let idx = *value_map
393                        .get(value)
394                        .expect("value must exist in value_map for counting");
395                    (idx, 1)
396                })
397                .collect::<Vec<(usize, usize)>>();
398
399            // Aggregate counts
400            for (idx, count) in local_counts {
401                counts_vec[idx] += count;
402            }
403        }
404
405        Some(Array::from_vec(counts_vec))
406    } else {
407        None
408    };
409
410    // Construct the result
411    let unique_array = Array::from_vec(unique_elements);
412
413    Ok(UniqueResult {
414        values: unique_array,
415        indices: if need_index {
416            Some(Array::from_vec(first_indices))
417        } else {
418            None
419        },
420        inverse: if need_inverse {
421            Some(Array::from_vec(inverse_indices))
422        } else {
423            None
424        },
425        counts,
426    })
427}
428
429// Helper function for thread-safe vector push
430fn synchronized_push<T: Send + Clone>(vec: &mut Vec<T>, value: T) {
431    // For simplicity, using a mutex, but a lock-free implementation would be better
432    // in a production environment
433    let mutex = std::sync::Mutex::new(());
434    let guard = mutex.lock().expect("Mutex poisoned in synchronized_push");
435    vec.push(value);
436    drop(guard);
437}
438
439// Helper function for atomic increment (unused but kept for reference)
440// fn synchronized_increment(vec: &mut [usize], idx: usize) {
441//     // Using atomic operations for thread safety
442//     let ptr = &mut vec[idx] as *mut usize;
443//     // This is safe because we're incrementing a single unique index per thread
444//     unsafe {
445//         let atomic_ref = AtomicUsize::from_ptr(ptr);
446//         atomic_ref.fetch_add(1, Ordering::Relaxed);
447//     }
448// }
449
450/// Output type for unique function to handle variable return types
451pub struct UniqueResult<T> {
452    pub values: Array<T>,
453    pub indices: Option<Array<usize>>,
454    pub inverse: Option<Array<usize>>,
455    pub counts: Option<Array<usize>>,
456}
457
458impl<T: Clone> UniqueResult<T> {
459    /// Get the unique values only
460    pub fn values(self) -> Array<T> {
461        self.values
462    }
463
464    /// Get a tuple of (values, indices) if indices were requested
465    pub fn values_indices(self) -> Result<(Array<T>, Array<usize>)> {
466        match self.indices {
467            Some(indices) => Ok((self.values, indices)),
468            None => Err(NumRs2Error::InvalidOperation(
469                "indices were not requested in the unique call".to_string(),
470            )),
471        }
472    }
473
474    /// Get a tuple of (values, inverse) if inverse was requested
475    pub fn values_inverse(self) -> Result<(Array<T>, Array<usize>)> {
476        match self.inverse {
477            Some(inverse) => Ok((self.values, inverse)),
478            None => Err(NumRs2Error::InvalidOperation(
479                "inverse was not requested in the unique call".to_string(),
480            )),
481        }
482    }
483
484    /// Get a tuple of (values, counts) if counts were requested
485    pub fn values_counts(self) -> Result<(Array<T>, Array<usize>)> {
486        match self.counts {
487            Some(counts) => Ok((self.values, counts)),
488            None => Err(NumRs2Error::InvalidOperation(
489                "counts were not requested in the unique call".to_string(),
490            )),
491        }
492    }
493
494    /// Get a tuple of (values, indices, inverse) if both were requested
495    pub fn values_indices_inverse(self) -> Result<(Array<T>, Array<usize>, Array<usize>)> {
496        match (self.indices, self.inverse) {
497            (Some(indices), Some(inverse)) => Ok((self.values, indices, inverse)),
498            _ => Err(NumRs2Error::InvalidOperation(
499                "either indices or inverse were not requested in the unique call".to_string(),
500            )),
501        }
502    }
503
504    /// Get a tuple of (values, indices, counts) if both were requested
505    pub fn values_indices_counts(self) -> Result<(Array<T>, Array<usize>, Array<usize>)> {
506        match (self.indices, self.counts) {
507            (Some(indices), Some(counts)) => Ok((self.values, indices, counts)),
508            _ => Err(NumRs2Error::InvalidOperation(
509                "either indices or counts were not requested in the unique call".to_string(),
510            )),
511        }
512    }
513
514    /// Get a tuple of (values, inverse, counts) if both were requested
515    pub fn values_inverse_counts(self) -> Result<(Array<T>, Array<usize>, Array<usize>)> {
516        match (self.inverse, self.counts) {
517            (Some(inverse), Some(counts)) => Ok((self.values, inverse, counts)),
518            _ => Err(NumRs2Error::InvalidOperation(
519                "either inverse or counts were not requested in the unique call".to_string(),
520            )),
521        }
522    }
523
524    /// Get a tuple of (values, indices, inverse, counts) if all were requested
525    pub fn values_indices_inverse_counts(self) -> Result<crate::unique::UniqueTuple<T>> {
526        match (self.indices, self.inverse, self.counts) {
527            (Some(indices), Some(inverse), Some(counts)) => {
528                Ok((self.values, indices, inverse, counts))
529            }
530            _ => Err(NumRs2Error::InvalidOperation(
531                "not all of indices, inverse, and counts were requested in the unique call"
532                    .to_string(),
533            )),
534        }
535    }
536}