Skip to main content

trueno/vector/ops/reductions/
mod.rs

1//! Reduction operations for Vector<f32>
2//!
3//! This module provides reduction operations that aggregate vector elements:
4//! - Basic: `sum`, `dot`, `max`, `min`
5//! - Index-finding: `argmax`, `argmin`
6//! - Statistical: `mean`, `variance`, `stddev`, `covariance`, `correlation`
7//! - Numerically stable: `sum_kahan`, `sum_of_squares`
8
9mod stats;
10#[cfg(test)]
11mod tests;
12
13#[cfg(target_arch = "x86_64")]
14use crate::backends::avx2::Avx2Backend;
15#[cfg(target_arch = "x86_64")]
16use crate::backends::avx512::Avx512Backend;
17#[cfg(any(target_arch = "aarch64", target_arch = "arm"))]
18use crate::backends::neon::NeonBackend;
19use crate::backends::scalar::ScalarBackend;
20#[cfg(target_arch = "x86_64")]
21use crate::backends::sse2::Sse2Backend;
22#[cfg(target_arch = "wasm32")]
23use crate::backends::wasm::WasmBackend;
24use crate::backends::VectorBackend;
25use crate::vector::Vector;
26use crate::{dispatch_reduction, Backend, Result, TruenoError};
27
28impl Vector<f32> {
29    /// Dot product
30    ///
31    /// # Examples
32    ///
33    /// ```
34    /// use trueno::Vector;
35    ///
36    /// let a = Vector::from_slice(&[1.0, 2.0, 3.0]);
37    /// let b = Vector::from_slice(&[4.0, 5.0, 6.0]);
38    /// let result = a.dot(&b)?;
39    ///
40    /// assert_eq!(result, 32.0); // 1*4 + 2*5 + 3*6 = 4 + 10 + 18 = 32
41    /// # Ok::<(), trueno::TruenoError>(())
42    /// ```
43    pub fn dot(&self, other: &Self) -> Result<f32> {
44        if self.len() != other.len() {
45            return Err(TruenoError::SizeMismatch { expected: self.len(), actual: other.len() });
46        }
47
48        // SAFETY: Unsafe block delegates to backend implementation which maintains safety invariants
49        let result = unsafe {
50            match self.backend {
51                Backend::Scalar => ScalarBackend::dot(&self.data, &other.data),
52                #[cfg(target_arch = "x86_64")]
53                Backend::SSE2 | Backend::AVX => Sse2Backend::dot(&self.data, &other.data),
54                #[cfg(target_arch = "x86_64")]
55                Backend::AVX2 => Avx2Backend::dot(&self.data, &other.data),
56                #[cfg(target_arch = "x86_64")]
57                Backend::AVX512 => Avx512Backend::dot(&self.data, &other.data),
58                #[cfg(not(target_arch = "x86_64"))]
59                Backend::SSE2 | Backend::AVX | Backend::AVX2 | Backend::AVX512 => {
60                    ScalarBackend::dot(&self.data, &other.data)
61                }
62                #[cfg(any(target_arch = "aarch64", target_arch = "arm"))]
63                Backend::NEON => NeonBackend::dot(&self.data, &other.data),
64                #[cfg(not(any(target_arch = "aarch64", target_arch = "arm")))]
65                Backend::NEON => ScalarBackend::dot(&self.data, &other.data),
66                #[cfg(target_arch = "wasm32")]
67                Backend::WasmSIMD => WasmBackend::dot(&self.data, &other.data),
68                #[cfg(not(target_arch = "wasm32"))]
69                Backend::WasmSIMD => ScalarBackend::dot(&self.data, &other.data),
70                Backend::GPU | Backend::Auto => ScalarBackend::dot(&self.data, &other.data),
71            }
72        };
73
74        Ok(result)
75    }
76
77    /// Sum all elements
78    ///
79    /// # Examples
80    ///
81    /// ```
82    /// use trueno::Vector;
83    ///
84    /// let v = Vector::from_slice(&[1.0, 2.0, 3.0, 4.0]);
85    /// assert_eq!(v.sum()?, 10.0);
86    /// # Ok::<(), trueno::TruenoError>(())
87    /// ```
88    pub fn sum(&self) -> Result<f32> {
89        Ok(dispatch_reduction!(self.backend, sum, &self.data))
90    }
91
92    /// Find maximum element
93    ///
94    /// # Examples
95    ///
96    /// ```
97    /// use trueno::Vector;
98    ///
99    /// let v = Vector::from_slice(&[1.0, 5.0, 3.0, 2.0]);
100    /// assert_eq!(v.max()?, 5.0);
101    /// # Ok::<(), trueno::TruenoError>(())
102    /// ```
103    ///
104    /// # Errors
105    ///
106    /// Returns [`TruenoError::InvalidInput`] if vector is empty.
107    pub fn max(&self) -> Result<f32> {
108        if self.data.is_empty() {
109            return Err(TruenoError::InvalidInput("Empty vector".to_string()));
110        }
111
112        // SAFETY: Unsafe block delegates to backend implementation which maintains safety invariants
113        let result = unsafe {
114            match self.backend {
115                Backend::Scalar => ScalarBackend::max(&self.data),
116                #[cfg(target_arch = "x86_64")]
117                Backend::SSE2 | Backend::AVX => Sse2Backend::max(&self.data),
118                #[cfg(target_arch = "x86_64")]
119                Backend::AVX2 | Backend::AVX512 => Avx2Backend::max(&self.data),
120                #[cfg(not(target_arch = "x86_64"))]
121                Backend::SSE2 | Backend::AVX | Backend::AVX2 | Backend::AVX512 => {
122                    ScalarBackend::max(&self.data)
123                }
124                #[cfg(any(target_arch = "aarch64", target_arch = "arm"))]
125                Backend::NEON => NeonBackend::max(&self.data),
126                #[cfg(not(any(target_arch = "aarch64", target_arch = "arm")))]
127                Backend::NEON => ScalarBackend::max(&self.data),
128                #[cfg(target_arch = "wasm32")]
129                Backend::WasmSIMD => WasmBackend::max(&self.data),
130                #[cfg(not(target_arch = "wasm32"))]
131                Backend::WasmSIMD => ScalarBackend::max(&self.data),
132                Backend::GPU | Backend::Auto => ScalarBackend::max(&self.data),
133            }
134        };
135
136        Ok(result)
137    }
138
139    /// Find minimum value in the vector
140    ///
141    /// Returns the smallest element in the vector using SIMD optimization.
142    ///
143    /// # Examples
144    ///
145    /// ```
146    /// use trueno::Vector;
147    ///
148    /// let v = Vector::from_slice(&[1.0, 5.0, 3.0, 2.0]);
149    /// assert_eq!(v.min()?, 1.0);
150    /// # Ok::<(), trueno::TruenoError>(())
151    /// ```
152    ///
153    /// # Errors
154    ///
155    /// Returns [`TruenoError::InvalidInput`] if vector is empty.
156    pub fn min(&self) -> Result<f32> {
157        if self.data.is_empty() {
158            return Err(TruenoError::InvalidInput("Empty vector".to_string()));
159        }
160
161        // SAFETY: Unsafe block delegates to backend implementation which maintains safety invariants
162        let result = unsafe {
163            match self.backend {
164                Backend::Scalar => ScalarBackend::min(&self.data),
165                #[cfg(target_arch = "x86_64")]
166                Backend::SSE2 | Backend::AVX => Sse2Backend::min(&self.data),
167                #[cfg(target_arch = "x86_64")]
168                Backend::AVX2 | Backend::AVX512 => Avx2Backend::min(&self.data),
169                #[cfg(not(target_arch = "x86_64"))]
170                Backend::SSE2 | Backend::AVX | Backend::AVX2 | Backend::AVX512 => {
171                    ScalarBackend::min(&self.data)
172                }
173                #[cfg(any(target_arch = "aarch64", target_arch = "arm"))]
174                Backend::NEON => NeonBackend::min(&self.data),
175                #[cfg(not(any(target_arch = "aarch64", target_arch = "arm")))]
176                Backend::NEON => ScalarBackend::min(&self.data),
177                #[cfg(target_arch = "wasm32")]
178                Backend::WasmSIMD => WasmBackend::min(&self.data),
179                #[cfg(not(target_arch = "wasm32"))]
180                Backend::WasmSIMD => ScalarBackend::min(&self.data),
181                Backend::GPU | Backend::Auto => ScalarBackend::min(&self.data),
182            }
183        };
184
185        Ok(result)
186    }
187
188    /// Find index of maximum value in the vector
189    ///
190    /// Returns the index of the first occurrence of the maximum value using SIMD optimization.
191    ///
192    /// # Examples
193    ///
194    /// ```
195    /// use trueno::Vector;
196    ///
197    /// let v = Vector::from_slice(&[1.0, 5.0, 3.0, 2.0]);
198    /// assert_eq!(v.argmax()?, 1); // max value 5.0 is at index 1
199    /// # Ok::<(), trueno::TruenoError>(())
200    /// ```
201    ///
202    /// # Errors
203    ///
204    /// Returns [`TruenoError::InvalidInput`] if vector is empty.
205    pub fn argmax(&self) -> Result<usize> {
206        if self.data.is_empty() {
207            return Err(TruenoError::InvalidInput("Empty vector".to_string()));
208        }
209
210        // SAFETY: Unsafe block delegates to backend implementation which maintains safety invariants
211        let result = unsafe {
212            match self.backend {
213                Backend::Scalar => ScalarBackend::argmax(&self.data),
214                #[cfg(target_arch = "x86_64")]
215                Backend::SSE2 | Backend::AVX => Sse2Backend::argmax(&self.data),
216                #[cfg(target_arch = "x86_64")]
217                Backend::AVX2 | Backend::AVX512 => Avx2Backend::argmax(&self.data),
218                #[cfg(not(target_arch = "x86_64"))]
219                Backend::SSE2 | Backend::AVX | Backend::AVX2 | Backend::AVX512 => {
220                    ScalarBackend::argmax(&self.data)
221                }
222                #[cfg(any(target_arch = "aarch64", target_arch = "arm"))]
223                Backend::NEON => NeonBackend::argmax(&self.data),
224                #[cfg(not(any(target_arch = "aarch64", target_arch = "arm")))]
225                Backend::NEON => ScalarBackend::argmax(&self.data),
226                #[cfg(target_arch = "wasm32")]
227                Backend::WasmSIMD => WasmBackend::argmax(&self.data),
228                #[cfg(not(target_arch = "wasm32"))]
229                Backend::WasmSIMD => ScalarBackend::argmax(&self.data),
230                Backend::GPU | Backend::Auto => ScalarBackend::argmax(&self.data),
231            }
232        };
233
234        Ok(result)
235    }
236
237    /// Find index of minimum value in the vector
238    ///
239    /// Returns the index of the first occurrence of the minimum value using SIMD optimization.
240    ///
241    /// # Examples
242    ///
243    /// ```
244    /// use trueno::Vector;
245    ///
246    /// let v = Vector::from_slice(&[1.0, 5.0, 3.0, 2.0]);
247    /// assert_eq!(v.argmin()?, 0); // min value 1.0 is at index 0
248    /// # Ok::<(), trueno::TruenoError>(())
249    /// ```
250    ///
251    /// # Errors
252    ///
253    /// Returns [`TruenoError::InvalidInput`] if vector is empty.
254    pub fn argmin(&self) -> Result<usize> {
255        if self.data.is_empty() {
256            return Err(TruenoError::InvalidInput("Empty vector".to_string()));
257        }
258
259        // SAFETY: Unsafe block delegates to backend implementation which maintains safety invariants
260        let result = unsafe {
261            match self.backend {
262                Backend::Scalar => ScalarBackend::argmin(&self.data),
263                #[cfg(target_arch = "x86_64")]
264                Backend::SSE2 | Backend::AVX => Sse2Backend::argmin(&self.data),
265                #[cfg(target_arch = "x86_64")]
266                Backend::AVX2 | Backend::AVX512 => Avx2Backend::argmin(&self.data),
267                #[cfg(not(target_arch = "x86_64"))]
268                Backend::SSE2 | Backend::AVX | Backend::AVX2 | Backend::AVX512 => {
269                    ScalarBackend::argmin(&self.data)
270                }
271                #[cfg(any(target_arch = "aarch64", target_arch = "arm"))]
272                Backend::NEON => NeonBackend::argmin(&self.data),
273                #[cfg(not(any(target_arch = "aarch64", target_arch = "arm")))]
274                Backend::NEON => ScalarBackend::argmin(&self.data),
275                #[cfg(target_arch = "wasm32")]
276                Backend::WasmSIMD => WasmBackend::argmin(&self.data),
277                #[cfg(not(target_arch = "wasm32"))]
278                Backend::WasmSIMD => ScalarBackend::argmin(&self.data),
279                Backend::GPU | Backend::Auto => ScalarBackend::argmin(&self.data),
280            }
281        };
282
283        Ok(result)
284    }
285}