1use crate::dtype::DType;
7use crate::error::Result;
8use crate::ops::traits::{
9 ActivationOps, BinaryOps, CompareOps, ConvOps, CumulativeOps, IndexingOps, MatmulOps,
10 NormalizationOps, PaddingMode, ReduceOps, ScalarOps, ShapeOps, TypeConversionOps, UnaryOps,
11 UtilityOps,
12};
13use crate::runtime::Runtime;
14use crate::tensor::Tensor;
15
16impl<R: Runtime> Tensor<R>
21where
22 R::Client: BinaryOps<R>,
23{
24 pub fn add(&self, other: &Tensor<R>) -> Result<Tensor<R>> {
26 let client = R::default_client(self.device());
27 client.add(self, other)
28 }
29
30 pub fn sub(&self, other: &Tensor<R>) -> Result<Tensor<R>> {
32 let client = R::default_client(self.device());
33 client.sub(self, other)
34 }
35
36 pub fn mul(&self, other: &Tensor<R>) -> Result<Tensor<R>> {
38 let client = R::default_client(self.device());
39 client.mul(self, other)
40 }
41
42 pub fn div(&self, other: &Tensor<R>) -> Result<Tensor<R>> {
44 let client = R::default_client(self.device());
45 client.div(self, other)
46 }
47
48 pub fn pow(&self, other: &Tensor<R>) -> Result<Tensor<R>> {
50 let client = R::default_client(self.device());
51 client.pow(self, other)
52 }
53
54 pub fn maximum(&self, other: &Tensor<R>) -> Result<Tensor<R>> {
56 let client = R::default_client(self.device());
57 client.maximum(self, other)
58 }
59
60 pub fn minimum(&self, other: &Tensor<R>) -> Result<Tensor<R>> {
62 let client = R::default_client(self.device());
63 client.minimum(self, other)
64 }
65}
66
67impl<R: Runtime> Tensor<R>
72where
73 R::Client: UnaryOps<R>,
74{
75 pub fn neg(&self) -> Result<Tensor<R>> {
77 let client = R::default_client(self.device());
78 client.neg(self)
79 }
80
81 pub fn abs(&self) -> Result<Tensor<R>> {
83 let client = R::default_client(self.device());
84 client.abs(self)
85 }
86
87 pub fn sqrt(&self) -> Result<Tensor<R>> {
89 let client = R::default_client(self.device());
90 client.sqrt(self)
91 }
92
93 pub fn exp(&self) -> Result<Tensor<R>> {
95 let client = R::default_client(self.device());
96 client.exp(self)
97 }
98
99 pub fn log(&self) -> Result<Tensor<R>> {
101 let client = R::default_client(self.device());
102 client.log(self)
103 }
104
105 pub fn sin(&self) -> Result<Tensor<R>> {
107 let client = R::default_client(self.device());
108 client.sin(self)
109 }
110
111 pub fn cos(&self) -> Result<Tensor<R>> {
113 let client = R::default_client(self.device());
114 client.cos(self)
115 }
116
117 pub fn tan(&self) -> Result<Tensor<R>> {
119 let client = R::default_client(self.device());
120 client.tan(self)
121 }
122
123 pub fn tanh(&self) -> Result<Tensor<R>> {
125 let client = R::default_client(self.device());
126 client.tanh(self)
127 }
128
129 pub fn recip(&self) -> Result<Tensor<R>> {
131 let client = R::default_client(self.device());
132 client.recip(self)
133 }
134
135 pub fn floor(&self) -> Result<Tensor<R>> {
137 let client = R::default_client(self.device());
138 client.floor(self)
139 }
140
141 pub fn ceil(&self) -> Result<Tensor<R>> {
143 let client = R::default_client(self.device());
144 client.ceil(self)
145 }
146
147 pub fn round(&self) -> Result<Tensor<R>> {
149 let client = R::default_client(self.device());
150 client.round(self)
151 }
152}
153
154impl<R: Runtime> Tensor<R>
159where
160 R::Client: ScalarOps<R>,
161{
162 pub fn add_scalar(&self, scalar: f64) -> Result<Tensor<R>> {
164 let client = R::default_client(self.device());
165 client.add_scalar(self, scalar)
166 }
167
168 pub fn mul_scalar(&self, scalar: f64) -> Result<Tensor<R>> {
170 let client = R::default_client(self.device());
171 client.mul_scalar(self, scalar)
172 }
173
174 pub fn scale(&self, scalar: f64) -> Result<Tensor<R>> {
176 self.mul_scalar(scalar)
177 }
178}
179
180impl<R: Runtime> Tensor<R>
185where
186 R::Client: ActivationOps<R>,
187{
188 pub fn relu(&self) -> Result<Tensor<R>> {
190 let client = R::default_client(self.device());
191 client.relu(self)
192 }
193
194 pub fn sigmoid(&self) -> Result<Tensor<R>> {
196 let client = R::default_client(self.device());
197 client.sigmoid(self)
198 }
199
200 pub fn gelu(&self) -> Result<Tensor<R>> {
202 let client = R::default_client(self.device());
203 client.gelu(self)
204 }
205
206 pub fn silu(&self) -> Result<Tensor<R>> {
208 let client = R::default_client(self.device());
209 client.silu(self)
210 }
211
212 pub fn softmax(&self, dim: isize) -> Result<Tensor<R>> {
214 let client = R::default_client(self.device());
215 client.softmax(self, dim)
216 }
217
218 pub fn log_softmax(&self, dim: isize) -> Result<Tensor<R>> {
220 let client = R::default_client(self.device());
221 client.log_softmax(self, dim)
222 }
223
224 pub fn dropout(&self, p: f64, training: bool) -> Result<Tensor<R>> {
226 let client = R::default_client(self.device());
227 client.dropout(self, p, training)
228 }
229}
230
231impl<R: Runtime> Tensor<R>
236where
237 R::Client: ReduceOps<R>,
238{
239 pub fn sum(&self, dims: &[usize], keepdim: bool) -> Result<Tensor<R>> {
241 let client = R::default_client(self.device());
242 client.sum(self, dims, keepdim)
243 }
244
245 pub fn mean(&self, dims: &[usize], keepdim: bool) -> Result<Tensor<R>> {
247 let client = R::default_client(self.device());
248 client.mean(self, dims, keepdim)
249 }
250
251 pub fn max(&self, dims: &[usize], keepdim: bool) -> Result<Tensor<R>> {
253 let client = R::default_client(self.device());
254 client.max(self, dims, keepdim)
255 }
256
257 pub fn min(&self, dims: &[usize], keepdim: bool) -> Result<Tensor<R>> {
259 let client = R::default_client(self.device());
260 client.min(self, dims, keepdim)
261 }
262}
263
264impl<R: Runtime> Tensor<R>
269where
270 R::Client: MatmulOps<R>,
271{
272 pub fn matmul(&self, other: &Tensor<R>) -> Result<Tensor<R>> {
274 let client = R::default_client(self.device());
275 client.matmul(self, other)
276 }
277}
278
279impl<R: Runtime> Tensor<R>
284where
285 R::Client: NormalizationOps<R>,
286{
287 pub fn rms_norm(&self, weight: &Tensor<R>, eps: f32) -> Result<Tensor<R>> {
289 let client = R::default_client(self.device());
290 client.rms_norm(self, weight, eps)
291 }
292
293 pub fn layer_norm(&self, weight: &Tensor<R>, bias: &Tensor<R>, eps: f32) -> Result<Tensor<R>> {
295 let client = R::default_client(self.device());
296 client.layer_norm(self, weight, bias, eps)
297 }
298}
299
300impl<R: Runtime> Tensor<R>
305where
306 R::Client: CompareOps<R>,
307{
308 pub fn eq(&self, other: &Tensor<R>) -> Result<Tensor<R>> {
310 let client = R::default_client(self.device());
311 client.eq(self, other)
312 }
313
314 pub fn gt(&self, other: &Tensor<R>) -> Result<Tensor<R>> {
316 let client = R::default_client(self.device());
317 client.gt(self, other)
318 }
319
320 pub fn lt(&self, other: &Tensor<R>) -> Result<Tensor<R>> {
322 let client = R::default_client(self.device());
323 client.lt(self, other)
324 }
325}
326
327impl<R: Runtime> Tensor<R>
332where
333 R::Client: IndexingOps<R>,
334{
335 pub fn index_select(&self, dim: usize, indices: &Tensor<R>) -> Result<Tensor<R>> {
337 let client = R::default_client(self.device());
338 client.index_select(self, dim, indices)
339 }
340
341 pub fn argmax(&self, dim: usize, keepdim: bool) -> Result<Tensor<R>> {
343 let client = R::default_client(self.device());
344 client.argmax(self, dim, keepdim)
345 }
346
347 pub fn argmin(&self, dim: usize, keepdim: bool) -> Result<Tensor<R>> {
349 let client = R::default_client(self.device());
350 client.argmin(self, dim, keepdim)
351 }
352
353 pub fn masked_fill(&self, mask: &Tensor<R>, value: f64) -> Result<Tensor<R>> {
355 let client = R::default_client(self.device());
356 client.masked_fill(self, mask, value)
357 }
358
359 pub fn slice_assign(&self, src: &Tensor<R>, dim: usize, start: usize) -> Result<Tensor<R>> {
363 let client = R::default_client(self.device());
364 client.slice_assign(self, src, dim, start)
365 }
366}
367
368impl<R: Runtime> Tensor<R>
373where
374 R::Client: ShapeOps<R>,
375{
376 pub fn cat(tensors: &[&Tensor<R>], dim: isize) -> Result<Tensor<R>> {
378 if tensors.is_empty() {
379 return Err(crate::error::Error::InvalidArgument {
380 arg: "tensors",
381 reason: "cannot concatenate empty list".into(),
382 });
383 }
384 let client = R::default_client(tensors[0].device());
385 client.cat(tensors, dim)
386 }
387
388 pub fn stack(tensors: &[&Tensor<R>], dim: isize) -> Result<Tensor<R>> {
390 if tensors.is_empty() {
391 return Err(crate::error::Error::InvalidArgument {
392 arg: "tensors",
393 reason: "cannot stack empty list".into(),
394 });
395 }
396 let client = R::default_client(tensors[0].device());
397 client.stack(tensors, dim)
398 }
399}
400
401impl<R: Runtime> Tensor<R>
406where
407 R::Client: CumulativeOps<R>,
408{
409 pub fn cumsum(&self, dim: isize) -> Result<Tensor<R>> {
411 let client = R::default_client(self.device());
412 client.cumsum(self, dim)
413 }
414
415 pub fn cumprod(&self, dim: isize) -> Result<Tensor<R>> {
417 let client = R::default_client(self.device());
418 client.cumprod(self, dim)
419 }
420
421 pub fn logsumexp(&self, dims: &[usize], keepdim: bool) -> Result<Tensor<R>> {
426 let client = R::default_client(self.device());
427 client.logsumexp(self, dims, keepdim)
428 }
429}
430
431impl<R: Runtime> Tensor<R>
436where
437 R::Client: TypeConversionOps<R>,
438{
439 pub fn to_dtype(&self, dtype: DType) -> Result<Tensor<R>> {
441 let client = R::default_client(self.device());
442 client.cast(self, dtype)
443 }
444}
445
446impl<R: Runtime> Tensor<R>
451where
452 R::Client: UtilityOps<R>,
453{
454 pub fn clamp(&self, min: f64, max: f64) -> Result<Tensor<R>> {
456 let client = R::default_client(self.device());
457 client.clamp(self, min, max)
458 }
459
460 pub fn one_hot(&self, num_classes: usize) -> Result<Tensor<R>> {
462 let client = R::default_client(self.device());
463 client.one_hot(self, num_classes)
464 }
465}
466
467impl<R: Runtime> Tensor<R>
472where
473 R::Client: ConvOps<R>,
474{
475 pub fn conv1d(
477 &self,
478 weight: &Tensor<R>,
479 bias: Option<&Tensor<R>>,
480 stride: usize,
481 padding: PaddingMode,
482 dilation: usize,
483 groups: usize,
484 ) -> Result<Tensor<R>> {
485 let client = R::default_client(self.device());
486 client.conv1d(self, weight, bias, stride, padding, dilation, groups)
487 }
488}