Skip to main content

trueno/vector/ops/transforms/
mod.rs

1//! Vector transformation operations
2//!
3//! This module provides element-wise transformation methods:
4//! - `abs()` - Element-wise absolute value
5//! - `clamp()` / `clip()` - Clamp values to a range
6//! - `lerp()` - Linear interpolation between two vectors
7//! - `sqrt()` - Element-wise square root (in `math` submodule)
8//! - `recip()` - Element-wise reciprocal (1/x) (in `math` submodule)
9//! - `pow()` - Element-wise power (in `math` submodule)
10
11mod math;
12
13#[cfg(target_arch = "x86_64")]
14use crate::backends::avx2::Avx2Backend;
15#[cfg(any(target_arch = "aarch64", target_arch = "arm"))]
16use crate::backends::neon::NeonBackend;
17use crate::backends::scalar::ScalarBackend;
18#[cfg(target_arch = "x86_64")]
19use crate::backends::sse2::Sse2Backend;
20#[cfg(target_arch = "wasm32")]
21use crate::backends::wasm::WasmBackend;
22use crate::backends::VectorBackend;
23use crate::{Backend, Result, TruenoError, Vector};
24
25impl Vector<f32> {
26    /// Compute element-wise absolute value
27    ///
28    /// Returns a new vector where each element is the absolute value of the corresponding input element.
29    ///
30    /// # Examples
31    ///
32    /// ```
33    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
34    /// use trueno::Vector;
35    ///
36    /// let v = Vector::from_slice(&[3.0, -4.0, 5.0, -2.0]);
37    /// let result = v.abs()?;
38    ///
39    /// assert_eq!(result.as_slice(), &[3.0, 4.0, 5.0, 2.0]);
40    /// # Ok(())
41    /// # }
42    /// ```
43    ///
44    /// # Empty Vector
45    ///
46    /// ```
47    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
48    /// use trueno::Vector;
49    ///
50    /// let v: Vector<f32> = Vector::from_slice(&[]);
51    /// let result = v.abs()?;
52    /// assert_eq!(result.len(), 0);
53    /// # Ok(())
54    /// # }
55    /// ```
56    pub fn abs(&self) -> Result<Vector<f32>> {
57        // Uninit: backend writes every element before any read.
58        let n = self.len();
59        let mut result_data: Vec<f32> = Vec::with_capacity(n);
60        // SAFETY: Backend writes all elements before any read.
61        unsafe {
62            result_data.set_len(n);
63        }
64
65        if !self.as_slice().is_empty() {
66            // SAFETY: Unsafe block delegates to backend implementation which maintains safety invariants
67            unsafe {
68                match self.backend() {
69                    Backend::Scalar => ScalarBackend::abs(self.as_slice(), &mut result_data),
70                    #[cfg(target_arch = "x86_64")]
71                    Backend::SSE2 | Backend::AVX => {
72                        Sse2Backend::abs(self.as_slice(), &mut result_data)
73                    }
74                    #[cfg(target_arch = "x86_64")]
75                    Backend::AVX2 | Backend::AVX512 => {
76                        Avx2Backend::abs(self.as_slice(), &mut result_data)
77                    }
78                    #[cfg(any(target_arch = "aarch64", target_arch = "arm"))]
79                    Backend::NEON => NeonBackend::abs(self.as_slice(), &mut result_data),
80                    #[cfg(target_arch = "wasm32")]
81                    Backend::WasmSIMD => WasmBackend::abs(self.as_slice(), &mut result_data),
82                    Backend::GPU => return Err(TruenoError::UnsupportedBackend(Backend::GPU)),
83                    Backend::Auto => {
84                        return Err(TruenoError::UnsupportedBackend(Backend::Auto));
85                    }
86                    #[cfg(not(target_arch = "x86_64"))]
87                    Backend::SSE2 | Backend::AVX | Backend::AVX2 | Backend::AVX512 => {
88                        ScalarBackend::abs(self.as_slice(), &mut result_data)
89                    }
90                    #[cfg(not(any(target_arch = "aarch64", target_arch = "arm")))]
91                    Backend::NEON => ScalarBackend::abs(self.as_slice(), &mut result_data),
92                    #[cfg(not(target_arch = "wasm32"))]
93                    Backend::WasmSIMD => ScalarBackend::abs(self.as_slice(), &mut result_data),
94                }
95            }
96        }
97
98        // Construct directly (no copy) — from_slice_with_backend would copy 4MB!
99        Ok(Vector { data: result_data, backend: self.backend() })
100    }
101
102    /// Clip values to a specified range [min_val, max_val]
103    ///
104    /// Constrains each element to be within the specified range:
105    /// - Values below min_val become min_val
106    /// - Values above max_val become max_val
107    /// - Values within range stay unchanged
108    ///
109    /// This is useful for outlier handling, gradient clipping in neural networks,
110    /// and ensuring values stay within valid bounds.
111    ///
112    /// # Examples
113    ///
114    /// ```
115    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
116    /// use trueno::Vector;
117    ///
118    /// let v = Vector::from_slice(&[-5.0, 0.0, 5.0, 10.0, 15.0]);
119    /// let clipped = v.clip(0.0, 10.0)?;
120    ///
121    /// // Values: [-5, 0, 5, 10, 15] → [0, 0, 5, 10, 10]
122    /// assert_eq!(clipped.as_slice(), &[0.0, 0.0, 5.0, 10.0, 10.0]);
123    /// # Ok(())
124    /// # }
125    /// ```
126    ///
127    /// # Invalid range
128    ///
129    /// Returns InvalidInput error if min_val > max_val.
130    ///
131    /// ```
132    /// use trueno::{Vector, TruenoError};
133    ///
134    /// let v = Vector::from_slice(&[1.0, 2.0, 3.0]);
135    /// let result = v.clip(10.0, 5.0); // min > max
136    /// assert!(matches!(result, Err(TruenoError::InvalidInput(_))));
137    /// ```
138    pub fn clip(&self, min_val: f32, max_val: f32) -> Result<Self> {
139        if min_val > max_val {
140            return Err(TruenoError::InvalidInput(format!(
141                "min_val ({}) must be <= max_val ({})",
142                min_val, max_val
143            )));
144        }
145
146        // Scalar fallback: Element-wise clamp
147        let data: Vec<f32> = self.as_slice().iter().map(|&x| x.max(min_val).min(max_val)).collect();
148
149        Ok(Vector::from_vec(data))
150    }
151
152    /// Clamp elements to range [min_val, max_val]
153    ///
154    /// Returns a new vector where each element is constrained to the specified range.
155    /// Elements below min_val become min_val, elements above max_val become max_val.
156    ///
157    /// # Examples
158    ///
159    /// ```
160    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
161    /// use trueno::Vector;
162    ///
163    /// let v = Vector::from_slice(&[-5.0, 0.0, 5.0, 10.0, 15.0]);
164    /// let result = v.clamp(0.0, 10.0)?;
165    ///
166    /// assert_eq!(result.as_slice(), &[0.0, 0.0, 5.0, 10.0, 10.0]);
167    /// # Ok(())
168    /// # }
169    /// ```
170    ///
171    /// # Negative Range
172    ///
173    /// ```
174    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
175    /// use trueno::Vector;
176    ///
177    /// let v = Vector::from_slice(&[-10.0, -5.0, 0.0, 5.0]);
178    /// let result = v.clamp(-8.0, -2.0)?;
179    /// assert_eq!(result.as_slice(), &[-8.0, -5.0, -2.0, -2.0]);
180    /// # Ok(())
181    /// # }
182    /// ```
183    ///
184    /// # Errors
185    ///
186    /// Returns `InvalidInput` if min_val > max_val.
187    pub fn clamp(&self, min_val: f32, max_val: f32) -> Result<Vector<f32>> {
188        // Validate range
189        if min_val > max_val {
190            return Err(TruenoError::InvalidInput(format!(
191                "Invalid clamp range: min ({}) > max ({})",
192                min_val, max_val
193            )));
194        }
195
196        // Uninit: backend writes every element before any read.
197        let n = self.len();
198        let mut result_data: Vec<f32> = Vec::with_capacity(n);
199        // SAFETY: Backend writes all elements before any read.
200        unsafe {
201            result_data.set_len(n);
202        }
203
204        if !self.as_slice().is_empty() {
205            // SAFETY: Unsafe block delegates to backend implementation which maintains safety invariants
206            unsafe {
207                match self.backend() {
208                    Backend::Scalar => {
209                        ScalarBackend::clamp(self.as_slice(), min_val, max_val, &mut result_data)
210                    }
211                    #[cfg(target_arch = "x86_64")]
212                    Backend::SSE2 | Backend::AVX => {
213                        Sse2Backend::clamp(self.as_slice(), min_val, max_val, &mut result_data)
214                    }
215                    #[cfg(target_arch = "x86_64")]
216                    Backend::AVX2 | Backend::AVX512 => {
217                        Avx2Backend::clamp(self.as_slice(), min_val, max_val, &mut result_data)
218                    }
219                    #[cfg(any(target_arch = "aarch64", target_arch = "arm"))]
220                    Backend::NEON => {
221                        NeonBackend::clamp(self.as_slice(), min_val, max_val, &mut result_data)
222                    }
223                    #[cfg(target_arch = "wasm32")]
224                    Backend::WasmSIMD => {
225                        WasmBackend::clamp(self.as_slice(), min_val, max_val, &mut result_data)
226                    }
227                    Backend::GPU => return Err(TruenoError::UnsupportedBackend(Backend::GPU)),
228                    Backend::Auto => {
229                        return Err(TruenoError::UnsupportedBackend(Backend::Auto));
230                    }
231                    #[cfg(not(target_arch = "x86_64"))]
232                    Backend::SSE2 | Backend::AVX | Backend::AVX2 | Backend::AVX512 => {
233                        ScalarBackend::clamp(self.as_slice(), min_val, max_val, &mut result_data)
234                    }
235                    #[cfg(not(any(target_arch = "aarch64", target_arch = "arm")))]
236                    Backend::NEON => {
237                        ScalarBackend::clamp(self.as_slice(), min_val, max_val, &mut result_data)
238                    }
239                    #[cfg(not(target_arch = "wasm32"))]
240                    Backend::WasmSIMD => {
241                        ScalarBackend::clamp(self.as_slice(), min_val, max_val, &mut result_data)
242                    }
243                }
244            }
245        }
246
247        Ok(Vector { data: result_data, backend: self.backend() })
248    }
249
250    /// Linear interpolation between two vectors
251    ///
252    /// Computes element-wise linear interpolation: `result\[i\] = a\[i\] + t * (b\[i\] - a\[i\])`
253    ///
254    /// - When `t = 0.0`, returns `self`
255    /// - When `t = 1.0`, returns `other`
256    /// - Values outside `[0, 1]` perform extrapolation
257    ///
258    /// # Examples
259    ///
260    /// ```
261    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
262    /// use trueno::Vector;
263    ///
264    /// let a = Vector::from_slice(&[0.0, 10.0, 20.0]);
265    /// let b = Vector::from_slice(&[100.0, 110.0, 120.0]);
266    /// let result = a.lerp(&b, 0.5)?;
267    ///
268    /// assert_eq!(result.as_slice(), &[50.0, 60.0, 70.0]);
269    /// # Ok(())
270    /// # }
271    /// ```
272    ///
273    /// # Extrapolation
274    ///
275    /// ```
276    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
277    /// use trueno::Vector;
278    ///
279    /// let a = Vector::from_slice(&[0.0, 10.0]);
280    /// let b = Vector::from_slice(&[10.0, 20.0]);
281    ///
282    /// // t > 1.0 extrapolates beyond b
283    /// let result = a.lerp(&b, 2.0)?;
284    /// assert_eq!(result.as_slice(), &[20.0, 30.0]);
285    /// # Ok(())
286    /// # }
287    /// ```
288    ///
289    /// # Errors
290    ///
291    /// Returns `SizeMismatch` if vectors have different lengths.
292    pub fn lerp(&self, other: &Vector<f32>, t: f32) -> Result<Vector<f32>> {
293        if self.len() != other.len() {
294            return Err(TruenoError::SizeMismatch { expected: self.len(), actual: other.len() });
295        }
296
297        // Uninit: backend writes every element before any read.
298        let n = self.len();
299        let mut result_data: Vec<f32> = Vec::with_capacity(n);
300        // SAFETY: Backend writes all elements before any read.
301        unsafe {
302            result_data.set_len(n);
303        }
304
305        if !self.as_slice().is_empty() {
306            // SAFETY: Unsafe block delegates to backend implementation which maintains safety invariants
307            unsafe {
308                match self.backend() {
309                    Backend::Scalar => {
310                        ScalarBackend::lerp(self.as_slice(), other.as_slice(), t, &mut result_data)
311                    }
312                    #[cfg(target_arch = "x86_64")]
313                    Backend::SSE2 | Backend::AVX => {
314                        Sse2Backend::lerp(self.as_slice(), other.as_slice(), t, &mut result_data)
315                    }
316                    #[cfg(target_arch = "x86_64")]
317                    Backend::AVX2 | Backend::AVX512 => {
318                        Avx2Backend::lerp(self.as_slice(), other.as_slice(), t, &mut result_data)
319                    }
320                    #[cfg(any(target_arch = "aarch64", target_arch = "arm"))]
321                    Backend::NEON => {
322                        NeonBackend::lerp(self.as_slice(), other.as_slice(), t, &mut result_data)
323                    }
324                    #[cfg(target_arch = "wasm32")]
325                    Backend::WasmSIMD => {
326                        WasmBackend::lerp(self.as_slice(), other.as_slice(), t, &mut result_data)
327                    }
328                    Backend::GPU => return Err(TruenoError::UnsupportedBackend(Backend::GPU)),
329                    Backend::Auto => {
330                        return Err(TruenoError::UnsupportedBackend(Backend::Auto));
331                    }
332                    #[cfg(not(target_arch = "x86_64"))]
333                    Backend::SSE2 | Backend::AVX | Backend::AVX2 | Backend::AVX512 => {
334                        ScalarBackend::lerp(self.as_slice(), other.as_slice(), t, &mut result_data)
335                    }
336                    #[cfg(not(any(target_arch = "aarch64", target_arch = "arm")))]
337                    Backend::NEON => {
338                        ScalarBackend::lerp(self.as_slice(), other.as_slice(), t, &mut result_data)
339                    }
340                    #[cfg(not(target_arch = "wasm32"))]
341                    Backend::WasmSIMD => {
342                        ScalarBackend::lerp(self.as_slice(), other.as_slice(), t, &mut result_data)
343                    }
344                }
345            }
346        }
347
348        Ok(Vector { data: result_data, backend: self.backend() })
349    }
350}
351
352#[cfg(test)]
353mod tests;