1use alloc::vec::Vec;
2use burn_std::{DType, Shape, Slice};
3
4use crate::{
5 AutodiffBackend, Backend, Distribution, ExecutionError, TensorData, TensorPrimitive,
6 element::ElementConversion,
7 ops::TransactionPrimitive,
8 tensor::{
9 BasicAutodiffOps, BasicOps, Device, Float, IndexingUpdateOp, IntTensor, Numeric, TensorKind,
10 },
11};
12
13macro_rules! q_bin_ops {
14 ($lhs:ident, $rhs:ident, $op:ident, $q_op:ident) => {
15 match ($lhs, $rhs) {
16 (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
17 TensorPrimitive::Float(B::$op(lhs, rhs))
18 }
19 (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => B::$q_op(lhs, rhs),
20 (TensorPrimitive::QFloat(lhs), TensorPrimitive::Float(rhs)) => {
21 TensorPrimitive::Float(B::$op(B::dequantize(lhs), rhs))
22 }
23 (TensorPrimitive::Float(lhs), TensorPrimitive::QFloat(rhs)) => {
24 TensorPrimitive::Float(B::$op(lhs, B::dequantize(rhs)))
25 }
26 }
27 };
28}
29
30impl<B: Backend> BasicOps<B> for Float {
31 type Elem = B::FloatElem;
32
33 fn empty(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
34 TensorPrimitive::Float(B::float_empty(shape, device, dtype.into()))
35 }
36
37 fn zeros(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
38 TensorPrimitive::Float(B::float_zeros(shape, device, dtype.into()))
39 }
40 fn ones(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
41 TensorPrimitive::Float(B::float_ones(shape, device, dtype.into()))
42 }
43
44 fn full<E: ElementConversion>(
45 shape: Shape,
46 fill_value: E,
47 device: &Device<B>,
48 dtype: DType,
49 ) -> Self::Primitive {
50 TensorPrimitive::Float(B::float_full(
51 shape,
52 fill_value.elem(),
53 device,
54 dtype.into(),
55 ))
56 }
57
58 fn register_transaction(tr: &mut TransactionPrimitive<B>, tensor: Self::Primitive) {
59 tr.register_float(tensor);
60 }
61
62 fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
63 match tensor {
64 TensorPrimitive::Float(tensor) => {
65 TensorPrimitive::Float(B::float_reshape(tensor, shape))
66 }
67 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_reshape(tensor, shape)),
68 }
69 }
70
71 fn transpose(tensor: Self::Primitive) -> Self::Primitive {
72 match tensor {
73 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_transpose(tensor)),
74 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_transpose(tensor)),
75 }
76 }
77
78 fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
79 match tensor {
80 TensorPrimitive::Float(tensor) => {
81 TensorPrimitive::Float(B::float_swap_dims(tensor, dim1, dim2))
82 }
83 TensorPrimitive::QFloat(tensor) => {
84 TensorPrimitive::QFloat(B::q_swap_dims(tensor, dim1, dim2))
85 }
86 }
87 }
88
89 fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive {
90 match tensor {
91 TensorPrimitive::Float(tensor) => {
92 TensorPrimitive::Float(B::float_slice(tensor, slices))
93 }
94 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_slice(tensor, slices)),
95 }
96 }
97
98 fn slice_assign(
99 tensor: Self::Primitive,
100 slices: &[Slice],
101 value: Self::Primitive,
102 ) -> Self::Primitive {
103 TensorPrimitive::Float(B::float_slice_assign(
104 tensor.tensor(),
105 slices,
106 value.tensor(),
107 ))
108 }
109
110 fn select(tensor: Self::Primitive, dim: usize, indices: IntTensor<B>) -> Self::Primitive {
111 match tensor {
112 TensorPrimitive::Float(tensor) => {
113 TensorPrimitive::Float(B::float_select(tensor, dim, indices))
114 }
115 TensorPrimitive::QFloat(tensor) => {
116 TensorPrimitive::QFloat(B::q_select(tensor, dim, indices))
117 }
118 }
119 }
120
121 fn select_assign(
122 tensor: Self::Primitive,
123 dim: usize,
124 indices: IntTensor<B>,
125 values: Self::Primitive,
126 update: IndexingUpdateOp,
127 ) -> Self::Primitive {
128 match update {
130 IndexingUpdateOp::Add => TensorPrimitive::Float(B::float_select_add(
131 tensor.tensor(),
132 dim,
133 indices,
134 values.tensor(),
135 )),
136 }
137 }
138
139 fn mask_where(
140 tensor: Self::Primitive,
141 mask: B::BoolTensorPrimitive,
142 source: Self::Primitive,
143 ) -> Self::Primitive {
144 TensorPrimitive::Float(B::float_mask_where(tensor.tensor(), mask, source.tensor()))
145 }
146
147 fn mask_fill(
148 tensor: Self::Primitive,
149 mask: B::BoolTensorPrimitive,
150 value: Self::Elem,
151 ) -> Self::Primitive {
152 TensorPrimitive::Float(B::float_mask_fill(tensor.tensor(), mask, value))
153 }
154
155 fn gather(dim: usize, tensor: Self::Primitive, indices: IntTensor<B>) -> Self::Primitive {
156 match tensor {
157 TensorPrimitive::Float(tensor) => {
158 TensorPrimitive::Float(B::float_gather(dim, tensor, indices))
159 }
160 TensorPrimitive::QFloat(tensor) => {
161 TensorPrimitive::QFloat(B::q_gather(dim, tensor, indices))
162 }
163 }
164 }
165
166 fn scatter(
167 dim: usize,
168 tensor: Self::Primitive,
169 indices: IntTensor<B>,
170 values: Self::Primitive,
171 update: IndexingUpdateOp,
172 ) -> Self::Primitive {
173 match update {
174 IndexingUpdateOp::Add => TensorPrimitive::Float(B::float_scatter_add(
175 dim,
176 tensor.tensor(),
177 indices,
178 values.tensor(),
179 )),
180 }
181 }
182
183 fn device(tensor: &Self::Primitive) -> Device<B> {
184 match tensor {
185 TensorPrimitive::Float(tensor) => B::float_device(tensor),
186 TensorPrimitive::QFloat(tensor) => B::q_device(tensor),
187 }
188 }
189
190 fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
191 match tensor {
192 TensorPrimitive::Float(tensor) => {
193 TensorPrimitive::Float(B::float_to_device(tensor, device))
194 }
195 TensorPrimitive::QFloat(tensor) => {
196 TensorPrimitive::QFloat(B::q_to_device(tensor, device))
197 }
198 }
199 }
200
201 async fn into_data_async(tensor: Self::Primitive) -> Result<TensorData, ExecutionError> {
202 match tensor {
203 TensorPrimitive::Float(tensor) => B::float_into_data(tensor).await,
204 TensorPrimitive::QFloat(tensor) => B::q_into_data(tensor).await,
205 }
206 }
207
208 fn from_data(data: TensorData, device: &Device<B>) -> Self::Primitive {
209 match &data.dtype {
210 DType::QFloat(_scheme) => TensorPrimitive::QFloat(B::q_from_data(data, device)),
211 _ => TensorPrimitive::Float(B::float_from_data(data.convert::<B::FloatElem>(), device)),
212 }
213 }
214
215 fn from_data_dtype(data: TensorData, device: &Device<B>, dtype: DType) -> Self::Primitive {
216 match dtype {
217 DType::QFloat(_scheme) => {
218 TensorPrimitive::QFloat(B::q_from_data(data.convert_dtype(dtype), device))
219 }
220 _ if dtype.is_float() => {
221 TensorPrimitive::Float(B::float_from_data(data.convert_dtype(dtype), device))
222 }
223 _ => panic!("Expected float dtype, got {dtype:?}"),
224 }
225 }
226
227 fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
228 match tensor {
229 TensorPrimitive::Float(tensor) => {
230 TensorPrimitive::Float(B::float_repeat_dim(tensor, dim, times))
231 }
232 TensorPrimitive::QFloat(tensor) => {
233 TensorPrimitive::QFloat(B::q_repeat_dim(tensor, dim, times))
234 }
235 }
236 }
237
238 fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
239 match vectors.first().unwrap() {
240 TensorPrimitive::Float(_) => TensorPrimitive::Float(B::float_cat(
241 vectors.into_iter().map(|tensor| tensor.tensor()).collect(),
242 dim,
243 )),
244 TensorPrimitive::QFloat(_) => TensorPrimitive::QFloat(B::q_cat(
245 vectors
246 .into_iter()
247 .map(|tensor| {
248 if let TensorPrimitive::QFloat(t) = tensor {
249 t
250 } else {
251 panic!("Concatenation only works with vector of QFloat")
252 }
253 })
254 .collect(),
255 dim,
256 )),
257 }
258 }
259
260 fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
261 B::float_equal(lhs.tensor(), rhs.tensor())
262 }
263
264 fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
265 B::float_not_equal(lhs.tensor(), rhs.tensor())
266 }
267
268 fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
269 B::float_equal_elem(lhs.tensor(), rhs)
270 }
271
272 fn not_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
273 B::float_not_equal_elem(lhs.tensor(), rhs)
274 }
275
276 fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
277 B::float_any(tensor.tensor())
278 }
279
280 fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
281 B::float_any_dim(tensor.tensor(), dim)
282 }
283
284 fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
285 B::float_all(tensor.tensor())
286 }
287
288 fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
289 B::float_all_dim(tensor.tensor(), dim)
290 }
291
292 fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
293 match tensor {
294 TensorPrimitive::Float(tensor) => {
295 TensorPrimitive::Float(B::float_permute(tensor, axes))
296 }
297 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_permute(tensor, axes)),
298 }
299 }
300
301 fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
302 TensorPrimitive::Float(B::float_expand(tensor.tensor(), shape))
303 }
304
305 fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
306 match tensor {
307 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_flip(tensor, axes)),
308 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_flip(tensor, axes)),
309 }
310 }
311
312 fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive {
313 TensorPrimitive::Float(B::float_unfold(tensor.tensor(), dim, size, step))
314 }
315}
316
317impl<B: Backend> Numeric<B> for Float {
318 fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
319 q_bin_ops!(lhs, rhs, float_add, q_add)
320 }
321
322 fn add_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
323 match lhs {
324 TensorPrimitive::Float(lhs) => {
325 TensorPrimitive::Float(B::float_add_scalar(lhs, rhs.elem()))
326 }
327 TensorPrimitive::QFloat(lhs) => B::q_add_scalar(lhs, rhs.elem()),
328 }
329 }
330
331 fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
332 q_bin_ops!(lhs, rhs, float_sub, q_sub)
333 }
334
335 fn sub_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
336 match lhs {
337 TensorPrimitive::Float(lhs) => {
338 TensorPrimitive::Float(B::float_sub_scalar(lhs, rhs.elem()))
339 }
340 TensorPrimitive::QFloat(lhs) => B::q_sub_scalar(lhs, rhs.elem()),
341 }
342 }
343
344 fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
345 q_bin_ops!(lhs, rhs, float_div, q_div)
346 }
347
348 fn div_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
349 match lhs {
350 TensorPrimitive::Float(lhs) => {
351 TensorPrimitive::Float(B::float_div_scalar(lhs, rhs.elem()))
352 }
353 TensorPrimitive::QFloat(lhs) => B::q_div_scalar(lhs, rhs.elem()),
354 }
355 }
356 fn remainder(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
357 TensorPrimitive::Float(B::float_remainder(lhs.tensor(), rhs.tensor()))
358 }
359
360 fn remainder_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
361 TensorPrimitive::Float(B::float_remainder_scalar(lhs.tensor(), rhs.elem()))
362 }
363
364 fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
365 q_bin_ops!(lhs, rhs, float_mul, q_mul)
366 }
367
368 fn mul_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
369 match lhs {
370 TensorPrimitive::Float(lhs) => {
371 TensorPrimitive::Float(B::float_mul_scalar(lhs, rhs.elem()))
372 }
373 TensorPrimitive::QFloat(lhs) => B::q_mul_scalar(lhs, rhs.elem()),
374 }
375 }
376 fn neg(tensor: Self::Primitive) -> Self::Primitive {
377 match tensor {
378 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_neg(tensor)),
379 TensorPrimitive::QFloat(tensor) => B::q_neg(tensor),
380 }
381 }
382
383 fn sum(tensor: Self::Primitive) -> Self::Primitive {
384 match tensor {
385 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_sum(tensor)),
386 TensorPrimitive::QFloat(tensor) => B::q_sum(tensor),
387 }
388 }
389
390 fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
391 match tensor {
392 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_sum_dim(tensor, dim)),
393 TensorPrimitive::QFloat(tensor) => B::q_sum_dim(tensor, dim),
394 }
395 }
396
397 fn prod(tensor: Self::Primitive) -> Self::Primitive {
398 match tensor {
399 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_prod(tensor)),
400 TensorPrimitive::QFloat(tensor) => B::q_prod(tensor),
401 }
402 }
403
404 fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
405 match tensor {
406 TensorPrimitive::Float(tensor) => {
407 TensorPrimitive::Float(B::float_prod_dim(tensor, dim))
408 }
409 TensorPrimitive::QFloat(tensor) => B::q_prod_dim(tensor, dim),
410 }
411 }
412
413 fn mean(tensor: Self::Primitive) -> Self::Primitive {
414 match tensor {
415 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_mean(tensor)),
416 TensorPrimitive::QFloat(tensor) => B::q_mean(tensor),
417 }
418 }
419
420 fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
421 match tensor {
422 TensorPrimitive::Float(tensor) => {
423 TensorPrimitive::Float(B::float_mean_dim(tensor, dim))
424 }
425 TensorPrimitive::QFloat(tensor) => B::q_mean_dim(tensor, dim),
426 }
427 }
428
429 fn cumsum(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
430 match tensor {
431 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cumsum(tensor, dim)),
432 TensorPrimitive::QFloat(tensor) => B::q_cumsum(tensor, dim),
433 }
434 }
435
436 fn cumprod(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
437 match tensor {
438 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cumprod(tensor, dim)),
439 TensorPrimitive::QFloat(tensor) => B::q_cumprod(tensor, dim),
440 }
441 }
442
443 fn cummin(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
444 match tensor {
445 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cummin(tensor, dim)),
446 TensorPrimitive::QFloat(tensor) => B::q_cummin(tensor, dim),
447 }
448 }
449
450 fn cummax(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
451 match tensor {
452 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cummax(tensor, dim)),
453 TensorPrimitive::QFloat(tensor) => B::q_cummax(tensor, dim),
454 }
455 }
456
457 fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
458 B::float_greater(lhs.tensor(), rhs.tensor())
459 }
460
461 fn greater_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
462 B::float_greater_elem(lhs.tensor(), rhs)
463 }
464
465 fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
466 B::float_greater_equal(lhs.tensor(), rhs.tensor())
467 }
468
469 fn greater_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
470 B::float_greater_equal_elem(lhs.tensor(), rhs)
471 }
472
473 fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
474 B::float_lower(lhs.tensor(), rhs.tensor())
475 }
476
477 fn lower_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
478 B::float_lower_elem(lhs.tensor(), rhs)
479 }
480
481 fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
482 B::float_lower_equal(lhs.tensor(), rhs.tensor())
483 }
484
485 fn lower_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
486 B::float_lower_equal_elem(lhs.tensor(), rhs)
487 }
488
489 fn argmax(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
490 match tensor {
491 TensorPrimitive::Float(tensor) => B::float_argmax(tensor, dim),
492 TensorPrimitive::QFloat(tensor) => B::q_argmax(tensor, dim),
493 }
494 }
495
496 fn argmin(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
497 match tensor {
498 TensorPrimitive::Float(tensor) => B::float_argmin(tensor, dim),
499 TensorPrimitive::QFloat(tensor) => B::q_argmin(tensor, dim),
500 }
501 }
502
503 fn max(tensor: Self::Primitive) -> Self::Primitive {
504 match tensor {
505 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max(tensor)),
506 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max(tensor)),
507 }
508 }
509
510 fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
511 match tensor {
512 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max_dim(tensor, dim)),
513 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max_dim(tensor, dim)),
514 }
515 }
516
517 fn max_dim_with_indices(
518 tensor: Self::Primitive,
519 dim: usize,
520 ) -> (Self::Primitive, IntTensor<B>) {
521 match tensor {
522 TensorPrimitive::Float(tensor) => {
523 let (values, indices) = B::float_max_dim_with_indices(tensor, dim);
524 (TensorPrimitive::Float(values), indices)
525 }
526 TensorPrimitive::QFloat(tensor) => {
527 let (values, indices) = B::q_max_dim_with_indices(tensor, dim);
528 (TensorPrimitive::QFloat(values), indices)
529 }
530 }
531 }
532
533 fn min(tensor: Self::Primitive) -> Self::Primitive {
534 match tensor {
535 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_min(tensor)),
536 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_min(tensor)),
537 }
538 }
539
540 fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
541 match tensor {
542 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_min_dim(tensor, dim)),
543 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_min_dim(tensor, dim)),
544 }
545 }
546
547 fn min_dim_with_indices(
548 tensor: Self::Primitive,
549 dim: usize,
550 ) -> (Self::Primitive, IntTensor<B>) {
551 match tensor {
552 TensorPrimitive::Float(tensor) => {
553 let (values, indices) = B::float_min_dim_with_indices(tensor, dim);
554 (TensorPrimitive::Float(values), indices)
555 }
556 TensorPrimitive::QFloat(tensor) => {
557 let (values, indices) = B::q_min_dim_with_indices(tensor, dim);
558 (TensorPrimitive::QFloat(values), indices)
559 }
560 }
561 }
562
563 fn clamp(tensor: Self::Primitive, min: B::FloatElem, max: B::FloatElem) -> Self::Primitive {
564 match tensor {
565 TensorPrimitive::Float(tensor) => {
566 TensorPrimitive::Float(B::float_clamp(tensor, min, max))
567 }
568 TensorPrimitive::QFloat(tensor) => B::q_clamp(tensor, min, max),
569 }
570 }
571
572 fn clamp_min(tensor: Self::Primitive, min: B::FloatElem) -> Self::Primitive {
573 match tensor {
574 TensorPrimitive::Float(tensor) => {
575 TensorPrimitive::Float(B::float_clamp_min(tensor, min))
576 }
577 TensorPrimitive::QFloat(tensor) => B::q_clamp_min(tensor, min),
578 }
579 }
580
581 fn clamp_max(tensor: Self::Primitive, max: B::FloatElem) -> Self::Primitive {
582 match tensor {
583 TensorPrimitive::Float(tensor) => {
584 TensorPrimitive::Float(B::float_clamp_max(tensor, max))
585 }
586 TensorPrimitive::QFloat(tensor) => B::q_clamp_max(tensor, max),
587 }
588 }
589
590 fn abs(tensor: Self::Primitive) -> Self::Primitive {
591 match tensor {
592 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_abs(tensor)),
593 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_abs(tensor)),
594 }
595 }
596
597 fn powf(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
598 q_bin_ops!(lhs, rhs, float_powf, q_powf)
599 }
600
601 fn powf_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
602 match lhs {
603 TensorPrimitive::Float(lhs) => {
604 TensorPrimitive::Float(B::float_powf_scalar(lhs, rhs.elem()))
605 }
606 TensorPrimitive::QFloat(lhs) => B::q_powf_scalar(lhs, rhs.elem()),
607 }
608 }
609
610 fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
611 q_bin_ops!(lhs, rhs, float_powf, q_powf)
612 }
613
614 fn powi_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
615 match lhs {
616 TensorPrimitive::Float(lhs) => {
617 TensorPrimitive::Float(B::float_powi_scalar(lhs, rhs.elem()))
618 }
619 TensorPrimitive::QFloat(lhs) => B::q_powi_scalar(lhs, rhs.elem()),
620 }
621 }
622
623 fn random(shape: Shape, distribution: Distribution, device: &Device<B>) -> Self::Primitive {
624 TensorPrimitive::Float(B::float_random(shape, distribution, device))
625 }
626
627 fn sign(tensor: Self::Primitive) -> Self::Primitive {
628 TensorPrimitive::Float(B::float_sign(tensor.tensor()))
629 }
630
631 fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive {
632 match tensor {
633 TensorPrimitive::Float(tensor) => {
634 TensorPrimitive::Float(B::float_sort(tensor, dim, descending))
635 }
636 TensorPrimitive::QFloat(tensor) => {
637 TensorPrimitive::QFloat(B::q_sort(tensor, dim, descending))
638 }
639 }
640 }
641
642 fn sort_with_indices(
643 tensor: Self::Primitive,
644 dim: usize,
645 descending: bool,
646 ) -> (Self::Primitive, IntTensor<B>) {
647 match tensor {
648 TensorPrimitive::Float(tensor) => {
649 let (values, indices) = B::float_sort_with_indices(tensor, dim, descending);
650 (TensorPrimitive::Float(values), indices)
651 }
652 TensorPrimitive::QFloat(tensor) => {
653 let (values, indices) = B::q_sort_with_indices(tensor, dim, descending);
654 (TensorPrimitive::QFloat(values), indices)
655 }
656 }
657 }
658
659 fn argsort(tensor: Self::Primitive, dim: usize, descending: bool) -> IntTensor<B> {
660 match tensor {
661 TensorPrimitive::Float(tensor) => B::float_argsort(tensor, dim, descending),
662 TensorPrimitive::QFloat(tensor) => B::q_argsort(tensor, dim, descending),
663 }
664 }
665
666 fn max_abs(tensor: Self::Primitive) -> Self::Primitive {
667 match tensor {
668 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max_abs(tensor)),
669 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max_abs(tensor)),
670 }
671 }
672
673 fn max_abs_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
674 match tensor {
675 TensorPrimitive::Float(tensor) => {
676 TensorPrimitive::Float(B::float_max_abs_dim(tensor, dim))
677 }
678 TensorPrimitive::QFloat(tensor) => {
679 TensorPrimitive::QFloat(B::q_max_abs_dim(tensor, dim))
680 }
681 }
682 }
683
684 fn matmul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
692 match (lhs, rhs) {
693 (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
694 TensorPrimitive::Float(B::float_matmul(lhs, rhs))
695 }
696 (lhs, rhs) => B::q_matmul(lhs, rhs),
697 }
698 }
699}
700
701impl<B: AutodiffBackend> BasicAutodiffOps<B> for Float {
702 type InnerKind = Float;
703
704 fn inner(
705 tensor: <Self as TensorKind<B>>::Primitive,
706 ) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive {
707 match tensor {
708 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::inner(tensor)),
709 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_inner(tensor)),
710 }
711 }
712
713 fn from_inner(
714 inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive,
715 ) -> <Self as TensorKind<B>>::Primitive {
716 match inner {
717 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::from_inner(tensor)),
718 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_from_inner(tensor)),
719 }
720 }
721}