1use alloc::vec::Vec;
2use burn_std::{DType, Shape, Slice};
3
4use crate::{
5 AutodiffBackend, Backend, Distribution, ExecutionError, Scalar, TensorData, TensorPrimitive,
6 ops::TransactionPrimitive,
7 tensor::{
8 BasicAutodiffOps, BasicOps, Device, Float, IndexingUpdateOp, IntTensor, Numeric, Ordered,
9 TensorKind,
10 },
11};
12
13macro_rules! q_bin_ops {
14 ($lhs:ident, $rhs:ident, $op:ident, $q_op:ident) => {
15 match ($lhs, $rhs) {
16 (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
17 TensorPrimitive::Float(B::$op(lhs, rhs))
18 }
19 (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => B::$q_op(lhs, rhs),
20 (TensorPrimitive::QFloat(lhs), TensorPrimitive::Float(rhs)) => {
21 TensorPrimitive::Float(B::$op(B::dequantize(lhs), rhs))
22 }
23 (TensorPrimitive::Float(lhs), TensorPrimitive::QFloat(rhs)) => {
24 TensorPrimitive::Float(B::$op(lhs, B::dequantize(rhs)))
25 }
26 }
27 };
28}
29
30impl<B: Backend> BasicOps<B> for Float {
31 type Elem = B::FloatElem;
32
33 fn empty(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
34 TensorPrimitive::Float(B::float_empty(shape, device, dtype.into()))
35 }
36
37 fn zeros(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
38 TensorPrimitive::Float(B::float_zeros(shape, device, dtype.into()))
39 }
40 fn ones(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
41 TensorPrimitive::Float(B::float_ones(shape, device, dtype.into()))
42 }
43
44 fn full(shape: Shape, fill_value: Scalar, device: &Device<B>, dtype: DType) -> Self::Primitive {
45 TensorPrimitive::Float(B::float_full(shape, fill_value, device, dtype.into()))
46 }
47
48 fn register_transaction(tr: &mut TransactionPrimitive<B>, tensor: Self::Primitive) {
49 tr.register_float(tensor);
50 }
51
52 fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
53 match tensor {
54 TensorPrimitive::Float(tensor) => {
55 TensorPrimitive::Float(B::float_reshape(tensor, shape))
56 }
57 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_reshape(tensor, shape)),
58 }
59 }
60
61 fn transpose(tensor: Self::Primitive) -> Self::Primitive {
62 match tensor {
63 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_transpose(tensor)),
64 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_transpose(tensor)),
65 }
66 }
67
68 fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
69 match tensor {
70 TensorPrimitive::Float(tensor) => {
71 TensorPrimitive::Float(B::float_swap_dims(tensor, dim1, dim2))
72 }
73 TensorPrimitive::QFloat(tensor) => {
74 TensorPrimitive::QFloat(B::q_swap_dims(tensor, dim1, dim2))
75 }
76 }
77 }
78
79 fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive {
80 match tensor {
81 TensorPrimitive::Float(tensor) => {
82 TensorPrimitive::Float(B::float_slice(tensor, slices))
83 }
84 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_slice(tensor, slices)),
85 }
86 }
87
88 fn slice_assign(
89 tensor: Self::Primitive,
90 slices: &[Slice],
91 value: Self::Primitive,
92 ) -> Self::Primitive {
93 TensorPrimitive::Float(B::float_slice_assign(
94 tensor.tensor(),
95 slices,
96 value.tensor(),
97 ))
98 }
99
100 fn select(tensor: Self::Primitive, dim: usize, indices: IntTensor<B>) -> Self::Primitive {
101 match tensor {
102 TensorPrimitive::Float(tensor) => {
103 TensorPrimitive::Float(B::float_select(tensor, dim, indices))
104 }
105 TensorPrimitive::QFloat(tensor) => {
106 TensorPrimitive::QFloat(B::q_select(tensor, dim, indices))
107 }
108 }
109 }
110
111 fn select_assign(
112 tensor: Self::Primitive,
113 dim: usize,
114 indices: IntTensor<B>,
115 values: Self::Primitive,
116 update: IndexingUpdateOp,
117 ) -> Self::Primitive {
118 match update {
120 IndexingUpdateOp::Add => TensorPrimitive::Float(B::float_select_add(
121 tensor.tensor(),
122 dim,
123 indices,
124 values.tensor(),
125 )),
126 }
127 }
128
129 fn mask_where(
130 tensor: Self::Primitive,
131 mask: B::BoolTensorPrimitive,
132 source: Self::Primitive,
133 ) -> Self::Primitive {
134 TensorPrimitive::Float(B::float_mask_where(tensor.tensor(), mask, source.tensor()))
135 }
136
137 fn mask_fill(
138 tensor: Self::Primitive,
139 mask: B::BoolTensorPrimitive,
140 value: Scalar,
141 ) -> Self::Primitive {
142 TensorPrimitive::Float(B::float_mask_fill(tensor.tensor(), mask, value))
143 }
144
145 fn gather(dim: usize, tensor: Self::Primitive, indices: IntTensor<B>) -> Self::Primitive {
146 match tensor {
147 TensorPrimitive::Float(tensor) => {
148 TensorPrimitive::Float(B::float_gather(dim, tensor, indices))
149 }
150 TensorPrimitive::QFloat(tensor) => {
151 TensorPrimitive::QFloat(B::q_gather(dim, tensor, indices))
152 }
153 }
154 }
155
156 fn scatter(
157 dim: usize,
158 tensor: Self::Primitive,
159 indices: IntTensor<B>,
160 values: Self::Primitive,
161 update: IndexingUpdateOp,
162 ) -> Self::Primitive {
163 match update {
164 IndexingUpdateOp::Add => TensorPrimitive::Float(B::float_scatter_add(
165 dim,
166 tensor.tensor(),
167 indices,
168 values.tensor(),
169 )),
170 }
171 }
172
173 fn device(tensor: &Self::Primitive) -> Device<B> {
174 match tensor {
175 TensorPrimitive::Float(tensor) => B::float_device(tensor),
176 TensorPrimitive::QFloat(tensor) => B::q_device(tensor),
177 }
178 }
179
180 fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
181 match tensor {
182 TensorPrimitive::Float(tensor) => {
183 TensorPrimitive::Float(B::float_to_device(tensor, device))
184 }
185 TensorPrimitive::QFloat(tensor) => {
186 TensorPrimitive::QFloat(B::q_to_device(tensor, device))
187 }
188 }
189 }
190
191 async fn into_data_async(tensor: Self::Primitive) -> Result<TensorData, ExecutionError> {
192 match tensor {
193 TensorPrimitive::Float(tensor) => B::float_into_data(tensor).await,
194 TensorPrimitive::QFloat(tensor) => B::q_into_data(tensor).await,
195 }
196 }
197
198 fn from_data(data: TensorData, device: &Device<B>) -> Self::Primitive {
199 match &data.dtype {
200 DType::QFloat(_scheme) => TensorPrimitive::QFloat(B::q_from_data(data, device)),
201 _ => TensorPrimitive::Float(B::float_from_data(data.convert::<B::FloatElem>(), device)),
202 }
203 }
204
205 fn from_data_dtype(data: TensorData, device: &Device<B>, dtype: DType) -> Self::Primitive {
206 match dtype {
207 DType::QFloat(_scheme) => {
208 TensorPrimitive::QFloat(B::q_from_data(data.convert_dtype(dtype), device))
209 }
210 _ if dtype.is_float() => {
211 TensorPrimitive::Float(B::float_from_data(data.convert_dtype(dtype), device))
212 }
213 _ => panic!("Expected float dtype, got {dtype:?}"),
214 }
215 }
216
217 fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
218 match tensor {
219 TensorPrimitive::Float(tensor) => {
220 TensorPrimitive::Float(B::float_repeat_dim(tensor, dim, times))
221 }
222 TensorPrimitive::QFloat(tensor) => {
223 TensorPrimitive::QFloat(B::q_repeat_dim(tensor, dim, times))
224 }
225 }
226 }
227
228 fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
229 match vectors.first().unwrap() {
230 TensorPrimitive::Float(_) => TensorPrimitive::Float(B::float_cat(
231 vectors.into_iter().map(|tensor| tensor.tensor()).collect(),
232 dim,
233 )),
234 TensorPrimitive::QFloat(_) => TensorPrimitive::QFloat(B::q_cat(
235 vectors
236 .into_iter()
237 .map(|tensor| {
238 if let TensorPrimitive::QFloat(t) = tensor {
239 t
240 } else {
241 panic!("Concatenation only works with vector of QFloat")
242 }
243 })
244 .collect(),
245 dim,
246 )),
247 }
248 }
249
250 fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
251 B::float_equal(lhs.tensor(), rhs.tensor())
252 }
253
254 fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
255 B::float_not_equal(lhs.tensor(), rhs.tensor())
256 }
257
258 fn equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
259 B::float_equal_elem(lhs.tensor(), rhs)
260 }
261
262 fn not_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
263 B::float_not_equal_elem(lhs.tensor(), rhs)
264 }
265
266 fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
267 B::float_any(tensor.tensor())
268 }
269
270 fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
271 B::float_any_dim(tensor.tensor(), dim)
272 }
273
274 fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
275 B::float_all(tensor.tensor())
276 }
277
278 fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
279 B::float_all_dim(tensor.tensor(), dim)
280 }
281
282 fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
283 match tensor {
284 TensorPrimitive::Float(tensor) => {
285 TensorPrimitive::Float(B::float_permute(tensor, axes))
286 }
287 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_permute(tensor, axes)),
288 }
289 }
290
291 fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
292 TensorPrimitive::Float(B::float_expand(tensor.tensor(), shape))
293 }
294
295 fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
296 match tensor {
297 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_flip(tensor, axes)),
298 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_flip(tensor, axes)),
299 }
300 }
301
302 fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive {
303 TensorPrimitive::Float(B::float_unfold(tensor.tensor(), dim, size, step))
304 }
305}
306
307impl<B: Backend> Numeric<B> for Float {
308 fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
309 q_bin_ops!(lhs, rhs, float_add, q_add)
310 }
311
312 fn add_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
313 match lhs {
314 TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_add_scalar(lhs, rhs)),
315 TensorPrimitive::QFloat(lhs) => B::q_add_scalar(lhs, rhs),
316 }
317 }
318
319 fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
320 q_bin_ops!(lhs, rhs, float_sub, q_sub)
321 }
322
323 fn sub_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
324 match lhs {
325 TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_sub_scalar(lhs, rhs)),
326 TensorPrimitive::QFloat(lhs) => B::q_sub_scalar(lhs, rhs),
327 }
328 }
329
330 fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
331 q_bin_ops!(lhs, rhs, float_div, q_div)
332 }
333
334 fn div_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
335 match lhs {
336 TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_div_scalar(lhs, rhs)),
337 TensorPrimitive::QFloat(lhs) => B::q_div_scalar(lhs, rhs),
338 }
339 }
340 fn remainder(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
341 TensorPrimitive::Float(B::float_remainder(lhs.tensor(), rhs.tensor()))
342 }
343
344 fn remainder_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
345 TensorPrimitive::Float(B::float_remainder_scalar(lhs.tensor(), rhs))
346 }
347
348 fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
349 q_bin_ops!(lhs, rhs, float_mul, q_mul)
350 }
351
352 fn mul_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
353 match lhs {
354 TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_mul_scalar(lhs, rhs)),
355 TensorPrimitive::QFloat(lhs) => B::q_mul_scalar(lhs, rhs),
356 }
357 }
358 fn neg(tensor: Self::Primitive) -> Self::Primitive {
359 match tensor {
360 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_neg(tensor)),
361 TensorPrimitive::QFloat(tensor) => B::q_neg(tensor),
362 }
363 }
364
365 fn sum(tensor: Self::Primitive) -> Self::Primitive {
366 match tensor {
367 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_sum(tensor)),
368 TensorPrimitive::QFloat(tensor) => B::q_sum(tensor),
369 }
370 }
371
372 fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
373 match tensor {
374 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_sum_dim(tensor, dim)),
375 TensorPrimitive::QFloat(tensor) => B::q_sum_dim(tensor, dim),
376 }
377 }
378
379 fn prod(tensor: Self::Primitive) -> Self::Primitive {
380 match tensor {
381 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_prod(tensor)),
382 TensorPrimitive::QFloat(tensor) => B::q_prod(tensor),
383 }
384 }
385
386 fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
387 match tensor {
388 TensorPrimitive::Float(tensor) => {
389 TensorPrimitive::Float(B::float_prod_dim(tensor, dim))
390 }
391 TensorPrimitive::QFloat(tensor) => B::q_prod_dim(tensor, dim),
392 }
393 }
394
395 fn mean(tensor: Self::Primitive) -> Self::Primitive {
396 match tensor {
397 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_mean(tensor)),
398 TensorPrimitive::QFloat(tensor) => B::q_mean(tensor),
399 }
400 }
401
402 fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
403 match tensor {
404 TensorPrimitive::Float(tensor) => {
405 TensorPrimitive::Float(B::float_mean_dim(tensor, dim))
406 }
407 TensorPrimitive::QFloat(tensor) => B::q_mean_dim(tensor, dim),
408 }
409 }
410
411 fn cumsum(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
412 match tensor {
413 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cumsum(tensor, dim)),
414 TensorPrimitive::QFloat(tensor) => B::q_cumsum(tensor, dim),
415 }
416 }
417
418 fn cumprod(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
419 match tensor {
420 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cumprod(tensor, dim)),
421 TensorPrimitive::QFloat(tensor) => B::q_cumprod(tensor, dim),
422 }
423 }
424
425 fn abs(tensor: Self::Primitive) -> Self::Primitive {
426 match tensor {
427 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_abs(tensor)),
428 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_abs(tensor)),
429 }
430 }
431
432 fn powf(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
433 q_bin_ops!(lhs, rhs, float_powf, q_powf)
434 }
435
436 fn powf_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
437 match lhs {
438 TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_powf_scalar(lhs, rhs)),
439 TensorPrimitive::QFloat(lhs) => B::q_powf_scalar(lhs, rhs),
440 }
441 }
442
443 fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
444 q_bin_ops!(lhs, rhs, float_powf, q_powf)
445 }
446
447 fn powi_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
448 match lhs {
449 TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_powi_scalar(lhs, rhs)),
450 TensorPrimitive::QFloat(lhs) => B::q_powi_scalar(lhs, rhs),
451 }
452 }
453
454 fn random(shape: Shape, distribution: Distribution, device: &Device<B>) -> Self::Primitive {
455 TensorPrimitive::Float(B::float_random(shape, distribution, device))
456 }
457
458 fn sign(tensor: Self::Primitive) -> Self::Primitive {
459 TensorPrimitive::Float(B::float_sign(tensor.tensor()))
460 }
461
462 fn matmul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
470 match (lhs, rhs) {
471 (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
472 TensorPrimitive::Float(B::float_matmul(lhs, rhs))
473 }
474 (lhs, rhs) => B::q_matmul(lhs, rhs),
475 }
476 }
477}
478impl<B: Backend> Ordered<B> for Float {
479 fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive {
480 match tensor {
481 TensorPrimitive::Float(tensor) => {
482 TensorPrimitive::Float(B::float_sort(tensor, dim, descending))
483 }
484 TensorPrimitive::QFloat(tensor) => {
485 TensorPrimitive::QFloat(B::q_sort(tensor, dim, descending))
486 }
487 }
488 }
489
490 fn sort_with_indices(
491 tensor: Self::Primitive,
492 dim: usize,
493 descending: bool,
494 ) -> (Self::Primitive, IntTensor<B>) {
495 match tensor {
496 TensorPrimitive::Float(tensor) => {
497 let (values, indices) = B::float_sort_with_indices(tensor, dim, descending);
498 (TensorPrimitive::Float(values), indices)
499 }
500 TensorPrimitive::QFloat(tensor) => {
501 let (values, indices) = B::q_sort_with_indices(tensor, dim, descending);
502 (TensorPrimitive::QFloat(values), indices)
503 }
504 }
505 }
506
507 fn argsort(tensor: Self::Primitive, dim: usize, descending: bool) -> IntTensor<B> {
508 match tensor {
509 TensorPrimitive::Float(tensor) => B::float_argsort(tensor, dim, descending),
510 TensorPrimitive::QFloat(tensor) => B::q_argsort(tensor, dim, descending),
511 }
512 }
513
514 fn cummin(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
515 match tensor {
516 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cummin(tensor, dim)),
517 TensorPrimitive::QFloat(tensor) => B::q_cummin(tensor, dim),
518 }
519 }
520
521 fn cummax(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
522 match tensor {
523 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cummax(tensor, dim)),
524 TensorPrimitive::QFloat(tensor) => B::q_cummax(tensor, dim),
525 }
526 }
527
528 fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
529 B::float_greater(lhs.tensor(), rhs.tensor())
530 }
531
532 fn greater_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
533 B::float_greater_elem(lhs.tensor(), rhs)
534 }
535
536 fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
537 B::float_greater_equal(lhs.tensor(), rhs.tensor())
538 }
539
540 fn greater_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
541 B::float_greater_equal_elem(lhs.tensor(), rhs)
542 }
543
544 fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
545 B::float_lower(lhs.tensor(), rhs.tensor())
546 }
547
548 fn lower_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
549 B::float_lower_elem(lhs.tensor(), rhs)
550 }
551
552 fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
553 B::float_lower_equal(lhs.tensor(), rhs.tensor())
554 }
555
556 fn lower_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
557 B::float_lower_equal_elem(lhs.tensor(), rhs)
558 }
559
560 fn argmax(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
561 match tensor {
562 TensorPrimitive::Float(tensor) => B::float_argmax(tensor, dim),
563 TensorPrimitive::QFloat(tensor) => B::q_argmax(tensor, dim),
564 }
565 }
566
567 fn argmin(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
568 match tensor {
569 TensorPrimitive::Float(tensor) => B::float_argmin(tensor, dim),
570 TensorPrimitive::QFloat(tensor) => B::q_argmin(tensor, dim),
571 }
572 }
573
574 fn max(tensor: Self::Primitive) -> Self::Primitive {
575 match tensor {
576 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max(tensor)),
577 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max(tensor)),
578 }
579 }
580
581 fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
582 match tensor {
583 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max_dim(tensor, dim)),
584 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max_dim(tensor, dim)),
585 }
586 }
587
588 fn max_dim_with_indices(
589 tensor: Self::Primitive,
590 dim: usize,
591 ) -> (Self::Primitive, IntTensor<B>) {
592 match tensor {
593 TensorPrimitive::Float(tensor) => {
594 let (values, indices) = B::float_max_dim_with_indices(tensor, dim);
595 (TensorPrimitive::Float(values), indices)
596 }
597 TensorPrimitive::QFloat(tensor) => {
598 let (values, indices) = B::q_max_dim_with_indices(tensor, dim);
599 (TensorPrimitive::QFloat(values), indices)
600 }
601 }
602 }
603
604 fn min(tensor: Self::Primitive) -> Self::Primitive {
605 match tensor {
606 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_min(tensor)),
607 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_min(tensor)),
608 }
609 }
610
611 fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
612 match tensor {
613 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_min_dim(tensor, dim)),
614 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_min_dim(tensor, dim)),
615 }
616 }
617
618 fn min_dim_with_indices(
619 tensor: Self::Primitive,
620 dim: usize,
621 ) -> (Self::Primitive, IntTensor<B>) {
622 match tensor {
623 TensorPrimitive::Float(tensor) => {
624 let (values, indices) = B::float_min_dim_with_indices(tensor, dim);
625 (TensorPrimitive::Float(values), indices)
626 }
627 TensorPrimitive::QFloat(tensor) => {
628 let (values, indices) = B::q_min_dim_with_indices(tensor, dim);
629 (TensorPrimitive::QFloat(values), indices)
630 }
631 }
632 }
633
634 fn clamp(tensor: Self::Primitive, min: Scalar, max: Scalar) -> Self::Primitive {
635 match tensor {
636 TensorPrimitive::Float(tensor) => {
637 TensorPrimitive::Float(B::float_clamp(tensor, min, max))
638 }
639 TensorPrimitive::QFloat(tensor) => B::q_clamp(tensor, min, max),
640 }
641 }
642
643 fn clamp_min(tensor: Self::Primitive, min: Scalar) -> Self::Primitive {
644 match tensor {
645 TensorPrimitive::Float(tensor) => {
646 TensorPrimitive::Float(B::float_clamp_min(tensor, min))
647 }
648 TensorPrimitive::QFloat(tensor) => B::q_clamp_min(tensor, min),
649 }
650 }
651
652 fn clamp_max(tensor: Self::Primitive, max: Scalar) -> Self::Primitive {
653 match tensor {
654 TensorPrimitive::Float(tensor) => {
655 TensorPrimitive::Float(B::float_clamp_max(tensor, max))
656 }
657 TensorPrimitive::QFloat(tensor) => B::q_clamp_max(tensor, max),
658 }
659 }
660
661 fn max_abs(tensor: Self::Primitive) -> Self::Primitive {
662 match tensor {
663 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max_abs(tensor)),
664 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max_abs(tensor)),
665 }
666 }
667
668 fn max_abs_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
669 match tensor {
670 TensorPrimitive::Float(tensor) => {
671 TensorPrimitive::Float(B::float_max_abs_dim(tensor, dim))
672 }
673 TensorPrimitive::QFloat(tensor) => {
674 TensorPrimitive::QFloat(B::q_max_abs_dim(tensor, dim))
675 }
676 }
677 }
678}
679
680impl<B: AutodiffBackend> BasicAutodiffOps<B> for Float {
681 type InnerKind = Float;
682
683 fn inner(
684 tensor: <Self as TensorKind<B>>::Primitive,
685 ) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive {
686 match tensor {
687 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::inner(tensor)),
688 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_inner(tensor)),
689 }
690 }
691
692 fn from_inner(
693 inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive,
694 ) -> <Self as TensorKind<B>>::Primitive {
695 match inner {
696 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::from_inner(tensor)),
697 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_from_inner(tensor)),
698 }
699 }
700}