1#![allow(clippy::missing_safety_doc)]
2#![allow(clippy::redundant_closure_call)]
3#![allow(clippy::len_zero)]
4#![allow(clippy::excessive_precision)]
5#![allow(clippy::approx_constant)]
6#![allow(clippy::manual_is_multiple_of)]
7#![allow(unexpected_cfgs)]
8#![allow(unused_macros)]
9#[macro_use]
10extern crate derive_new;
11extern crate lazy_static;
12extern crate log;
13extern crate num_traits;
14#[macro_use]
15extern crate pastey;
16#[cfg(test)]
17extern crate proptest;
18
19include!(concat!(env!("OUT_DIR"), "/extern_kernel_macro.rs"));
20
21#[macro_use]
22mod frame;
23pub mod generic;
24pub mod multithread;
25pub use frame::weights::WeightType;
26pub use generic::{ScaleShiftAndRound, Scaler};
27use lazy_static::lazy_static;
28use mmm::{MMMInputFormat, MatMatMul, PanelExtractor};
29use tract_data::internal::TensorView;
30#[cfg(target_arch = "x86_64")]
31pub mod x86_64_fma;
32
33pub mod hwbench;
34
35#[cfg(target_arch = "aarch64")]
36pub mod arm64;
37
38#[cfg(target_arch = "aarch64")]
39pub use arm64::has_fp16;
40use tract_itertools::Itertools;
41
42#[cfg(not(target_arch = "aarch64"))]
43pub fn has_fp16() -> bool {
44 false
45}
46
47#[cfg(any(target_arch = "arm", target_arch = "armv7", target_arch = "arm"))]
48pub mod arm32;
49
50#[cfg(all(target_family = "wasm", target_feature = "simd128"))]
51pub mod wasm;
52
53pub use self::frame::*;
54
55use tract_data::prelude::*;
56
57pub type MMMImpl = Box<
58 dyn Fn(Option<usize>, Option<usize>, Option<usize>) -> Box<dyn mmm::MatMatMul> + Send + Sync,
59>;
60
61type MMVImpl = Box<dyn Fn(Option<usize>, Option<usize>) -> Box<dyn mmm::MatMatMul> + Send + Sync>;
62
63#[allow(clippy::type_complexity)]
64pub struct Ops {
65 mmm_impls: Vec<Box<dyn mmm::MatMatMul>>,
66 panel_extractors: Vec<mmm::PanelExtractor>,
67
68 mmm_f64: MMMImpl,
69 mmv_f64: MMVImpl,
70
71 mmm_f32: MMMImpl,
72 mmv_f32: MMVImpl,
73
74 mmm_f16: MMMImpl,
75 mmv_f16: MMVImpl,
76
77 qmmm_i32: MMMImpl,
78 qmmv_i32: MMVImpl,
79
80 pub leaky_relu_f16: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f16, f16>> + Send + Sync>,
81 pub leaky_relu_f32: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f32, f32>> + Send + Sync>,
82 pub mul_by_scalar_f32:
83 Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f32, f32>> + Send + Sync>,
84 pub mul_by_scalar_f16:
85 Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f16, f16>> + Send + Sync>,
86
87 pub sigmoid_f16: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f16>> + Send + Sync>,
88 pub sigmoid_f32: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f32>> + Send + Sync>,
89 pub tanh_f16: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f16>> + Send + Sync>,
90 pub tanh_f32: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f32>> + Send + Sync>,
91 pub erf_f32: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f32>> + Send + Sync>,
92 pub hardswish_f16: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f16>> + Send + Sync>,
93 pub hardswish_f32: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f32>> + Send + Sync>,
94 pub silu_f16: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f16>> + Send + Sync>,
95 pub silu_f32: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f32>> + Send + Sync>,
96 pub gelu_f16: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f16>> + Send + Sync>,
97 pub gelu_f32: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f32>> + Send + Sync>,
98 pub lut_u8: Box<dyn Fn(&[u8]) -> Box<dyn lut::Lut> + Send + Sync>,
99
100 pub max_f16: Box<dyn Fn() -> Box<dyn reduce::Reduce<f16>> + Send + Sync>,
101 pub max_f32: Box<dyn Fn() -> Box<dyn reduce::Reduce<f32>> + Send + Sync>,
102
103 pub sum_f16: Box<dyn Fn() -> Box<dyn reduce::Reduce<f16>> + Send + Sync>,
104 pub sum_f32: Box<dyn Fn() -> Box<dyn reduce::Reduce<f32>> + Send + Sync>,
105
106 pub softmax2_fastcompact_f16:
107 Box<dyn Fn() -> Box<dyn reduce::MapReduce<f16, f16>> + Send + Sync>,
108 pub softmax2_fastcompact_f32:
109 Box<dyn Fn() -> Box<dyn reduce::MapReduce<f32, f32>> + Send + Sync>,
110}
111
112impl Ops {
113 pub fn mmm_impls(&self) -> &[Box<dyn mmm::MatMatMul>] {
114 &self.mmm_impls
115 }
116
117 pub fn all_possible_packing(
118 &self,
119 weight_type: impl Into<WeightType>,
120 ) -> impl Iterator<Item = &dyn MMMInputFormat> {
121 let weight_type = weight_type.into();
122 self.mmm_impls
123 .iter()
124 .flat_map(|m| m.packings())
125 .map(|p| &*p.0)
126 .flat_map(move |p| {
127 let mut packs: Vec<&dyn MMMInputFormat> = vec![];
128 if p.precursor() == weight_type {
129 packs.push(p)
130 };
131 for pe in &self.panel_extractors {
132 if pe.from.precursor() == weight_type && pe.to.dyn_eq(p) {
133 packs.push(&*pe.from);
134 }
135 }
136 packs.into_iter()
137 })
138 .sorted_by_key(|p| p.to_string())
139 .dedup()
140 }
141
142 pub fn filter_impls<'o>(
143 &'o self,
144 weight: &'o dyn MMMInputFormat,
145 acc: &[DatumType],
146 act: DatumType,
147 store: DatumType,
148 ) -> impl Iterator<
149 Item = (
150 &'o dyn MatMatMul,
151 usize,
152 &'o dyn MMMInputFormat,
153 Option<&'o PanelExtractor>,
154 &'o dyn MMMInputFormat,
155 ),
156 > {
157 let acc = acc.to_vec();
158 self.mmm_impls
159 .iter()
160 .filter(move |mmm| acc.contains(&mmm.internal_type()) && mmm.stores().contains(&store))
161 .flat_map(|mmm| {
162 mmm.packings()
163 .iter()
164 .enumerate()
165 .map(|(pack_ix, (a, b))| (&**mmm, pack_ix, &**a, &**b))
166 })
167 .filter_map(|(mmm, ix, a, b)| {
168 if a.dyn_eq(weight) {
169 Some((mmm, ix, a, None, b))
170 } else {
171 self.panel_extractors
172 .iter()
173 .find(|pe| pe.from.dyn_eq(weight) && pe.to.dyn_eq(a))
174 .map(|pe| (mmm, ix, a, Some(pe), b))
175 }
176 })
177 .filter(move |(_mmm, _ix, _a, _pe, b)| {
178 b.precursor().as_dt().is_some_and(|dt| dt == act)
179 })
180 }
181
182 pub fn panel_extractors(&self) -> &[mmm::panel_extract::PanelExtractor] {
183 &self.panel_extractors
184 }
185
186 pub fn mmm(
187 &self,
188 accumulator: DatumType,
189 m: Option<usize>,
190 k: Option<usize>,
191 n: Option<usize>,
192 ) -> Option<Box<dyn mmm::MatMatMul>> {
193 use DatumType::*;
194 match accumulator {
195 F64 => Some(if n == Some(1) { (self.mmv_f64)(m, k) } else { (self.mmm_f64)(m, k, n) }),
196 F32 => Some(if n == Some(1) { (self.mmv_f32)(m, k) } else { (self.mmm_f32)(m, k, n) }),
197 F16 => Some(if n == Some(1) { (self.mmv_f16)(m, k) } else { (self.mmm_f16)(m, k, n) }),
198 I32 => {
199 Some(if n == Some(1) { (self.qmmv_i32)(m, k) } else { (self.qmmm_i32)(m, k, n) })
200 }
201 _ => None,
202 }
203 }
204}
205
206pub fn generic() -> Ops {
207 use crate::generic::mmm::*;
208 use element_wise::ElementWiseKer;
209 use reduce::{MapReduceKer, ReduceKer};
210 let mut ops = Ops {
211 mmm_impls: vec![],
212 panel_extractors: vec![],
213 mmm_f64: Box::new(|_, _, _| generic_f64_4x4.mmm()),
214 mmv_f64: Box::new(|_, _| generic_f64_4x1.mmm()),
215 mmm_f32: Box::new(|_, _, _| generic_f32_4x4.mmm()),
216 mmv_f32: Box::new(|_, _| generic_f32_4x1.mmm()),
217 mmm_f16: Box::new(|_, _, _| generic_f16_4x4.mmm()),
218 mmv_f16: Box::new(|_, _| generic_f16_4x1.mmm()),
219 qmmm_i32: Box::new(|_, _, _| generic_i32_4x4.mmm()),
220 qmmv_i32: Box::new(|_, _| generic_i32_4x4.mmm()),
221 leaky_relu_f16: Box::new(|| generic::HLeakyRelu8::ew()),
222 leaky_relu_f32: Box::new(|| generic::SLeakyRelu4::ew()),
223 mul_by_scalar_f16: Box::new(|| generic::HMulByScalar8::ew()),
224 mul_by_scalar_f32: Box::new(|| generic::SMulByScalar4::ew()),
225 sigmoid_f16: Box::new(|| generic::HSigmoid8::ew()),
226 sigmoid_f32: Box::new(|| generic::SSigmoid4::ew()),
227 tanh_f16: Box::new(|| generic::HTanh8::ew()),
228 tanh_f32: Box::new(|| generic::STanh4::ew()),
229 erf_f32: Box::new(|| generic::SErf4::ew()),
230 hardswish_f16: Box::new(|| generic::HHardSwish8::ew()),
231 hardswish_f32: Box::new(|| generic::SHardSwish4::ew()),
232 silu_f16: Box::new(|| generic::HSiLU8::ew()),
233 silu_f32: Box::new(|| generic::SSiLU4::ew()),
234 gelu_f16: Box::new(|| generic::HGelu8::ew()),
235 gelu_f32: Box::new(|| generic::SGelu4::ew()),
236 lut_u8: Box::new(|table: &[u8]| Box::new(lut::LutImpl::<generic::GenericLut8>::new(table))),
237 max_f16: Box::new(|| generic::reduce::max::HMax8::red()),
238 max_f32: Box::new(|| generic::reduce::max::SMax4::red()),
239 sum_f16: Box::new(|| generic::reduce::sum::HSum8::red()),
240 sum_f32: Box::new(|| generic::reduce::sum::SSum4::red()),
241 softmax2_fastcompact_f16: Box::new(|| generic::reduce::softmax_l2::HSoftMaxL2::red()),
245 softmax2_fastcompact_f32: Box::new(|| generic::reduce::softmax_l2::SSoftMaxL2::red()),
246 };
247 crate::generic::mmm::plug(&mut ops);
248 ops
249}
250
251#[allow(unreachable_code, unused_mut, unexpected_cfgs)]
252pub fn best() -> Ops {
253 let mut ops = generic();
254 #[cfg(target_arch = "x86_64")]
255 x86_64_fma::plug(&mut ops);
256 #[cfg(any(target_arch = "arm", target_arch = "armv7"))]
257 arm32::plug(&mut ops);
258 #[cfg(target_arch = "aarch64")]
259 arm64::plug(&mut ops);
260 #[cfg(all(target_family = "wasm", target_feature = "simd128"))]
261 wasm::plug(&mut ops);
262
263 ops
264}
265
266lazy_static::lazy_static! {
267 static ref OPS: Ops = {
268 best()
269 };
270}
271
272#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
273pub enum BinOp {
274 Min,
275 Max,
276 Add,
277 Mul,
278 Sub,
279 SubF,
280}
281
282impl BinOp {
283 pub fn flip(&self) -> BinOp {
284 use BinOp::*;
285 match self {
286 Sub => SubF,
287 SubF => Sub,
288 sym => *sym,
289 }
290 }
291}
292
293fn register_all_unicast(registry: &mut LinalgRegistry) {
294 generic::register_all_unicast(registry);
295 #[cfg(target_arch = "aarch64")]
296 arm64::register_all_unicast(registry);
297}
298
299fn register_all_by_scalar(registry: &mut LinalgRegistry) {
300 generic::register_all_by_scalar(registry);
301 #[cfg(target_arch = "aarch64")]
302 arm64::register_all_by_scalar(registry);
303}
304
305pub type LinalgFn = dyn Fn(&mut TensorView, &TensorView) -> TractResult<()> + Send + Sync;
306type LinalgRegistry = HashMap<(BinOp, DatumType), Box<dyn Fn() -> Box<LinalgFn> + Send + Sync>>;
307lazy_static! {
308 static ref BIN_UNICAST_OPS: Mutex<LinalgRegistry> = {
309 let mut registry = HashMap::default();
310 register_all_unicast(&mut registry);
311 Mutex::new(registry)
312 };
313 static ref BIN_BY_SCALAR_OPS: Mutex<LinalgRegistry> = {
314 let mut registry = HashMap::default();
315 register_all_by_scalar(&mut registry);
316 Mutex::new(registry)
317 };
318}
319
320pub fn bin_by_scalar(dt: DatumType, bin: BinOp) -> Option<Box<LinalgFn>> {
321 let map = BIN_BY_SCALAR_OPS.lock().unwrap();
322 if (dt == DatumType::F16) && !has_fp16() {
323 return None;
324 }
325 map.get(&(bin, dt)).map(|it| (it)())
326}
327
328pub fn bin_unicast(dt: DatumType, bin: BinOp) -> Option<Box<LinalgFn>> {
329 let map = BIN_UNICAST_OPS.lock().unwrap();
330 if (dt == DatumType::F16) && !has_fp16() {
331 return None;
332 }
333 map.get(&(bin, dt)).map(|it| (it)())
334}
335
336pub fn ops() -> &'static Ops {
337 &OPS
338}
339
340use dyn_eq::DynEq;
341use num_traits::*;
342use std::collections::HashMap;
343use std::fmt::Debug;
344use std::ops::*;
345use std::sync::Mutex;
346
347pub trait LADatum:
348 Sized
349 + std::fmt::Display
350 + Debug
351 + Copy
352 + Clone
353 + Zero
354 + One
355 + 'static
356 + Add<Output = Self>
357 + Sub<Output = Self>
358 + Mul
359 + AddAssign
360 + PartialOrd
361 + Bounded
362 + tract_data::prelude::Datum
363{
364 #[cfg(test)]
365 fn strat() -> proptest::prelude::BoxedStrategy<Self>;
366}
367
368#[cfg(test)]
369use proptest::prelude::*;
370
371impl LADatum for f16 {
372 #[cfg(test)]
373 fn strat() -> BoxedStrategy<Self> {
374 f32::strat().prop_map(|f| f.as_()).boxed()
375 }
376}
377
378impl LADatum for f32 {
379 #[cfg(test)]
380 fn strat() -> BoxedStrategy<Self> {
381 (-1000isize..1000).prop_map(|i| i as f32 / 1000.0).boxed()
382 }
383}
384
385impl LADatum for f64 {
386 #[cfg(test)]
387 fn strat() -> BoxedStrategy<Self> {
388 (-1000isize..1000).prop_map(|i| i as f64 / 1000.0).boxed()
389 }
390}
391
392impl LADatum for u8 {
393 #[cfg(test)]
394 fn strat() -> BoxedStrategy<Self> {
395 any::<u8>().boxed()
396 }
397}
398
399impl LADatum for i8 {
400 #[cfg(test)]
401 fn strat() -> BoxedStrategy<Self> {
402 any::<i8>().boxed()
403 }
404}
405
406impl LADatum for i32 {
407 #[cfg(test)]
408 fn strat() -> BoxedStrategy<Self> {
409 any::<i32>().boxed()
410 }
411}
412
413#[cfg(test)]
414#[allow(dead_code)]
415fn setup_test_logger() {
416 let _ = env_logger::Builder::from_env("TRACT_LOG").try_init();
417}