Skip to main content

lak/
traits.rs

1// traits.rs 
2
3use std::fmt::Display; 
4use std::ops::{Add, Mul, Sub}; 
5use std::simd::{Simd, StdFloat};
6
7use crate::l3::{dgemm, sgemm};
8use crate::types::{MatMut, MatRef, Transpose}; 
9
10/// uses fma 
11pub trait Fma {
12    /// computes (self * a) + b 
13    fn fma(self, b: Self, c: Self) -> Self; 
14}
15
16impl Fma for f32 { 
17    #[inline(always)]
18    fn fma(self, b: Self, c: Self) -> Self { 
19        self.mul_add(b, c)
20    }
21}
22
23impl Fma for f64 { 
24    #[inline(always)]
25    fn fma(self, b: Self, c: Self) -> Self { 
26        self.mul_add(b, c)
27    }
28}
29
30impl<const LANES: usize> Fma for Simd<f32, LANES> {
31    #[inline(always)]
32    fn fma(self, b: Self, c: Self) -> Self { 
33        self.mul_add(b, c) 
34    }
35}
36
37impl<const LANES: usize> Fma for Simd<f64, LANES> {
38    /// computes (self * a) + b 
39    #[inline(always)]
40    fn fma(self, b: Self, c: Self) -> Self { 
41        self.mul_add(b, c)
42    }   
43}
44
45/// computes absolute value 
46pub trait Abs { 
47    fn abs(self) -> Self; 
48}   
49
50impl Abs for f32 { 
51    /// computes absolute value of [f32]
52    fn abs(self) -> Self { 
53        f32::abs(self)
54    }
55}
56
57impl Abs for f64 { 
58    /// computes absolute value of [f64]
59    fn abs(self) -> Self { 
60        f64::abs(self)
61    }
62}
63
64
65/// computes square root 
66pub trait Sqrt { 
67    fn sqrt(self) -> Self; 
68}
69
70impl Sqrt for f32 { 
71    /// computes square root if [f32]
72    fn sqrt(self) -> Self { 
73        f32::sqrt(self)
74    }
75}
76
77impl Sqrt for f64 { 
78    /// computes square root if [f64]
79    fn sqrt(self) -> Self { 
80        f64::sqrt(self)
81    }
82}
83
84/// returns max between 
85pub trait Max { 
86    fn max(self, other: Self) -> Self; 
87}
88
89impl Max for f64 { 
90    /// computes max between two [f64]s 
91    fn max(self, other: Self) -> Self { 
92        if self >= other { 
93            self 
94        } else { 
95            other
96        }
97    }
98}
99
100
101/// used in tests
102pub trait TestFloat:
103    Copy
104    + PartialOrd
105    + Sub<Output = Self>
106    + Add<Output = Self>
107    + Mul<Output = Self>
108    + Abs
109    + Max 
110    + Display
111{
112    const RTOL: Self;
113    const ATOL: Self;
114}
115
116impl TestFloat for f64 {
117    const RTOL: Self = 1e-14;
118    const ATOL: Self = 1e-14;
119}
120
121
122
123/// defines SIMD vector width in L1 routines 
124pub trait SimdScalarL1: Copy { 
125    const LANES: usize; 
126}
127
128impl SimdScalarL1 for f32 { 
129    const LANES: usize = 32; 
130}
131
132impl SimdScalarL1 for f64 { 
133    const LANES: usize = 16; 
134}
135
136
137/// dispatches generic [crate::l3::gemm] calls to [sgemm]/[dgemm]. 
138pub trait GemmDispatch: Sized {
139    /// calls the concrete [sgemm]/[dgemm] implementation for self 
140    fn gemm( 
141        atrans: Transpose, 
142        btrans: Transpose, 
143        alpha: Self, 
144        beta: Self, 
145        a: MatRef<'_, Self>, 
146        b: MatRef<'_, Self>, 
147        c: MatMut<'_, Self>,
148    ); 
149}
150
151impl GemmDispatch for f32 {
152    /// dispatches to [sgemm]
153    fn gemm( 
154        atrans: Transpose, 
155        btrans: Transpose, 
156        alpha: f32, 
157        beta: f32, 
158        a: MatRef<'_, f32>, 
159        b: MatRef<'_, f32>, 
160        c: MatMut<'_, f32>,      
161    ) { 
162        sgemm(atrans, btrans, alpha, beta, a, b, c);
163    }
164}
165
166impl GemmDispatch for f64 { 
167    /// dispatches to [dgemm]
168    fn gemm( 
169        atrans: Transpose, 
170        btrans: Transpose, 
171        alpha: f64, 
172        beta: f64, 
173        a: MatRef<'_, f64>, 
174        b: MatRef<'_, f64>, 
175        c: MatMut<'_, f64>,      
176    ) {
177        dgemm(atrans, btrans, alpha, beta, a, b, c);
178    }
179}
180
181
182