1use alloc::vec::Vec;
2use burn_std::{DType, Shape, Slice};
3
4use crate::{
5 AutodiffBackend, Backend, Distribution, ExecutionError, Scalar, TensorData, TensorMetadata,
6 TensorPrimitive, get_device_settings,
7 ops::TransactionPrimitive,
8 tensor::{
9 BasicAutodiffOps, BasicOps, Device, Float, IndexingUpdateOp, IntTensor, Numeric, Ordered,
10 TensorKind, TransactionOp,
11 },
12};
13
14macro_rules! q_bin_ops {
15 ($lhs:ident, $rhs:ident, $op:ident, $q_op:ident) => {
16 match ($lhs, $rhs) {
17 (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
18 TensorPrimitive::Float(B::$op(lhs, rhs))
19 }
20 (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => B::$q_op(lhs, rhs),
21 (TensorPrimitive::QFloat(lhs), TensorPrimitive::Float(rhs)) => {
22 let dtype = rhs.dtype();
23 TensorPrimitive::Float(B::$op(B::dequantize(lhs, dtype.into()), rhs))
24 }
25 (TensorPrimitive::Float(lhs), TensorPrimitive::QFloat(rhs)) => {
26 let dtype = lhs.dtype();
27 TensorPrimitive::Float(B::$op(lhs, B::dequantize(rhs, dtype.into())))
28 }
29 }
30 };
31}
32impl<B: Backend> TransactionOp<B> for Float {
33 fn register_transaction(tr: &mut TransactionPrimitive<B>, tensor: Self::Primitive) {
34 tr.register_float(tensor);
35 }
36}
37impl<B: Backend> BasicOps<B> for Float {
38 type Elem = B::FloatElem;
39
40 fn empty(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
41 TensorPrimitive::Float(B::float_empty(shape, device, dtype.into()))
42 }
43
44 fn zeros(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
45 TensorPrimitive::Float(B::float_zeros(shape, device, dtype.into()))
46 }
47 fn ones(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
48 TensorPrimitive::Float(B::float_ones(shape, device, dtype.into()))
49 }
50
51 fn full(shape: Shape, fill_value: Scalar, device: &Device<B>, dtype: DType) -> Self::Primitive {
52 TensorPrimitive::Float(B::float_full(shape, fill_value, device, dtype.into()))
53 }
54
55 fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
56 match tensor {
57 TensorPrimitive::Float(tensor) => {
58 TensorPrimitive::Float(B::float_reshape(tensor, shape))
59 }
60 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_reshape(tensor, shape)),
61 }
62 }
63
64 fn transpose(tensor: Self::Primitive) -> Self::Primitive {
65 match tensor {
66 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_transpose(tensor)),
67 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_transpose(tensor)),
68 }
69 }
70
71 fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
72 match tensor {
73 TensorPrimitive::Float(tensor) => {
74 TensorPrimitive::Float(B::float_swap_dims(tensor, dim1, dim2))
75 }
76 TensorPrimitive::QFloat(tensor) => {
77 TensorPrimitive::QFloat(B::q_swap_dims(tensor, dim1, dim2))
78 }
79 }
80 }
81
82 fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive {
83 match tensor {
84 TensorPrimitive::Float(tensor) => {
85 TensorPrimitive::Float(B::float_slice(tensor, slices))
86 }
87 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_slice(tensor, slices)),
88 }
89 }
90
91 fn slice_assign(
92 tensor: Self::Primitive,
93 slices: &[Slice],
94 value: Self::Primitive,
95 ) -> Self::Primitive {
96 TensorPrimitive::Float(B::float_slice_assign(
97 tensor.tensor(),
98 slices,
99 value.tensor(),
100 ))
101 }
102
103 fn select(tensor: Self::Primitive, dim: usize, indices: IntTensor<B>) -> Self::Primitive {
104 match tensor {
105 TensorPrimitive::Float(tensor) => {
106 TensorPrimitive::Float(B::float_select(tensor, dim, indices))
107 }
108 TensorPrimitive::QFloat(tensor) => {
109 TensorPrimitive::QFloat(B::q_select(tensor, dim, indices))
110 }
111 }
112 }
113
114 fn select_assign(
115 tensor: Self::Primitive,
116 dim: usize,
117 indices: IntTensor<B>,
118 values: Self::Primitive,
119 update: IndexingUpdateOp,
120 ) -> Self::Primitive {
121 match update {
123 IndexingUpdateOp::Add => TensorPrimitive::Float(B::float_select_add(
124 tensor.tensor(),
125 dim,
126 indices,
127 values.tensor(),
128 )),
129 }
130 }
131
132 fn mask_where(
133 tensor: Self::Primitive,
134 mask: B::BoolTensorPrimitive,
135 source: Self::Primitive,
136 ) -> Self::Primitive {
137 TensorPrimitive::Float(B::float_mask_where(tensor.tensor(), mask, source.tensor()))
138 }
139
140 fn mask_fill(
141 tensor: Self::Primitive,
142 mask: B::BoolTensorPrimitive,
143 value: Scalar,
144 ) -> Self::Primitive {
145 TensorPrimitive::Float(B::float_mask_fill(tensor.tensor(), mask, value))
146 }
147
148 fn gather(dim: usize, tensor: Self::Primitive, indices: IntTensor<B>) -> Self::Primitive {
149 match tensor {
150 TensorPrimitive::Float(tensor) => {
151 TensorPrimitive::Float(B::float_gather(dim, tensor, indices))
152 }
153 TensorPrimitive::QFloat(tensor) => {
154 TensorPrimitive::QFloat(B::q_gather(dim, tensor, indices))
155 }
156 }
157 }
158
159 fn scatter(
160 dim: usize,
161 tensor: Self::Primitive,
162 indices: IntTensor<B>,
163 values: Self::Primitive,
164 update: IndexingUpdateOp,
165 ) -> Self::Primitive {
166 match update {
167 IndexingUpdateOp::Add => TensorPrimitive::Float(B::float_scatter_add(
168 dim,
169 tensor.tensor(),
170 indices,
171 values.tensor(),
172 )),
173 }
174 }
175
176 fn device(tensor: &Self::Primitive) -> Device<B> {
177 match tensor {
178 TensorPrimitive::Float(tensor) => B::float_device(tensor),
179 TensorPrimitive::QFloat(tensor) => B::q_device(tensor),
180 }
181 }
182
183 fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
184 match tensor {
185 TensorPrimitive::Float(tensor) => {
186 TensorPrimitive::Float(B::float_to_device(tensor, device))
187 }
188 TensorPrimitive::QFloat(tensor) => {
189 TensorPrimitive::QFloat(B::q_to_device(tensor, device))
190 }
191 }
192 }
193
194 async fn into_data_async(tensor: Self::Primitive) -> Result<TensorData, ExecutionError> {
195 match tensor {
196 TensorPrimitive::Float(tensor) => B::float_into_data(tensor).await,
197 TensorPrimitive::QFloat(tensor) => B::q_into_data(tensor).await,
198 }
199 }
200
201 fn from_data(data: TensorData, device: &Device<B>, dtype: DType) -> Self::Primitive {
202 if matches!(data.dtype, DType::QFloat(_)) {
203 TensorPrimitive::QFloat(B::q_from_data(data, device))
205 } else if dtype.is_float() {
206 TensorPrimitive::Float(B::float_from_data(data.convert_dtype(dtype), device))
207 } else {
208 panic!("Expected float dtype, got {dtype:?}")
209 }
210 }
211
212 fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
213 match tensor {
214 TensorPrimitive::Float(tensor) => {
215 TensorPrimitive::Float(B::float_repeat_dim(tensor, dim, times))
216 }
217 TensorPrimitive::QFloat(tensor) => {
218 TensorPrimitive::QFloat(B::q_repeat_dim(tensor, dim, times))
219 }
220 }
221 }
222
223 fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
224 match vectors.first().unwrap() {
225 TensorPrimitive::Float(_) => TensorPrimitive::Float(B::float_cat(
226 vectors.into_iter().map(|tensor| tensor.tensor()).collect(),
227 dim,
228 )),
229 TensorPrimitive::QFloat(_) => TensorPrimitive::QFloat(B::q_cat(
230 vectors
231 .into_iter()
232 .map(|tensor| {
233 if let TensorPrimitive::QFloat(t) = tensor {
234 t
235 } else {
236 panic!("Concatenation only works with vector of QFloat")
237 }
238 })
239 .collect(),
240 dim,
241 )),
242 }
243 }
244
245 fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
246 let lhs = lhs.tensor();
247 let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
248 B::float_equal(lhs, rhs.tensor(), out_dtype)
249 }
250
251 fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
252 let lhs = lhs.tensor();
253 let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
254 B::float_not_equal(lhs, rhs.tensor(), out_dtype)
255 }
256
257 fn equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
258 let lhs = lhs.tensor();
259 let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
260 B::float_equal_elem(lhs, rhs, out_dtype)
261 }
262
263 fn not_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
264 let lhs = lhs.tensor();
265 let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
266 B::float_not_equal_elem(lhs, rhs, out_dtype)
267 }
268
269 fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
270 let tensor = tensor.tensor();
271 let out_dtype = get_device_settings::<B>(&B::float_device(&tensor)).bool_dtype;
272 B::float_any(tensor, out_dtype)
273 }
274
275 fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
276 let tensor = tensor.tensor();
277 let out_dtype = get_device_settings::<B>(&B::float_device(&tensor)).bool_dtype;
278 B::float_any_dim(tensor, dim, out_dtype)
279 }
280
281 fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
282 let tensor = tensor.tensor();
283 let out_dtype = get_device_settings::<B>(&B::float_device(&tensor)).bool_dtype;
284 B::float_all(tensor, out_dtype)
285 }
286
287 fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
288 let tensor = tensor.tensor();
289 let out_dtype = get_device_settings::<B>(&B::float_device(&tensor)).bool_dtype;
290 B::float_all_dim(tensor, dim, out_dtype)
291 }
292
293 fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
294 match tensor {
295 TensorPrimitive::Float(tensor) => {
296 TensorPrimitive::Float(B::float_permute(tensor, axes))
297 }
298 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_permute(tensor, axes)),
299 }
300 }
301
302 fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
303 TensorPrimitive::Float(B::float_expand(tensor.tensor(), shape))
304 }
305
306 fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
307 match tensor {
308 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_flip(tensor, axes)),
309 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_flip(tensor, axes)),
310 }
311 }
312
313 fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive {
314 TensorPrimitive::Float(B::float_unfold(tensor.tensor(), dim, size, step))
315 }
316}
317
318impl<B: Backend> Numeric<B> for Float {
319 fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
320 q_bin_ops!(lhs, rhs, float_add, q_add)
321 }
322
323 fn add_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
324 match lhs {
325 TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_add_scalar(lhs, rhs)),
326 TensorPrimitive::QFloat(lhs) => B::q_add_scalar(lhs, rhs),
327 }
328 }
329
330 fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
331 q_bin_ops!(lhs, rhs, float_sub, q_sub)
332 }
333
334 fn sub_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
335 match lhs {
336 TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_sub_scalar(lhs, rhs)),
337 TensorPrimitive::QFloat(lhs) => B::q_sub_scalar(lhs, rhs),
338 }
339 }
340
341 fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
342 q_bin_ops!(lhs, rhs, float_div, q_div)
343 }
344
345 fn div_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
346 match lhs {
347 TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_div_scalar(lhs, rhs)),
348 TensorPrimitive::QFloat(lhs) => B::q_div_scalar(lhs, rhs),
349 }
350 }
351 fn remainder(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
352 TensorPrimitive::Float(B::float_remainder(lhs.tensor(), rhs.tensor()))
353 }
354
355 fn remainder_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
356 TensorPrimitive::Float(B::float_remainder_scalar(lhs.tensor(), rhs))
357 }
358
359 fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
360 q_bin_ops!(lhs, rhs, float_mul, q_mul)
361 }
362
363 fn mul_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
364 match lhs {
365 TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_mul_scalar(lhs, rhs)),
366 TensorPrimitive::QFloat(lhs) => B::q_mul_scalar(lhs, rhs),
367 }
368 }
369 fn neg(tensor: Self::Primitive) -> Self::Primitive {
370 match tensor {
371 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_neg(tensor)),
372 TensorPrimitive::QFloat(tensor) => B::q_neg(tensor),
373 }
374 }
375
376 fn sum(tensor: Self::Primitive) -> Self::Primitive {
377 match tensor {
378 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_sum(tensor)),
379 TensorPrimitive::QFloat(tensor) => B::q_sum(tensor),
380 }
381 }
382
383 fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
384 match tensor {
385 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_sum_dim(tensor, dim)),
386 TensorPrimitive::QFloat(tensor) => B::q_sum_dim(tensor, dim),
387 }
388 }
389
390 fn prod(tensor: Self::Primitive) -> Self::Primitive {
391 match tensor {
392 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_prod(tensor)),
393 TensorPrimitive::QFloat(tensor) => B::q_prod(tensor),
394 }
395 }
396
397 fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
398 match tensor {
399 TensorPrimitive::Float(tensor) => {
400 TensorPrimitive::Float(B::float_prod_dim(tensor, dim))
401 }
402 TensorPrimitive::QFloat(tensor) => B::q_prod_dim(tensor, dim),
403 }
404 }
405
406 fn mean(tensor: Self::Primitive) -> Self::Primitive {
407 match tensor {
408 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_mean(tensor)),
409 TensorPrimitive::QFloat(tensor) => B::q_mean(tensor),
410 }
411 }
412
413 fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
414 match tensor {
415 TensorPrimitive::Float(tensor) => {
416 TensorPrimitive::Float(B::float_mean_dim(tensor, dim))
417 }
418 TensorPrimitive::QFloat(tensor) => B::q_mean_dim(tensor, dim),
419 }
420 }
421
422 fn cumsum(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
423 match tensor {
424 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cumsum(tensor, dim)),
425 TensorPrimitive::QFloat(tensor) => B::q_cumsum(tensor, dim),
426 }
427 }
428
429 fn cumprod(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
430 match tensor {
431 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cumprod(tensor, dim)),
432 TensorPrimitive::QFloat(tensor) => B::q_cumprod(tensor, dim),
433 }
434 }
435
436 fn abs(tensor: Self::Primitive) -> Self::Primitive {
437 match tensor {
438 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_abs(tensor)),
439 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_abs(tensor)),
440 }
441 }
442
443 fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
444 q_bin_ops!(lhs, rhs, float_powf, q_powf)
445 }
446
447 fn powi_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
448 match lhs {
449 TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_powi_scalar(lhs, rhs)),
450 TensorPrimitive::QFloat(lhs) => B::q_powi_scalar(lhs, rhs),
451 }
452 }
453
454 fn random(
455 shape: Shape,
456 distribution: Distribution,
457 device: &Device<B>,
458 dtype: DType,
459 ) -> Self::Primitive {
460 TensorPrimitive::Float(B::float_random(shape, distribution, device, dtype.into()))
461 }
462
463 fn sign(tensor: Self::Primitive) -> Self::Primitive {
464 TensorPrimitive::Float(B::float_sign(tensor.tensor()))
465 }
466
467 fn matmul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
475 match (lhs, rhs) {
476 (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
477 TensorPrimitive::Float(B::float_matmul(lhs, rhs))
478 }
479 (lhs, rhs) => B::q_matmul(lhs, rhs),
480 }
481 }
482}
483impl<B: Backend> Ordered<B> for Float {
484 fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive {
485 match tensor {
486 TensorPrimitive::Float(tensor) => {
487 TensorPrimitive::Float(B::float_sort(tensor, dim, descending))
488 }
489 TensorPrimitive::QFloat(tensor) => {
490 TensorPrimitive::QFloat(B::q_sort(tensor, dim, descending))
491 }
492 }
493 }
494
495 fn sort_with_indices(
496 tensor: Self::Primitive,
497 dim: usize,
498 descending: bool,
499 ) -> (Self::Primitive, IntTensor<B>) {
500 match tensor {
501 TensorPrimitive::Float(tensor) => {
502 let out_dtype = get_device_settings::<B>(&B::float_device(&tensor)).int_dtype;
503 let (values, indices) =
504 B::float_sort_with_indices(tensor, dim, descending, out_dtype);
505 (TensorPrimitive::Float(values), indices)
506 }
507 TensorPrimitive::QFloat(tensor) => {
508 let out_dtype = get_device_settings::<B>(&B::q_device(&tensor)).int_dtype;
509 let (values, indices) = B::q_sort_with_indices(tensor, dim, descending, out_dtype);
510 (TensorPrimitive::QFloat(values), indices)
511 }
512 }
513 }
514
515 fn argsort(tensor: Self::Primitive, dim: usize, descending: bool) -> IntTensor<B> {
516 match tensor {
517 TensorPrimitive::Float(tensor) => {
518 let out_dtype = get_device_settings::<B>(&B::float_device(&tensor)).int_dtype;
519 B::float_argsort(tensor, dim, descending, out_dtype)
520 }
521 TensorPrimitive::QFloat(tensor) => {
522 let out_dtype = get_device_settings::<B>(&B::q_device(&tensor)).int_dtype;
523 B::q_argsort(tensor, dim, descending, out_dtype)
524 }
525 }
526 }
527
528 fn cummin(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
529 match tensor {
530 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cummin(tensor, dim)),
531 TensorPrimitive::QFloat(tensor) => B::q_cummin(tensor, dim),
532 }
533 }
534
535 fn cummax(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
536 match tensor {
537 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cummax(tensor, dim)),
538 TensorPrimitive::QFloat(tensor) => B::q_cummax(tensor, dim),
539 }
540 }
541
542 fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
543 let lhs = lhs.tensor();
544 let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
545 B::float_greater(lhs, rhs.tensor(), out_dtype)
546 }
547
548 fn greater_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
549 let lhs = lhs.tensor();
550 let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
551 B::float_greater_elem(lhs, rhs, out_dtype)
552 }
553
554 fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
555 let lhs = lhs.tensor();
556 let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
557 B::float_greater_equal(lhs, rhs.tensor(), out_dtype)
558 }
559
560 fn greater_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
561 let lhs = lhs.tensor();
562 let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
563 B::float_greater_equal_elem(lhs, rhs, out_dtype)
564 }
565
566 fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
567 let lhs = lhs.tensor();
568 let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
569 B::float_lower(lhs, rhs.tensor(), out_dtype)
570 }
571
572 fn lower_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
573 let lhs = lhs.tensor();
574 let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
575 B::float_lower_elem(lhs, rhs, out_dtype)
576 }
577
578 fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
579 let lhs = lhs.tensor();
580 let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
581 B::float_lower_equal(lhs, rhs.tensor(), out_dtype)
582 }
583
584 fn lower_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
585 let lhs = lhs.tensor();
586 let out_dtype = get_device_settings::<B>(&B::float_device(&lhs)).bool_dtype;
587 B::float_lower_equal_elem(lhs, rhs, out_dtype)
588 }
589
590 fn argmax(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
591 match tensor {
592 TensorPrimitive::Float(tensor) => {
593 let out_dtype = get_device_settings::<B>(&B::float_device(&tensor)).int_dtype;
594 B::float_argmax(tensor, dim, out_dtype)
595 }
596 TensorPrimitive::QFloat(tensor) => {
597 let out_dtype = get_device_settings::<B>(&B::q_device(&tensor)).int_dtype;
598 B::q_argmax(tensor, dim, out_dtype)
599 }
600 }
601 }
602
603 fn argmin(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
604 match tensor {
605 TensorPrimitive::Float(tensor) => {
606 let out_dtype = get_device_settings::<B>(&B::float_device(&tensor)).int_dtype;
607 B::float_argmin(tensor, dim, out_dtype)
608 }
609 TensorPrimitive::QFloat(tensor) => {
610 let out_dtype = get_device_settings::<B>(&B::q_device(&tensor)).int_dtype;
611 B::q_argmin(tensor, dim, out_dtype)
612 }
613 }
614 }
615
616 fn max(tensor: Self::Primitive) -> Self::Primitive {
617 match tensor {
618 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max(tensor)),
619 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max(tensor)),
620 }
621 }
622
623 fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
624 match tensor {
625 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max_dim(tensor, dim)),
626 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max_dim(tensor, dim)),
627 }
628 }
629
630 fn max_dim_with_indices(
631 tensor: Self::Primitive,
632 dim: usize,
633 ) -> (Self::Primitive, IntTensor<B>) {
634 match tensor {
635 TensorPrimitive::Float(tensor) => {
636 let out_dtype = get_device_settings::<B>(&B::float_device(&tensor)).int_dtype;
637 let (values, indices) = B::float_max_dim_with_indices(tensor, dim, out_dtype);
638 (TensorPrimitive::Float(values), indices)
639 }
640 TensorPrimitive::QFloat(tensor) => {
641 let out_dtype = get_device_settings::<B>(&B::q_device(&tensor)).int_dtype;
642 let (values, indices) = B::q_max_dim_with_indices(tensor, dim, out_dtype);
643 (TensorPrimitive::QFloat(values), indices)
644 }
645 }
646 }
647
648 fn min(tensor: Self::Primitive) -> Self::Primitive {
649 match tensor {
650 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_min(tensor)),
651 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_min(tensor)),
652 }
653 }
654
655 fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
656 match tensor {
657 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_min_dim(tensor, dim)),
658 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_min_dim(tensor, dim)),
659 }
660 }
661
662 fn min_dim_with_indices(
663 tensor: Self::Primitive,
664 dim: usize,
665 ) -> (Self::Primitive, IntTensor<B>) {
666 match tensor {
667 TensorPrimitive::Float(tensor) => {
668 let out_dtype = get_device_settings::<B>(&B::float_device(&tensor)).int_dtype;
669 let (values, indices) = B::float_min_dim_with_indices(tensor, dim, out_dtype);
670 (TensorPrimitive::Float(values), indices)
671 }
672 TensorPrimitive::QFloat(tensor) => {
673 let out_dtype = get_device_settings::<B>(&B::q_device(&tensor)).int_dtype;
674 let (values, indices) = B::q_min_dim_with_indices(tensor, dim, out_dtype);
675 (TensorPrimitive::QFloat(values), indices)
676 }
677 }
678 }
679
680 fn clamp(tensor: Self::Primitive, min: Scalar, max: Scalar) -> Self::Primitive {
681 match tensor {
682 TensorPrimitive::Float(tensor) => {
683 TensorPrimitive::Float(B::float_clamp(tensor, min, max))
684 }
685 TensorPrimitive::QFloat(tensor) => B::q_clamp(tensor, min, max),
686 }
687 }
688
689 fn clamp_min(tensor: Self::Primitive, min: Scalar) -> Self::Primitive {
690 match tensor {
691 TensorPrimitive::Float(tensor) => {
692 TensorPrimitive::Float(B::float_clamp_min(tensor, min))
693 }
694 TensorPrimitive::QFloat(tensor) => B::q_clamp_min(tensor, min),
695 }
696 }
697
698 fn clamp_max(tensor: Self::Primitive, max: Scalar) -> Self::Primitive {
699 match tensor {
700 TensorPrimitive::Float(tensor) => {
701 TensorPrimitive::Float(B::float_clamp_max(tensor, max))
702 }
703 TensorPrimitive::QFloat(tensor) => B::q_clamp_max(tensor, max),
704 }
705 }
706
707 fn max_abs(tensor: Self::Primitive) -> Self::Primitive {
708 match tensor {
709 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max_abs(tensor)),
710 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max_abs(tensor)),
711 }
712 }
713
714 fn max_abs_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
715 match tensor {
716 TensorPrimitive::Float(tensor) => {
717 TensorPrimitive::Float(B::float_max_abs_dim(tensor, dim))
718 }
719 TensorPrimitive::QFloat(tensor) => {
720 TensorPrimitive::QFloat(B::q_max_abs_dim(tensor, dim))
721 }
722 }
723 }
724}
725
726impl<B: AutodiffBackend> BasicAutodiffOps<B> for Float {
727 type InnerKind = Float;
728
729 fn inner(
730 tensor: <Self as TensorKind<B>>::Primitive,
731 ) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive {
732 match tensor {
733 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::inner(tensor)),
734 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_inner(tensor)),
735 }
736 }
737
738 fn from_inner(
739 inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive,
740 ) -> <Self as TensorKind<B>>::Primitive {
741 match inner {
742 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::from_inner(tensor)),
743 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_from_inner(tensor)),
744 }
745 }
746}