1use std::fmt::Display;
4use std::ops::{Add, Mul, Sub};
5use std::simd::{Simd, StdFloat};
6
7use crate::l3::{dgemm, sgemm};
8use crate::types::{MatMut, MatRef, Transpose};
9
10pub trait Fma {
12 fn fma(self, b: Self, c: Self) -> Self;
14}
15
16impl Fma for f32 {
17 #[inline(always)]
18 fn fma(self, b: Self, c: Self) -> Self {
19 self.mul_add(b, c)
20 }
21}
22
23impl Fma for f64 {
24 #[inline(always)]
25 fn fma(self, b: Self, c: Self) -> Self {
26 self.mul_add(b, c)
27 }
28}
29
30impl<const LANES: usize> Fma for Simd<f32, LANES> {
31 #[inline(always)]
32 fn fma(self, b: Self, c: Self) -> Self {
33 self.mul_add(b, c)
34 }
35}
36
37impl<const LANES: usize> Fma for Simd<f64, LANES> {
38 #[inline(always)]
40 fn fma(self, b: Self, c: Self) -> Self {
41 self.mul_add(b, c)
42 }
43}
44
45pub trait Abs {
47 fn abs(self) -> Self;
48}
49
50impl Abs for f32 {
51 fn abs(self) -> Self {
53 f32::abs(self)
54 }
55}
56
57impl Abs for f64 {
58 fn abs(self) -> Self {
60 f64::abs(self)
61 }
62}
63
64
65pub trait Sqrt {
67 fn sqrt(self) -> Self;
68}
69
70impl Sqrt for f32 {
71 fn sqrt(self) -> Self {
73 f32::sqrt(self)
74 }
75}
76
77impl Sqrt for f64 {
78 fn sqrt(self) -> Self {
80 f64::sqrt(self)
81 }
82}
83
84pub trait Max {
86 fn max(self, other: Self) -> Self;
87}
88
89impl Max for f64 {
90 fn max(self, other: Self) -> Self {
92 if self >= other {
93 self
94 } else {
95 other
96 }
97 }
98}
99
100
101pub trait TestFloat:
103 Copy
104 + PartialOrd
105 + Sub<Output = Self>
106 + Add<Output = Self>
107 + Mul<Output = Self>
108 + Abs
109 + Max
110 + Display
111{
112 const RTOL: Self;
113 const ATOL: Self;
114}
115
116impl TestFloat for f64 {
117 const RTOL: Self = 1e-14;
118 const ATOL: Self = 1e-14;
119}
120
121
122
123pub trait SimdScalarL1: Copy {
125 const LANES: usize;
126}
127
128impl SimdScalarL1 for f32 {
129 const LANES: usize = 32;
130}
131
132impl SimdScalarL1 for f64 {
133 const LANES: usize = 16;
134}
135
136
137pub trait GemmDispatch: Sized {
139 fn gemm(
141 atrans: Transpose,
142 btrans: Transpose,
143 alpha: Self,
144 beta: Self,
145 a: MatRef<'_, Self>,
146 b: MatRef<'_, Self>,
147 c: MatMut<'_, Self>,
148 );
149}
150
151impl GemmDispatch for f32 {
152 fn gemm(
154 atrans: Transpose,
155 btrans: Transpose,
156 alpha: f32,
157 beta: f32,
158 a: MatRef<'_, f32>,
159 b: MatRef<'_, f32>,
160 c: MatMut<'_, f32>,
161 ) {
162 sgemm(atrans, btrans, alpha, beta, a, b, c);
163 }
164}
165
166impl GemmDispatch for f64 {
167 fn gemm(
169 atrans: Transpose,
170 btrans: Transpose,
171 alpha: f64,
172 beta: f64,
173 a: MatRef<'_, f64>,
174 b: MatRef<'_, f64>,
175 c: MatMut<'_, f64>,
176 ) {
177 dgemm(atrans, btrans, alpha, beta, a, b, c);
178 }
179}
180
181
182