parallel_operations/
lib.rs

1use rayon::prelude::*;
2
3/// Gets the initial value for a binary operation.
4///
5/// This function determines the initial value based on the result of the operation
6/// when applied to two sample values. It is used to determine the initial value
7/// for parallel binary operations.
8///
9/// # Parameters
10/// - `operation`: A closure that takes two operands of type `T` and returns a result of type `T`.
11///
12/// # Returns
13/// The initial value for the binary operation based on the sample result.
14/// For now either 0 or 1.
15fn get_initial_value<T>(operation: fn(T, T) -> T) -> T
16where
17    T: Copy + Send + Sync + 'static + Default + PartialEq + From<u8>,
18{
19    let test_result = operation(T::from(8), T::from(8));
20    match test_result {
21        _ if test_result == T::from(16) => T::from(0), // For addition, use 0 as initial value
22        _ if test_result == T::from(64) => T::from(1), // For multiplication, use 1 as initial value
23        _ if test_result == T::from(0) => T::from(0),  // For subtraction, use 0 as initial value
24        _ if test_result == T::from(1) => T::from(1),  // For division, use 1 as initial value
25        _ => T::default(),                             // Default case
26    }
27}
28
29/// Performs a parallel binary operation on a vector of data.
30///
31/// This function divides the data into chunks, processes each chunk in parallel using
32/// multiple threads, and combines the results using the provided binary operation.
33///
34/// # Parameters
35/// - `data`: A vector of type `T` that contains the data to operate on.
36/// - `operation`: A closure that takes two operands of type `T` and returns a result of type `T`.
37///
38/// # Returns
39/// The result of applying the binary operation to all elements of the vector.
40///
41pub fn parallel_binary_operation<T>(data: Vec<T>, operation: fn(T, T) -> T) -> T
42where
43    T: Copy + Send + Sync + 'static + Default + PartialEq + From<u8>,
44{
45    if data.is_empty() {
46        return T::default();
47    }
48    if data.len() == 1 {
49        return data[0];
50    }
51
52    let initial = get_initial_value(operation);
53
54    let threads = num_cpus::get(); // Automatically use the number of available cores
55    let chunk_size = (data.len() + threads - 1) / threads;
56
57    // Perform the operation in parallel across chunks of data
58    data.par_chunks(chunk_size)
59        .map(|chunk| chunk.iter().copied().fold(initial, |a, b| operation(a, b)))
60        .reduce(|| initial, |a, b| operation(a, b)) // Reduce results using operation
61}
62
63#[cfg(test)]
64mod tests {
65    use super::*; // Import the public functions for testing
66
67    // Test for addition operation
68    #[test]
69    fn test_parallel_addition() {
70        let data = vec![1, 2, 3, 4, 5];
71        let result = parallel_binary_operation(data, |a, b| a + b);
72        assert_eq!(result, 15); // Expected result: 1 + 2 + 3 + 4 + 5 = 15
73    }
74
75    // Test for multiplication operation
76    #[test]
77    fn test_parallel_multiplication() {
78        let data = vec![1, 2, 3, 4, 5];
79        let result = parallel_binary_operation(data, |a, b| a * b);
80        assert_eq!(result, 120); // Expected result: 1 * 2 * 3 * 4 * 5 = 120
81    }
82
83    // Test for single element vector
84    #[test]
85    fn test_single_element() {
86        let data = vec![42];
87        let result = parallel_binary_operation(data, |a, b| a + b);
88        assert_eq!(result, 42); // Only one element, should return that element
89    }
90
91    // Test for empty vector
92    #[test]
93    fn test_empty_vector() {
94        let data: Vec<i32> = Vec::new();
95        let result = parallel_binary_operation(data, |a, b| a + b);
96        assert_eq!(result, 0); // Empty vector, result should be 0
97    }
98
99    // Test for odd number of elements (to check how chunking works)
100    #[test]
101    fn test_odd_number_of_elements() {
102        let data = vec![1, 2, 3, 4, 5];
103        let result = parallel_binary_operation(data, |a, b| a + b);
104        assert_eq!(result, 15); // 1 + 2 + 3 + 4 + 5 = 15
105    }
106}