Skip to main content

trueno/vector/ops/arithmetic/
mod.rs

1//! Arithmetic operations for Vector<f32>
2//!
3//! This module provides element-wise arithmetic operations:
4//! - Basic: `add`, `sub`, `mul`, `div`
5//! - Scalar: `scale`
6//! - Fused: `fma` (fused multiply-add)
7
8#[cfg(target_arch = "x86_64")]
9use crate::backends::avx2::Avx2Backend;
10#[cfg(any(target_arch = "aarch64", target_arch = "arm"))]
11use crate::backends::neon::NeonBackend;
12use crate::backends::scalar::ScalarBackend;
13#[cfg(target_arch = "x86_64")]
14use crate::backends::sse2::Sse2Backend;
15#[cfg(target_arch = "wasm32")]
16use crate::backends::wasm::WasmBackend;
17use crate::backends::VectorBackend;
18use crate::vector::Vector;
19use crate::{dispatch_binary_op, Backend, Result, TruenoError};
20
21impl Vector<f32> {
22    /// Element-wise addition
23    ///
24    /// # Performance
25    ///
26    /// Auto-selects the best available backend:
27    /// - **AVX2**: ~4x faster than scalar for 1K+ elements
28    /// - **GPU**: ~50x faster than scalar for 10M+ elements
29    ///
30    /// # Examples
31    ///
32    /// ```
33    /// use trueno::Vector;
34    ///
35    /// let a = Vector::from_slice(&[1.0, 2.0, 3.0]);
36    /// let b = Vector::from_slice(&[4.0, 5.0, 6.0]);
37    /// let result = a.add(&b)?;
38    ///
39    /// assert_eq!(result.as_slice(), &[5.0, 7.0, 9.0]);
40    /// # Ok::<(), trueno::TruenoError>(())
41    /// ```
42    ///
43    /// # Errors
44    ///
45    /// Returns [`TruenoError::SizeMismatch`] if vectors have different lengths.
46    pub fn add(&self, other: &Self) -> Result<Self> {
47        if self.len() != other.len() {
48            return Err(TruenoError::SizeMismatch { expected: self.len(), actual: other.len() });
49        }
50
51        // Uninit allocation: avoids the zero-fill cost (70µs+ at 1M elements)
52        // since every element will be overwritten by dispatch_binary_op below.
53        // SAFETY: dispatch_binary_op!(..., add, a, b, out) writes to EVERY element
54        // of `out` (it's an element-wise add). No reads before writes.
55        let n = self.len();
56        let mut result: Vec<f32> = Vec::with_capacity(n);
57        unsafe {
58            result.set_len(n);
59        }
60
61        // Use parallel processing for large arrays
62        #[cfg(feature = "parallel")]
63        {
64            const PARALLEL_THRESHOLD: usize = 100_000; // Threshold for element-wise ops
65            const CHUNK_SIZE: usize = 65536; // 64K elements = 256KB, cache-friendly
66
67            if self.len() >= PARALLEL_THRESHOLD {
68                use rayon::prelude::*;
69
70                self.data
71                    .par_chunks(CHUNK_SIZE)
72                    .zip(other.data.par_chunks(CHUNK_SIZE))
73                    .zip(result.par_chunks_mut(CHUNK_SIZE))
74                    .for_each(|((chunk_a, chunk_b), chunk_out)| {
75                        dispatch_binary_op!(self.backend, add, chunk_a, chunk_b, chunk_out);
76                    });
77
78                return Ok(Self { data: result, backend: self.backend });
79            }
80        }
81
82        dispatch_binary_op!(self.backend, add, &self.data, &other.data, &mut result);
83
84        Ok(Self { data: result, backend: self.backend })
85    }
86
87    /// Element-wise subtraction
88    ///
89    /// # Performance
90    ///
91    /// Auto-selects the best available backend:
92    /// - **AVX2**: ~4x faster than scalar for 1K+ elements
93    /// - **GPU**: ~50x faster than scalar for 10M+ elements
94    ///
95    /// # Examples
96    ///
97    /// ```
98    /// use trueno::Vector;
99    ///
100    /// let a = Vector::from_slice(&[5.0, 7.0, 9.0]);
101    /// let b = Vector::from_slice(&[1.0, 2.0, 3.0]);
102    /// let result = a.sub(&b)?;
103    ///
104    /// assert_eq!(result.as_slice(), &[4.0, 5.0, 6.0]);
105    /// # Ok::<(), trueno::TruenoError>(())
106    /// ```
107    ///
108    /// # Errors
109    ///
110    /// Returns [`TruenoError::SizeMismatch`] if vectors have different lengths.
111    pub fn sub(&self, other: &Self) -> Result<Self> {
112        if self.len() != other.len() {
113            return Err(TruenoError::SizeMismatch { expected: self.len(), actual: other.len() });
114        }
115
116        // Uninit allocation: skip zero-fill since dispatch_binary_op writes all elements.
117        let n = self.len();
118        let mut result: Vec<f32> = Vec::with_capacity(n);
119        // SAFETY: Every element is written before any read (by element-wise op below).
120        unsafe {
121            result.set_len(n);
122        }
123
124        // Use parallel processing for large arrays
125        #[cfg(feature = "parallel")]
126        {
127            const PARALLEL_THRESHOLD: usize = 100_000;
128            const CHUNK_SIZE: usize = 65536;
129
130            if self.len() >= PARALLEL_THRESHOLD {
131                use rayon::prelude::*;
132
133                self.data
134                    .par_chunks(CHUNK_SIZE)
135                    .zip(other.data.par_chunks(CHUNK_SIZE))
136                    .zip(result.par_chunks_mut(CHUNK_SIZE))
137                    .for_each(|((chunk_a, chunk_b), chunk_out)| {
138                        dispatch_binary_op!(self.backend, sub, chunk_a, chunk_b, chunk_out);
139                    });
140
141                return Ok(Self { data: result, backend: self.backend });
142            }
143        }
144
145        dispatch_binary_op!(self.backend, sub, &self.data, &other.data, &mut result);
146
147        Ok(Self { data: result, backend: self.backend })
148    }
149
150    /// Element-wise multiplication
151    ///
152    /// # Examples
153    ///
154    /// ```
155    /// use trueno::Vector;
156    ///
157    /// let a = Vector::from_slice(&[2.0, 3.0, 4.0]);
158    /// let b = Vector::from_slice(&[5.0, 6.0, 7.0]);
159    /// let result = a.mul(&b)?;
160    ///
161    /// assert_eq!(result.as_slice(), &[10.0, 18.0, 28.0]);
162    /// # Ok::<(), trueno::TruenoError>(())
163    /// ```
164    pub fn mul(&self, other: &Self) -> Result<Self> {
165        if self.len() != other.len() {
166            return Err(TruenoError::SizeMismatch { expected: self.len(), actual: other.len() });
167        }
168
169        // Uninit allocation: skip zero-fill since dispatch_binary_op writes all elements.
170        let n = self.len();
171        let mut result: Vec<f32> = Vec::with_capacity(n);
172        // SAFETY: Every element is written before any read (by element-wise op below).
173        unsafe {
174            result.set_len(n);
175        }
176
177        // Use parallel processing for large arrays
178        #[cfg(feature = "parallel")]
179        {
180            const PARALLEL_THRESHOLD: usize = 100_000;
181            const CHUNK_SIZE: usize = 65536;
182
183            if self.len() >= PARALLEL_THRESHOLD {
184                use rayon::prelude::*;
185
186                self.data
187                    .par_chunks(CHUNK_SIZE)
188                    .zip(other.data.par_chunks(CHUNK_SIZE))
189                    .zip(result.par_chunks_mut(CHUNK_SIZE))
190                    .for_each(|((chunk_a, chunk_b), chunk_out)| {
191                        dispatch_binary_op!(self.backend, mul, chunk_a, chunk_b, chunk_out);
192                    });
193
194                return Ok(Self { data: result, backend: self.backend });
195            }
196        }
197
198        dispatch_binary_op!(self.backend, mul, &self.data, &other.data, &mut result);
199
200        Ok(Self { data: result, backend: self.backend })
201    }
202
203    /// Element-wise division
204    ///
205    /// # Examples
206    ///
207    /// ```
208    /// use trueno::Vector;
209    ///
210    /// let a = Vector::from_slice(&[10.0, 20.0, 30.0]);
211    /// let b = Vector::from_slice(&[2.0, 4.0, 5.0]);
212    /// let result = a.div(&b)?;
213    ///
214    /// assert_eq!(result.as_slice(), &[5.0, 5.0, 6.0]);
215    /// # Ok::<(), trueno::TruenoError>(())
216    /// ```
217    pub fn div(&self, other: &Self) -> Result<Self> {
218        if self.len() != other.len() {
219            return Err(TruenoError::SizeMismatch { expected: self.len(), actual: other.len() });
220        }
221
222        // Uninit allocation: skip zero-fill since dispatch_binary_op writes all elements.
223        let n = self.len();
224        let mut result: Vec<f32> = Vec::with_capacity(n);
225        // SAFETY: Every element is written before any read (by element-wise op below).
226        unsafe {
227            result.set_len(n);
228        }
229
230        // Use parallel processing for large arrays
231        #[cfg(feature = "parallel")]
232        {
233            const PARALLEL_THRESHOLD: usize = 100_000;
234            const CHUNK_SIZE: usize = 65536;
235
236            if self.len() >= PARALLEL_THRESHOLD {
237                use rayon::prelude::*;
238
239                self.data
240                    .par_chunks(CHUNK_SIZE)
241                    .zip(other.data.par_chunks(CHUNK_SIZE))
242                    .zip(result.par_chunks_mut(CHUNK_SIZE))
243                    .for_each(|((chunk_a, chunk_b), chunk_out)| {
244                        dispatch_binary_op!(self.backend, div, chunk_a, chunk_b, chunk_out);
245                    });
246
247                return Ok(Self { data: result, backend: self.backend });
248            }
249        }
250
251        dispatch_binary_op!(self.backend, div, &self.data, &other.data, &mut result);
252
253        Ok(Self { data: result, backend: self.backend })
254    }
255
256    /// Scalar multiplication (scale all elements by a scalar value)
257    ///
258    /// Returns a new vector where each element is multiplied by the scalar.
259    ///
260    /// # Examples
261    ///
262    /// ```
263    /// use trueno::Vector;
264    ///
265    /// let v = Vector::from_slice(&[1.0, 2.0, 3.0, 4.0]);
266    /// let result = v.scale(2.0)?;
267    ///
268    /// assert_eq!(result.as_slice(), &[2.0, 4.0, 6.0, 8.0]);
269    /// # Ok::<(), trueno::TruenoError>(())
270    /// ```
271    ///
272    /// # Scaling by Zero
273    ///
274    /// ```
275    /// use trueno::Vector;
276    ///
277    /// let v = Vector::from_slice(&[1.0, 2.0, 3.0]);
278    /// let result = v.scale(0.0)?;
279    /// assert_eq!(result.as_slice(), &[0.0, 0.0, 0.0]);
280    /// # Ok::<(), trueno::TruenoError>(())
281    /// ```
282    ///
283    /// # Negative Scaling
284    ///
285    /// ```
286    /// use trueno::Vector;
287    ///
288    /// let v = Vector::from_slice(&[1.0, -2.0, 3.0]);
289    /// let result = v.scale(-2.0)?;
290    /// assert_eq!(result.as_slice(), &[-2.0, 4.0, -6.0]);
291    /// # Ok::<(), trueno::TruenoError>(())
292    /// ```
293    pub fn scale(&self, scalar: f32) -> Result<Vector<f32>> {
294        // Uninit allocation: backend writes all elements.
295        let n = self.len();
296        let mut result_data: Vec<f32> = Vec::with_capacity(n);
297        // SAFETY: backend scale() writes every element before any read.
298        unsafe {
299            result_data.set_len(n);
300        }
301
302        if !self.data.is_empty() {
303            // SAFETY: Unsafe block delegates to backend implementation which maintains safety invariants
304            unsafe {
305                match self.backend {
306                    Backend::Scalar => ScalarBackend::scale(&self.data, scalar, &mut result_data),
307                    #[cfg(target_arch = "x86_64")]
308                    Backend::SSE2 | Backend::AVX => {
309                        Sse2Backend::scale(&self.data, scalar, &mut result_data)
310                    }
311                    #[cfg(target_arch = "x86_64")]
312                    Backend::AVX2 | Backend::AVX512 => {
313                        Avx2Backend::scale(&self.data, scalar, &mut result_data)
314                    }
315                    #[cfg(any(target_arch = "aarch64", target_arch = "arm"))]
316                    Backend::NEON => NeonBackend::scale(&self.data, scalar, &mut result_data),
317                    #[cfg(target_arch = "wasm32")]
318                    Backend::WasmSIMD => WasmBackend::scale(&self.data, scalar, &mut result_data),
319                    Backend::GPU => return Err(TruenoError::UnsupportedBackend(Backend::GPU)),
320                    Backend::Auto => {
321                        // Auto should have been resolved at creation time
322                        return Err(TruenoError::UnsupportedBackend(Backend::Auto));
323                    }
324                    #[cfg(not(target_arch = "x86_64"))]
325                    Backend::SSE2 | Backend::AVX | Backend::AVX2 | Backend::AVX512 => {
326                        ScalarBackend::scale(&self.data, scalar, &mut result_data)
327                    }
328                    #[cfg(not(any(target_arch = "aarch64", target_arch = "arm")))]
329                    Backend::NEON => ScalarBackend::scale(&self.data, scalar, &mut result_data),
330                    #[cfg(not(target_arch = "wasm32"))]
331                    Backend::WasmSIMD => ScalarBackend::scale(&self.data, scalar, &mut result_data),
332                }
333            }
334        }
335
336        Ok(Vector { data: result_data, backend: self.backend })
337    }
338
339    /// Fused multiply-add: result\[i\] = self\[i\] * b\[i\] + c\[i\]
340    ///
341    /// Computes element-wise fused multiply-add operation. On hardware with FMA support
342    /// (AVX2, NEON), this is a single instruction with better performance and numerical
343    /// accuracy (no intermediate rounding). On platforms without FMA (SSE2, WASM), uses
344    /// separate multiply and add operations.
345    ///
346    /// # Arguments
347    ///
348    /// * `b` - The second vector to multiply with
349    /// * `c` - The vector to add to the product
350    ///
351    /// # Returns
352    ///
353    /// A new vector where each element is `self\[i\] * b\[i\] + c\[i\]`
354    ///
355    /// # Errors
356    ///
357    /// Returns `SizeMismatch` if vector lengths don't match
358    ///
359    /// # Examples
360    ///
361    /// ```
362    /// use trueno::Vector;
363    ///
364    /// let a = Vector::from_slice(&[2.0, 3.0, 4.0]);
365    /// let b = Vector::from_slice(&[5.0, 6.0, 7.0]);
366    /// let c = Vector::from_slice(&[1.0, 2.0, 3.0]);
367    /// let result = a.fma(&b, &c)?;
368    /// assert_eq!(result.as_slice(), &[11.0, 20.0, 31.0]);  // [2*5+1, 3*6+2, 4*7+3]
369    /// # Ok::<(), trueno::TruenoError>(())
370    /// ```
371    ///
372    /// # Use Cases
373    ///
374    /// - Neural networks: matrix multiplication, backpropagation
375    /// - Scientific computing: polynomial evaluation, numerical integration
376    /// - Graphics: transformation matrices, shader computations
377    /// - Physics simulations: force calculations, particle systems
378    pub fn fma(&self, b: &Vector<f32>, c: &Vector<f32>) -> Result<Vector<f32>> {
379        if self.len() != b.len() {
380            return Err(TruenoError::SizeMismatch { expected: self.len(), actual: b.len() });
381        }
382        if self.len() != c.len() {
383            return Err(TruenoError::SizeMismatch { expected: self.len(), actual: c.len() });
384        }
385
386        // Uninit allocation: backend fma writes all elements.
387        let n = self.len();
388        let mut result_data: Vec<f32> = Vec::with_capacity(n);
389        // SAFETY: backend fma() writes every element before any read.
390        unsafe {
391            result_data.set_len(n);
392        }
393
394        if !self.data.is_empty() {
395            // SAFETY: Unsafe block delegates to backend implementation which maintains safety invariants
396            unsafe {
397                match self.backend {
398                    Backend::Scalar => {
399                        ScalarBackend::fma(&self.data, &b.data, &c.data, &mut result_data)
400                    }
401                    #[cfg(target_arch = "x86_64")]
402                    Backend::SSE2 | Backend::AVX => {
403                        Sse2Backend::fma(&self.data, &b.data, &c.data, &mut result_data)
404                    }
405                    #[cfg(target_arch = "x86_64")]
406                    Backend::AVX2 | Backend::AVX512 => {
407                        Avx2Backend::fma(&self.data, &b.data, &c.data, &mut result_data)
408                    }
409                    #[cfg(any(target_arch = "aarch64", target_arch = "arm"))]
410                    Backend::NEON => {
411                        NeonBackend::fma(&self.data, &b.data, &c.data, &mut result_data)
412                    }
413                    #[cfg(target_arch = "wasm32")]
414                    Backend::WasmSIMD => {
415                        WasmBackend::fma(&self.data, &b.data, &c.data, &mut result_data)
416                    }
417                    Backend::GPU => return Err(TruenoError::UnsupportedBackend(Backend::GPU)),
418                    Backend::Auto => {
419                        return Err(TruenoError::UnsupportedBackend(Backend::Auto));
420                    }
421                    #[cfg(not(target_arch = "x86_64"))]
422                    Backend::SSE2 | Backend::AVX | Backend::AVX2 | Backend::AVX512 => {
423                        ScalarBackend::fma(&self.data, &b.data, &c.data, &mut result_data)
424                    }
425                    #[cfg(not(any(target_arch = "aarch64", target_arch = "arm")))]
426                    Backend::NEON => {
427                        ScalarBackend::fma(&self.data, &b.data, &c.data, &mut result_data)
428                    }
429                    #[cfg(not(target_arch = "wasm32"))]
430                    Backend::WasmSIMD => {
431                        ScalarBackend::fma(&self.data, &b.data, &c.data, &mut result_data)
432                    }
433                }
434            }
435        }
436
437        Ok(Vector { data: result_data, backend: self.backend })
438    }
439}
440
441#[cfg(test)]
442mod tests;