Skip to main content

numr/tensor/
ops.rs

1//! Convenience methods on Tensor that delegate to Client ops
2//!
3//! These methods provide ergonomic `tensor.add(&other)` style calls
4//! that internally get the client and delegate to the appropriate trait.
5
6use crate::dtype::DType;
7use crate::error::Result;
8use crate::ops::traits::{
9    ActivationOps, BinaryOps, CompareOps, ConvOps, CumulativeOps, IndexingOps, MatmulOps,
10    NormalizationOps, PaddingMode, ReduceOps, ScalarOps, ShapeOps, TypeConversionOps, UnaryOps,
11    UtilityOps,
12};
13use crate::runtime::Runtime;
14use crate::tensor::Tensor;
15
16// ============================================================================
17// Binary arithmetic
18// ============================================================================
19
20impl<R: Runtime> Tensor<R>
21where
22    R::Client: BinaryOps<R>,
23{
24    /// Element-wise addition: self + other
25    pub fn add(&self, other: &Tensor<R>) -> Result<Tensor<R>> {
26        let client = R::default_client(self.device());
27        client.add(self, other)
28    }
29
30    /// Element-wise subtraction: self - other
31    pub fn sub(&self, other: &Tensor<R>) -> Result<Tensor<R>> {
32        let client = R::default_client(self.device());
33        client.sub(self, other)
34    }
35
36    /// Element-wise multiplication: self * other
37    pub fn mul(&self, other: &Tensor<R>) -> Result<Tensor<R>> {
38        let client = R::default_client(self.device());
39        client.mul(self, other)
40    }
41
42    /// Element-wise division: self / other
43    pub fn div(&self, other: &Tensor<R>) -> Result<Tensor<R>> {
44        let client = R::default_client(self.device());
45        client.div(self, other)
46    }
47
48    /// Element-wise power: self ^ other
49    pub fn pow(&self, other: &Tensor<R>) -> Result<Tensor<R>> {
50        let client = R::default_client(self.device());
51        client.pow(self, other)
52    }
53
54    /// Element-wise maximum: max(self, other)
55    pub fn maximum(&self, other: &Tensor<R>) -> Result<Tensor<R>> {
56        let client = R::default_client(self.device());
57        client.maximum(self, other)
58    }
59
60    /// Element-wise minimum: min(self, other)
61    pub fn minimum(&self, other: &Tensor<R>) -> Result<Tensor<R>> {
62        let client = R::default_client(self.device());
63        client.minimum(self, other)
64    }
65}
66
67// ============================================================================
68// Unary operations
69// ============================================================================
70
71impl<R: Runtime> Tensor<R>
72where
73    R::Client: UnaryOps<R>,
74{
75    /// Element-wise negation
76    pub fn neg(&self) -> Result<Tensor<R>> {
77        let client = R::default_client(self.device());
78        client.neg(self)
79    }
80
81    /// Element-wise absolute value
82    pub fn abs(&self) -> Result<Tensor<R>> {
83        let client = R::default_client(self.device());
84        client.abs(self)
85    }
86
87    /// Element-wise square root
88    pub fn sqrt(&self) -> Result<Tensor<R>> {
89        let client = R::default_client(self.device());
90        client.sqrt(self)
91    }
92
93    /// Element-wise exponential
94    pub fn exp(&self) -> Result<Tensor<R>> {
95        let client = R::default_client(self.device());
96        client.exp(self)
97    }
98
99    /// Element-wise natural log
100    pub fn log(&self) -> Result<Tensor<R>> {
101        let client = R::default_client(self.device());
102        client.log(self)
103    }
104
105    /// Element-wise sine
106    pub fn sin(&self) -> Result<Tensor<R>> {
107        let client = R::default_client(self.device());
108        client.sin(self)
109    }
110
111    /// Element-wise cosine
112    pub fn cos(&self) -> Result<Tensor<R>> {
113        let client = R::default_client(self.device());
114        client.cos(self)
115    }
116
117    /// Element-wise tangent
118    pub fn tan(&self) -> Result<Tensor<R>> {
119        let client = R::default_client(self.device());
120        client.tan(self)
121    }
122
123    /// Element-wise hyperbolic tangent
124    pub fn tanh(&self) -> Result<Tensor<R>> {
125        let client = R::default_client(self.device());
126        client.tanh(self)
127    }
128
129    /// Element-wise reciprocal (1/x)
130    pub fn recip(&self) -> Result<Tensor<R>> {
131        let client = R::default_client(self.device());
132        client.recip(self)
133    }
134
135    /// Element-wise floor
136    pub fn floor(&self) -> Result<Tensor<R>> {
137        let client = R::default_client(self.device());
138        client.floor(self)
139    }
140
141    /// Element-wise ceil
142    pub fn ceil(&self) -> Result<Tensor<R>> {
143        let client = R::default_client(self.device());
144        client.ceil(self)
145    }
146
147    /// Element-wise round
148    pub fn round(&self) -> Result<Tensor<R>> {
149        let client = R::default_client(self.device());
150        client.round(self)
151    }
152}
153
154// ============================================================================
155// Scalar operations
156// ============================================================================
157
158impl<R: Runtime> Tensor<R>
159where
160    R::Client: ScalarOps<R>,
161{
162    /// Add scalar: self + scalar
163    pub fn add_scalar(&self, scalar: f64) -> Result<Tensor<R>> {
164        let client = R::default_client(self.device());
165        client.add_scalar(self, scalar)
166    }
167
168    /// Multiply by scalar: self * scalar
169    pub fn mul_scalar(&self, scalar: f64) -> Result<Tensor<R>> {
170        let client = R::default_client(self.device());
171        client.mul_scalar(self, scalar)
172    }
173
174    /// Scale alias for mul_scalar
175    pub fn scale(&self, scalar: f64) -> Result<Tensor<R>> {
176        self.mul_scalar(scalar)
177    }
178}
179
180// ============================================================================
181// Activation functions
182// ============================================================================
183
184impl<R: Runtime> Tensor<R>
185where
186    R::Client: ActivationOps<R>,
187{
188    /// ReLU activation: max(0, x)
189    pub fn relu(&self) -> Result<Tensor<R>> {
190        let client = R::default_client(self.device());
191        client.relu(self)
192    }
193
194    /// Sigmoid activation: 1 / (1 + exp(-x))
195    pub fn sigmoid(&self) -> Result<Tensor<R>> {
196        let client = R::default_client(self.device());
197        client.sigmoid(self)
198    }
199
200    /// GELU activation
201    pub fn gelu(&self) -> Result<Tensor<R>> {
202        let client = R::default_client(self.device());
203        client.gelu(self)
204    }
205
206    /// SiLU/Swish activation: x * sigmoid(x)
207    pub fn silu(&self) -> Result<Tensor<R>> {
208        let client = R::default_client(self.device());
209        client.silu(self)
210    }
211
212    /// Softmax along dimension
213    pub fn softmax(&self, dim: isize) -> Result<Tensor<R>> {
214        let client = R::default_client(self.device());
215        client.softmax(self, dim)
216    }
217
218    /// Log-softmax along dimension: log(softmax(x, dim))
219    pub fn log_softmax(&self, dim: isize) -> Result<Tensor<R>> {
220        let client = R::default_client(self.device());
221        client.log_softmax(self, dim)
222    }
223
224    /// Dropout: randomly zero elements with probability `p` during training
225    pub fn dropout(&self, p: f64, training: bool) -> Result<Tensor<R>> {
226        let client = R::default_client(self.device());
227        client.dropout(self, p, training)
228    }
229}
230
231// ============================================================================
232// Reduction operations
233// ============================================================================
234
235impl<R: Runtime> Tensor<R>
236where
237    R::Client: ReduceOps<R>,
238{
239    /// Sum along dimensions
240    pub fn sum(&self, dims: &[usize], keepdim: bool) -> Result<Tensor<R>> {
241        let client = R::default_client(self.device());
242        client.sum(self, dims, keepdim)
243    }
244
245    /// Mean along dimensions
246    pub fn mean(&self, dims: &[usize], keepdim: bool) -> Result<Tensor<R>> {
247        let client = R::default_client(self.device());
248        client.mean(self, dims, keepdim)
249    }
250
251    /// Max along dimensions
252    pub fn max(&self, dims: &[usize], keepdim: bool) -> Result<Tensor<R>> {
253        let client = R::default_client(self.device());
254        client.max(self, dims, keepdim)
255    }
256
257    /// Min along dimensions
258    pub fn min(&self, dims: &[usize], keepdim: bool) -> Result<Tensor<R>> {
259        let client = R::default_client(self.device());
260        client.min(self, dims, keepdim)
261    }
262}
263
264// ============================================================================
265// Matrix operations
266// ============================================================================
267
268impl<R: Runtime> Tensor<R>
269where
270    R::Client: MatmulOps<R>,
271{
272    /// Matrix multiplication: self @ other
273    pub fn matmul(&self, other: &Tensor<R>) -> Result<Tensor<R>> {
274        let client = R::default_client(self.device());
275        client.matmul(self, other)
276    }
277}
278
279// ============================================================================
280// Normalization
281// ============================================================================
282
283impl<R: Runtime> Tensor<R>
284where
285    R::Client: NormalizationOps<R>,
286{
287    /// RMS normalization: x / RMS(x) * weight
288    pub fn rms_norm(&self, weight: &Tensor<R>, eps: f32) -> Result<Tensor<R>> {
289        let client = R::default_client(self.device());
290        client.rms_norm(self, weight, eps)
291    }
292
293    /// Layer normalization: (x - mean) / sqrt(var + eps) * weight + bias
294    pub fn layer_norm(&self, weight: &Tensor<R>, bias: &Tensor<R>, eps: f32) -> Result<Tensor<R>> {
295        let client = R::default_client(self.device());
296        client.layer_norm(self, weight, bias, eps)
297    }
298}
299
300// ============================================================================
301// Comparison operations
302// ============================================================================
303
304impl<R: Runtime> Tensor<R>
305where
306    R::Client: CompareOps<R>,
307{
308    /// Element-wise equality
309    pub fn eq(&self, other: &Tensor<R>) -> Result<Tensor<R>> {
310        let client = R::default_client(self.device());
311        client.eq(self, other)
312    }
313
314    /// Element-wise greater than
315    pub fn gt(&self, other: &Tensor<R>) -> Result<Tensor<R>> {
316        let client = R::default_client(self.device());
317        client.gt(self, other)
318    }
319
320    /// Element-wise less than
321    pub fn lt(&self, other: &Tensor<R>) -> Result<Tensor<R>> {
322        let client = R::default_client(self.device());
323        client.lt(self, other)
324    }
325}
326
327// ============================================================================
328// Indexing operations
329// ============================================================================
330
331impl<R: Runtime> Tensor<R>
332where
333    R::Client: IndexingOps<R>,
334{
335    /// Select elements along a dimension using indices
336    pub fn index_select(&self, dim: usize, indices: &Tensor<R>) -> Result<Tensor<R>> {
337        let client = R::default_client(self.device());
338        client.index_select(self, dim, indices)
339    }
340
341    /// Argmax along a dimension
342    pub fn argmax(&self, dim: usize, keepdim: bool) -> Result<Tensor<R>> {
343        let client = R::default_client(self.device());
344        client.argmax(self, dim, keepdim)
345    }
346
347    /// Argmin along a dimension
348    pub fn argmin(&self, dim: usize, keepdim: bool) -> Result<Tensor<R>> {
349        let client = R::default_client(self.device());
350        client.argmin(self, dim, keepdim)
351    }
352
353    /// Fill tensor with value where mask is true
354    pub fn masked_fill(&self, mask: &Tensor<R>, value: f64) -> Result<Tensor<R>> {
355        let client = R::default_client(self.device());
356        client.masked_fill(self, mask, value)
357    }
358
359    /// Assign `src` into a slice of `self` along `dim` starting at `start`.
360    ///
361    /// Returns a new tensor with the slice region replaced by `src`.
362    pub fn slice_assign(&self, src: &Tensor<R>, dim: usize, start: usize) -> Result<Tensor<R>> {
363        let client = R::default_client(self.device());
364        client.slice_assign(self, src, dim, start)
365    }
366}
367
368// ============================================================================
369// Shape operations
370// ============================================================================
371
372impl<R: Runtime> Tensor<R>
373where
374    R::Client: ShapeOps<R>,
375{
376    /// Concatenate tensors along a dimension
377    pub fn cat(tensors: &[&Tensor<R>], dim: isize) -> Result<Tensor<R>> {
378        if tensors.is_empty() {
379            return Err(crate::error::Error::InvalidArgument {
380                arg: "tensors",
381                reason: "cannot concatenate empty list".into(),
382            });
383        }
384        let client = R::default_client(tensors[0].device());
385        client.cat(tensors, dim)
386    }
387
388    /// Stack tensors along a new dimension
389    pub fn stack(tensors: &[&Tensor<R>], dim: isize) -> Result<Tensor<R>> {
390        if tensors.is_empty() {
391            return Err(crate::error::Error::InvalidArgument {
392                arg: "tensors",
393                reason: "cannot stack empty list".into(),
394            });
395        }
396        let client = R::default_client(tensors[0].device());
397        client.stack(tensors, dim)
398    }
399}
400
401// ============================================================================
402// Cumulative operations
403// ============================================================================
404
405impl<R: Runtime> Tensor<R>
406where
407    R::Client: CumulativeOps<R>,
408{
409    /// Cumulative sum along a dimension
410    pub fn cumsum(&self, dim: isize) -> Result<Tensor<R>> {
411        let client = R::default_client(self.device());
412        client.cumsum(self, dim)
413    }
414
415    /// Cumulative product along a dimension
416    pub fn cumprod(&self, dim: isize) -> Result<Tensor<R>> {
417        let client = R::default_client(self.device());
418        client.cumprod(self, dim)
419    }
420
421    /// Log-sum-exp along specified dimensions (numerically stable)
422    ///
423    /// Computes `log(sum(exp(x)))` in a numerically stable way:
424    /// `logsumexp(x) = max(x) + log(sum(exp(x - max(x))))`
425    pub fn logsumexp(&self, dims: &[usize], keepdim: bool) -> Result<Tensor<R>> {
426        let client = R::default_client(self.device());
427        client.logsumexp(self, dims, keepdim)
428    }
429}
430
431// ============================================================================
432// Type conversion
433// ============================================================================
434
435impl<R: Runtime> Tensor<R>
436where
437    R::Client: TypeConversionOps<R>,
438{
439    /// Convert tensor to a different dtype
440    pub fn to_dtype(&self, dtype: DType) -> Result<Tensor<R>> {
441        let client = R::default_client(self.device());
442        client.cast(self, dtype)
443    }
444}
445
446// ============================================================================
447// Utility operations
448// ============================================================================
449
450impl<R: Runtime> Tensor<R>
451where
452    R::Client: UtilityOps<R>,
453{
454    /// Clamp values to [min, max]
455    pub fn clamp(&self, min: f64, max: f64) -> Result<Tensor<R>> {
456        let client = R::default_client(self.device());
457        client.clamp(self, min, max)
458    }
459
460    /// One-hot encode indices
461    pub fn one_hot(&self, num_classes: usize) -> Result<Tensor<R>> {
462        let client = R::default_client(self.device());
463        client.one_hot(self, num_classes)
464    }
465}
466
467// ============================================================================
468// Convolution operations
469// ============================================================================
470
471impl<R: Runtime> Tensor<R>
472where
473    R::Client: ConvOps<R>,
474{
475    /// 1D convolution
476    pub fn conv1d(
477        &self,
478        weight: &Tensor<R>,
479        bias: Option<&Tensor<R>>,
480        stride: usize,
481        padding: PaddingMode,
482        dilation: usize,
483        groups: usize,
484    ) -> Result<Tensor<R>> {
485        let client = R::default_client(self.device());
486        client.conv1d(self, weight, bias, stride, padding, dilation, groups)
487    }
488}