trueno/vector/ops/transforms/mod.rs
1//! Vector transformation operations
2//!
3//! This module provides element-wise transformation methods:
4//! - `abs()` - Element-wise absolute value
5//! - `clamp()` / `clip()` - Clamp values to a range
6//! - `lerp()` - Linear interpolation between two vectors
7//! - `sqrt()` - Element-wise square root (in `math` submodule)
8//! - `recip()` - Element-wise reciprocal (1/x) (in `math` submodule)
9//! - `pow()` - Element-wise power (in `math` submodule)
10
11mod math;
12
13#[cfg(target_arch = "x86_64")]
14use crate::backends::avx2::Avx2Backend;
15#[cfg(any(target_arch = "aarch64", target_arch = "arm"))]
16use crate::backends::neon::NeonBackend;
17use crate::backends::scalar::ScalarBackend;
18#[cfg(target_arch = "x86_64")]
19use crate::backends::sse2::Sse2Backend;
20#[cfg(target_arch = "wasm32")]
21use crate::backends::wasm::WasmBackend;
22use crate::backends::VectorBackend;
23use crate::{Backend, Result, TruenoError, Vector};
24
25impl Vector<f32> {
26 /// Compute element-wise absolute value
27 ///
28 /// Returns a new vector where each element is the absolute value of the corresponding input element.
29 ///
30 /// # Examples
31 ///
32 /// ```
33 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
34 /// use trueno::Vector;
35 ///
36 /// let v = Vector::from_slice(&[3.0, -4.0, 5.0, -2.0]);
37 /// let result = v.abs()?;
38 ///
39 /// assert_eq!(result.as_slice(), &[3.0, 4.0, 5.0, 2.0]);
40 /// # Ok(())
41 /// # }
42 /// ```
43 ///
44 /// # Empty Vector
45 ///
46 /// ```
47 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
48 /// use trueno::Vector;
49 ///
50 /// let v: Vector<f32> = Vector::from_slice(&[]);
51 /// let result = v.abs()?;
52 /// assert_eq!(result.len(), 0);
53 /// # Ok(())
54 /// # }
55 /// ```
56 pub fn abs(&self) -> Result<Vector<f32>> {
57 // Uninit: backend writes every element before any read.
58 let n = self.len();
59 let mut result_data: Vec<f32> = Vec::with_capacity(n);
60 // SAFETY: Backend writes all elements before any read.
61 unsafe {
62 result_data.set_len(n);
63 }
64
65 if !self.as_slice().is_empty() {
66 // SAFETY: Unsafe block delegates to backend implementation which maintains safety invariants
67 unsafe {
68 match self.backend() {
69 Backend::Scalar => ScalarBackend::abs(self.as_slice(), &mut result_data),
70 #[cfg(target_arch = "x86_64")]
71 Backend::SSE2 | Backend::AVX => {
72 Sse2Backend::abs(self.as_slice(), &mut result_data)
73 }
74 #[cfg(target_arch = "x86_64")]
75 Backend::AVX2 | Backend::AVX512 => {
76 Avx2Backend::abs(self.as_slice(), &mut result_data)
77 }
78 #[cfg(any(target_arch = "aarch64", target_arch = "arm"))]
79 Backend::NEON => NeonBackend::abs(self.as_slice(), &mut result_data),
80 #[cfg(target_arch = "wasm32")]
81 Backend::WasmSIMD => WasmBackend::abs(self.as_slice(), &mut result_data),
82 Backend::GPU => return Err(TruenoError::UnsupportedBackend(Backend::GPU)),
83 Backend::Auto => {
84 return Err(TruenoError::UnsupportedBackend(Backend::Auto));
85 }
86 #[cfg(not(target_arch = "x86_64"))]
87 Backend::SSE2 | Backend::AVX | Backend::AVX2 | Backend::AVX512 => {
88 ScalarBackend::abs(self.as_slice(), &mut result_data)
89 }
90 #[cfg(not(any(target_arch = "aarch64", target_arch = "arm")))]
91 Backend::NEON => ScalarBackend::abs(self.as_slice(), &mut result_data),
92 #[cfg(not(target_arch = "wasm32"))]
93 Backend::WasmSIMD => ScalarBackend::abs(self.as_slice(), &mut result_data),
94 }
95 }
96 }
97
98 // Construct directly (no copy) — from_slice_with_backend would copy 4MB!
99 Ok(Vector { data: result_data, backend: self.backend() })
100 }
101
102 /// Clip values to a specified range [min_val, max_val]
103 ///
104 /// Constrains each element to be within the specified range:
105 /// - Values below min_val become min_val
106 /// - Values above max_val become max_val
107 /// - Values within range stay unchanged
108 ///
109 /// This is useful for outlier handling, gradient clipping in neural networks,
110 /// and ensuring values stay within valid bounds.
111 ///
112 /// # Examples
113 ///
114 /// ```
115 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
116 /// use trueno::Vector;
117 ///
118 /// let v = Vector::from_slice(&[-5.0, 0.0, 5.0, 10.0, 15.0]);
119 /// let clipped = v.clip(0.0, 10.0)?;
120 ///
121 /// // Values: [-5, 0, 5, 10, 15] → [0, 0, 5, 10, 10]
122 /// assert_eq!(clipped.as_slice(), &[0.0, 0.0, 5.0, 10.0, 10.0]);
123 /// # Ok(())
124 /// # }
125 /// ```
126 ///
127 /// # Invalid range
128 ///
129 /// Returns InvalidInput error if min_val > max_val.
130 ///
131 /// ```
132 /// use trueno::{Vector, TruenoError};
133 ///
134 /// let v = Vector::from_slice(&[1.0, 2.0, 3.0]);
135 /// let result = v.clip(10.0, 5.0); // min > max
136 /// assert!(matches!(result, Err(TruenoError::InvalidInput(_))));
137 /// ```
138 pub fn clip(&self, min_val: f32, max_val: f32) -> Result<Self> {
139 if min_val > max_val {
140 return Err(TruenoError::InvalidInput(format!(
141 "min_val ({}) must be <= max_val ({})",
142 min_val, max_val
143 )));
144 }
145
146 // Scalar fallback: Element-wise clamp
147 let data: Vec<f32> = self.as_slice().iter().map(|&x| x.max(min_val).min(max_val)).collect();
148
149 Ok(Vector::from_vec(data))
150 }
151
152 /// Clamp elements to range [min_val, max_val]
153 ///
154 /// Returns a new vector where each element is constrained to the specified range.
155 /// Elements below min_val become min_val, elements above max_val become max_val.
156 ///
157 /// # Examples
158 ///
159 /// ```
160 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
161 /// use trueno::Vector;
162 ///
163 /// let v = Vector::from_slice(&[-5.0, 0.0, 5.0, 10.0, 15.0]);
164 /// let result = v.clamp(0.0, 10.0)?;
165 ///
166 /// assert_eq!(result.as_slice(), &[0.0, 0.0, 5.0, 10.0, 10.0]);
167 /// # Ok(())
168 /// # }
169 /// ```
170 ///
171 /// # Negative Range
172 ///
173 /// ```
174 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
175 /// use trueno::Vector;
176 ///
177 /// let v = Vector::from_slice(&[-10.0, -5.0, 0.0, 5.0]);
178 /// let result = v.clamp(-8.0, -2.0)?;
179 /// assert_eq!(result.as_slice(), &[-8.0, -5.0, -2.0, -2.0]);
180 /// # Ok(())
181 /// # }
182 /// ```
183 ///
184 /// # Errors
185 ///
186 /// Returns `InvalidInput` if min_val > max_val.
187 pub fn clamp(&self, min_val: f32, max_val: f32) -> Result<Vector<f32>> {
188 // Validate range
189 if min_val > max_val {
190 return Err(TruenoError::InvalidInput(format!(
191 "Invalid clamp range: min ({}) > max ({})",
192 min_val, max_val
193 )));
194 }
195
196 // Uninit: backend writes every element before any read.
197 let n = self.len();
198 let mut result_data: Vec<f32> = Vec::with_capacity(n);
199 // SAFETY: Backend writes all elements before any read.
200 unsafe {
201 result_data.set_len(n);
202 }
203
204 if !self.as_slice().is_empty() {
205 // SAFETY: Unsafe block delegates to backend implementation which maintains safety invariants
206 unsafe {
207 match self.backend() {
208 Backend::Scalar => {
209 ScalarBackend::clamp(self.as_slice(), min_val, max_val, &mut result_data)
210 }
211 #[cfg(target_arch = "x86_64")]
212 Backend::SSE2 | Backend::AVX => {
213 Sse2Backend::clamp(self.as_slice(), min_val, max_val, &mut result_data)
214 }
215 #[cfg(target_arch = "x86_64")]
216 Backend::AVX2 | Backend::AVX512 => {
217 Avx2Backend::clamp(self.as_slice(), min_val, max_val, &mut result_data)
218 }
219 #[cfg(any(target_arch = "aarch64", target_arch = "arm"))]
220 Backend::NEON => {
221 NeonBackend::clamp(self.as_slice(), min_val, max_val, &mut result_data)
222 }
223 #[cfg(target_arch = "wasm32")]
224 Backend::WasmSIMD => {
225 WasmBackend::clamp(self.as_slice(), min_val, max_val, &mut result_data)
226 }
227 Backend::GPU => return Err(TruenoError::UnsupportedBackend(Backend::GPU)),
228 Backend::Auto => {
229 return Err(TruenoError::UnsupportedBackend(Backend::Auto));
230 }
231 #[cfg(not(target_arch = "x86_64"))]
232 Backend::SSE2 | Backend::AVX | Backend::AVX2 | Backend::AVX512 => {
233 ScalarBackend::clamp(self.as_slice(), min_val, max_val, &mut result_data)
234 }
235 #[cfg(not(any(target_arch = "aarch64", target_arch = "arm")))]
236 Backend::NEON => {
237 ScalarBackend::clamp(self.as_slice(), min_val, max_val, &mut result_data)
238 }
239 #[cfg(not(target_arch = "wasm32"))]
240 Backend::WasmSIMD => {
241 ScalarBackend::clamp(self.as_slice(), min_val, max_val, &mut result_data)
242 }
243 }
244 }
245 }
246
247 Ok(Vector { data: result_data, backend: self.backend() })
248 }
249
250 /// Linear interpolation between two vectors
251 ///
252 /// Computes element-wise linear interpolation: `result\[i\] = a\[i\] + t * (b\[i\] - a\[i\])`
253 ///
254 /// - When `t = 0.0`, returns `self`
255 /// - When `t = 1.0`, returns `other`
256 /// - Values outside `[0, 1]` perform extrapolation
257 ///
258 /// # Examples
259 ///
260 /// ```
261 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
262 /// use trueno::Vector;
263 ///
264 /// let a = Vector::from_slice(&[0.0, 10.0, 20.0]);
265 /// let b = Vector::from_slice(&[100.0, 110.0, 120.0]);
266 /// let result = a.lerp(&b, 0.5)?;
267 ///
268 /// assert_eq!(result.as_slice(), &[50.0, 60.0, 70.0]);
269 /// # Ok(())
270 /// # }
271 /// ```
272 ///
273 /// # Extrapolation
274 ///
275 /// ```
276 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
277 /// use trueno::Vector;
278 ///
279 /// let a = Vector::from_slice(&[0.0, 10.0]);
280 /// let b = Vector::from_slice(&[10.0, 20.0]);
281 ///
282 /// // t > 1.0 extrapolates beyond b
283 /// let result = a.lerp(&b, 2.0)?;
284 /// assert_eq!(result.as_slice(), &[20.0, 30.0]);
285 /// # Ok(())
286 /// # }
287 /// ```
288 ///
289 /// # Errors
290 ///
291 /// Returns `SizeMismatch` if vectors have different lengths.
292 pub fn lerp(&self, other: &Vector<f32>, t: f32) -> Result<Vector<f32>> {
293 if self.len() != other.len() {
294 return Err(TruenoError::SizeMismatch { expected: self.len(), actual: other.len() });
295 }
296
297 // Uninit: backend writes every element before any read.
298 let n = self.len();
299 let mut result_data: Vec<f32> = Vec::with_capacity(n);
300 // SAFETY: Backend writes all elements before any read.
301 unsafe {
302 result_data.set_len(n);
303 }
304
305 if !self.as_slice().is_empty() {
306 // SAFETY: Unsafe block delegates to backend implementation which maintains safety invariants
307 unsafe {
308 match self.backend() {
309 Backend::Scalar => {
310 ScalarBackend::lerp(self.as_slice(), other.as_slice(), t, &mut result_data)
311 }
312 #[cfg(target_arch = "x86_64")]
313 Backend::SSE2 | Backend::AVX => {
314 Sse2Backend::lerp(self.as_slice(), other.as_slice(), t, &mut result_data)
315 }
316 #[cfg(target_arch = "x86_64")]
317 Backend::AVX2 | Backend::AVX512 => {
318 Avx2Backend::lerp(self.as_slice(), other.as_slice(), t, &mut result_data)
319 }
320 #[cfg(any(target_arch = "aarch64", target_arch = "arm"))]
321 Backend::NEON => {
322 NeonBackend::lerp(self.as_slice(), other.as_slice(), t, &mut result_data)
323 }
324 #[cfg(target_arch = "wasm32")]
325 Backend::WasmSIMD => {
326 WasmBackend::lerp(self.as_slice(), other.as_slice(), t, &mut result_data)
327 }
328 Backend::GPU => return Err(TruenoError::UnsupportedBackend(Backend::GPU)),
329 Backend::Auto => {
330 return Err(TruenoError::UnsupportedBackend(Backend::Auto));
331 }
332 #[cfg(not(target_arch = "x86_64"))]
333 Backend::SSE2 | Backend::AVX | Backend::AVX2 | Backend::AVX512 => {
334 ScalarBackend::lerp(self.as_slice(), other.as_slice(), t, &mut result_data)
335 }
336 #[cfg(not(any(target_arch = "aarch64", target_arch = "arm")))]
337 Backend::NEON => {
338 ScalarBackend::lerp(self.as_slice(), other.as_slice(), t, &mut result_data)
339 }
340 #[cfg(not(target_arch = "wasm32"))]
341 Backend::WasmSIMD => {
342 ScalarBackend::lerp(self.as_slice(), other.as_slice(), t, &mut result_data)
343 }
344 }
345 }
346 }
347
348 Ok(Vector { data: result_data, backend: self.backend() })
349 }
350}
351
352#[cfg(test)]
353mod tests;