tract_linalg/
lib.rs

1#![allow(clippy::missing_safety_doc)]
2#![allow(clippy::redundant_closure_call)]
3#![allow(clippy::len_zero)]
4#![allow(clippy::excessive_precision)]
5#![allow(clippy::approx_constant)]
6#![allow(clippy::manual_is_multiple_of)]
7#![allow(unexpected_cfgs)]
8#![allow(unused_macros)]
9#[macro_use]
10extern crate derive_new;
11extern crate lazy_static;
12extern crate log;
13extern crate num_traits;
14#[macro_use]
15extern crate pastey;
16#[cfg(test)]
17extern crate proptest;
18
19include!(concat!(env!("OUT_DIR"), "/extern_kernel_macro.rs"));
20
21#[macro_use]
22mod frame;
23pub mod generic;
24pub mod multithread;
25pub use frame::weights::WeightType;
26pub use generic::{ScaleShiftAndRound, Scaler};
27use lazy_static::lazy_static;
28use mmm::{MMMInputFormat, MatMatMul, PanelExtractor};
29use tract_data::internal::TensorView;
30#[cfg(target_arch = "x86_64")]
31pub mod x86_64_fma;
32
33pub mod hwbench;
34
35#[cfg(target_arch = "aarch64")]
36pub mod arm64;
37
38#[cfg(target_arch = "aarch64")]
39pub use arm64::has_fp16;
40use tract_itertools::Itertools;
41
42#[cfg(not(target_arch = "aarch64"))]
43pub fn has_fp16() -> bool {
44    false
45}
46
47#[cfg(any(target_arch = "arm", target_arch = "armv7", target_arch = "arm"))]
48pub mod arm32;
49
50#[cfg(all(target_family = "wasm", target_feature = "simd128"))]
51pub mod wasm;
52
53pub use self::frame::*;
54
55use tract_data::prelude::*;
56
57pub type MMMImpl = Box<
58    dyn Fn(Option<usize>, Option<usize>, Option<usize>) -> Box<dyn mmm::MatMatMul> + Send + Sync,
59>;
60
61type MMVImpl = Box<dyn Fn(Option<usize>, Option<usize>) -> Box<dyn mmm::MatMatMul> + Send + Sync>;
62
63#[allow(clippy::type_complexity)]
64pub struct Ops {
65    mmm_impls: Vec<Box<dyn mmm::MatMatMul>>,
66    panel_extractors: Vec<mmm::PanelExtractor>,
67
68    mmm_f64: MMMImpl,
69    mmv_f64: MMVImpl,
70
71    mmm_f32: MMMImpl,
72    mmv_f32: MMVImpl,
73
74    mmm_f16: MMMImpl,
75    mmv_f16: MMVImpl,
76
77    qmmm_i32: MMMImpl,
78    qmmv_i32: MMVImpl,
79
80    pub leaky_relu_f16: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f16, f16>> + Send + Sync>,
81    pub leaky_relu_f32: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f32, f32>> + Send + Sync>,
82    pub mul_by_scalar_f32:
83        Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f32, f32>> + Send + Sync>,
84    pub mul_by_scalar_f16:
85        Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f16, f16>> + Send + Sync>,
86
87    pub sigmoid_f16: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f16>> + Send + Sync>,
88    pub sigmoid_f32: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f32>> + Send + Sync>,
89    pub tanh_f16: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f16>> + Send + Sync>,
90    pub tanh_f32: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f32>> + Send + Sync>,
91    pub erf_f32: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f32>> + Send + Sync>,
92    pub lut_u8: Box<dyn Fn(&[u8]) -> Box<dyn lut::Lut> + Send + Sync>,
93
94    pub max_f16: Box<dyn Fn() -> Box<dyn reduce::Reduce<f16>> + Send + Sync>,
95    pub max_f32: Box<dyn Fn() -> Box<dyn reduce::Reduce<f32>> + Send + Sync>,
96
97    pub sum_f16: Box<dyn Fn() -> Box<dyn reduce::Reduce<f16>> + Send + Sync>,
98    pub sum_f32: Box<dyn Fn() -> Box<dyn reduce::Reduce<f32>> + Send + Sync>,
99
100    pub softmax2_fastcompact_f16:
101        Box<dyn Fn() -> Box<dyn reduce::MapReduce<f16, f16>> + Send + Sync>,
102    pub softmax2_fastcompact_f32:
103        Box<dyn Fn() -> Box<dyn reduce::MapReduce<f32, f32>> + Send + Sync>,
104}
105
106impl Ops {
107    pub fn mmm_impls(&self) -> &[Box<dyn mmm::MatMatMul>] {
108        &self.mmm_impls
109    }
110
111    pub fn all_possible_packing(
112        &self,
113        weight_type: impl Into<WeightType>,
114    ) -> impl Iterator<Item = &dyn MMMInputFormat> {
115        let weight_type = weight_type.into();
116        self.mmm_impls
117            .iter()
118            .flat_map(|m| m.packings())
119            .map(|p| &*p.0)
120            .flat_map(move |p| {
121                let mut packs: Vec<&dyn MMMInputFormat> = vec![];
122                if p.precursor() == weight_type {
123                    packs.push(p)
124                };
125                for pe in &self.panel_extractors {
126                    if pe.from.precursor() == weight_type && pe.to.same_as(p) {
127                        packs.push(&*pe.from);
128                    }
129                }
130                packs.into_iter()
131            })
132            .sorted_by_key(|p| p.to_string())
133            .dedup()
134    }
135
136    pub fn filter_impls<'o>(
137        &'o self,
138        weight: &'o dyn MMMInputFormat,
139        acc: &[DatumType],
140        act: DatumType,
141        store: DatumType,
142    ) -> impl Iterator<
143        Item = (
144            &'o dyn MatMatMul,
145            usize,
146            &'o dyn MMMInputFormat,
147            Option<&'o PanelExtractor>,
148            &'o dyn MMMInputFormat,
149        ),
150    > {
151        let acc = acc.to_vec();
152        self.mmm_impls
153            .iter()
154            .filter(move |mmm| acc.contains(&mmm.internal_type()) && mmm.stores().contains(&store))
155            .flat_map(|mmm| {
156                mmm.packings()
157                    .iter()
158                    .enumerate()
159                    .map(|(pack_ix, (a, b))| (&**mmm, pack_ix, &**a, &**b))
160            })
161            .filter_map(|(mmm, ix, a, b)| {
162                if a.same_as(weight) {
163                    Some((mmm, ix, a, None, b))
164                } else {
165                    self.panel_extractors
166                        .iter()
167                        .find(|pe| pe.from.same_as(weight) && pe.to.same_as(a))
168                        .map(|pe| (mmm, ix, a, Some(pe), b))
169                }
170            })
171            .filter(move |(_mmm, _ix, _a, _pe, b)| {
172                b.precursor().as_dt().is_some_and(|dt| dt == act)
173            })
174    }
175
176    pub fn panel_extractors(&self) -> &[mmm::panel_extract::PanelExtractor] {
177        &self.panel_extractors
178    }
179
180    pub fn mmm(
181        &self,
182        accumulator: DatumType,
183        m: Option<usize>,
184        k: Option<usize>,
185        n: Option<usize>,
186    ) -> Option<Box<dyn mmm::MatMatMul>> {
187        use DatumType::*;
188        match accumulator {
189            F64 => Some(if n == Some(1) { (self.mmv_f64)(m, k) } else { (self.mmm_f64)(m, k, n) }),
190            F32 => Some(if n == Some(1) { (self.mmv_f32)(m, k) } else { (self.mmm_f32)(m, k, n) }),
191            F16 => Some(if n == Some(1) { (self.mmv_f16)(m, k) } else { (self.mmm_f16)(m, k, n) }),
192            I32 => {
193                Some(if n == Some(1) { (self.qmmv_i32)(m, k) } else { (self.qmmm_i32)(m, k, n) })
194            }
195            _ => None,
196        }
197    }
198}
199
200pub fn generic() -> Ops {
201    use crate::generic::mmm::*;
202    use element_wise::ElementWiseKer;
203    use reduce::{MapReduceKer, ReduceKer};
204    let mut ops = Ops {
205        mmm_impls: vec![],
206        panel_extractors: vec![],
207        mmm_f64: Box::new(|_, _, _| generic_f64_4x4.mmm()),
208        mmv_f64: Box::new(|_, _| generic_f64_4x1.mmm()),
209        mmm_f32: Box::new(|_, _, _| generic_f32_4x4.mmm()),
210        mmv_f32: Box::new(|_, _| generic_f32_4x1.mmm()),
211        mmm_f16: Box::new(|_, _, _| generic_f16_4x4.mmm()),
212        mmv_f16: Box::new(|_, _| generic_f16_4x1.mmm()),
213        qmmm_i32: Box::new(|_, _, _| generic_i32_4x4.mmm()),
214        qmmv_i32: Box::new(|_, _| generic_i32_4x4.mmm()),
215        leaky_relu_f16: Box::new(|| generic::HLeakyRelu8::ew()),
216        leaky_relu_f32: Box::new(|| generic::SLeakyRelu4::ew()),
217        mul_by_scalar_f16: Box::new(|| generic::HMulByScalar8::ew()),
218        mul_by_scalar_f32: Box::new(|| generic::SMulByScalar4::ew()),
219        sigmoid_f16: Box::new(|| generic::HSigmoid8::ew()),
220        sigmoid_f32: Box::new(|| generic::SSigmoid4::ew()),
221        tanh_f16: Box::new(|| generic::HTanh8::ew()),
222        tanh_f32: Box::new(|| generic::STanh4::ew()),
223        erf_f32: Box::new(|| generic::SErf4::ew()),
224        lut_u8: Box::new(|table: &[u8]| Box::new(lut::LutImpl::<generic::GenericLut8>::new(table))),
225        max_f16: Box::new(|| generic::reduce::max::HMax8::red()),
226        max_f32: Box::new(|| generic::reduce::max::SMax4::red()),
227        sum_f16: Box::new(|| generic::reduce::sum::HSum8::red()),
228        sum_f32: Box::new(|| generic::reduce::sum::SSum4::red()),
229        /*
230        activation_f32: Box::new(|microcode| generic::SActivation::new(microcode))
231        */
232        softmax2_fastcompact_f16: Box::new(|| generic::reduce::softmax_l2::HSoftMaxL2::red()),
233        softmax2_fastcompact_f32: Box::new(|| generic::reduce::softmax_l2::SSoftMaxL2::red()),
234    };
235    crate::generic::mmm::plug(&mut ops);
236    ops
237}
238
239#[allow(unreachable_code, unused_mut, unexpected_cfgs)]
240pub fn best() -> Ops {
241    let mut ops = generic();
242    #[cfg(target_arch = "x86_64")]
243    x86_64_fma::plug(&mut ops);
244    #[cfg(any(target_arch = "arm", target_arch = "armv7"))]
245    arm32::plug(&mut ops);
246    #[cfg(target_arch = "aarch64")]
247    arm64::plug(&mut ops);
248    #[cfg(all(target_family = "wasm", target_feature = "simd128"))]
249    wasm::plug(&mut ops);
250
251    ops
252}
253
254lazy_static::lazy_static! {
255    static ref OPS: Ops = {
256        best()
257    };
258}
259
260#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
261pub enum BinOp {
262    Min,
263    Max,
264    Add,
265    Mul,
266    Sub,
267    SubF,
268}
269
270impl BinOp {
271    pub fn flip(&self) -> BinOp {
272        use BinOp::*;
273        match self {
274            Sub => SubF,
275            SubF => Sub,
276            sym => *sym,
277        }
278    }
279}
280
281fn register_all_unicast(registry: &mut LinalgRegistry) {
282    generic::register_all_unicast(registry);
283    #[cfg(target_arch = "aarch64")]
284    arm64::register_all_unicast(registry);
285}
286
287fn register_all_by_scalar(registry: &mut LinalgRegistry) {
288    generic::register_all_by_scalar(registry);
289    #[cfg(target_arch = "aarch64")]
290    arm64::register_all_by_scalar(registry);
291}
292
293pub type LinalgFn = dyn Fn(&mut TensorView, &TensorView) -> TractResult<()> + Send + Sync;
294type LinalgRegistry = HashMap<(BinOp, DatumType), Box<dyn Fn() -> Box<LinalgFn> + Send + Sync>>;
295lazy_static! {
296    static ref BIN_UNICAST_OPS: Mutex<LinalgRegistry> = {
297        let mut registry = HashMap::default();
298        register_all_unicast(&mut registry);
299        Mutex::new(registry)
300    };
301    static ref BIN_BY_SCALAR_OPS: Mutex<LinalgRegistry> = {
302        let mut registry = HashMap::default();
303        register_all_by_scalar(&mut registry);
304        Mutex::new(registry)
305    };
306}
307
308pub fn bin_by_scalar(dt: DatumType, bin: BinOp) -> Option<Box<LinalgFn>> {
309    let map = BIN_BY_SCALAR_OPS.lock().unwrap();
310    if (dt == DatumType::F16) && !has_fp16() {
311        return None;
312    }
313    map.get(&(bin, dt)).map(|it| (it)())
314}
315
316pub fn bin_unicast(dt: DatumType, bin: BinOp) -> Option<Box<LinalgFn>> {
317    let map = BIN_UNICAST_OPS.lock().unwrap();
318    if (dt == DatumType::F16) && !has_fp16() {
319        return None;
320    }
321    map.get(&(bin, dt)).map(|it| (it)())
322}
323
324pub fn ops() -> &'static Ops {
325    &OPS
326}
327
328use num_traits::*;
329use std::collections::HashMap;
330use std::fmt::Debug;
331use std::ops::*;
332use std::sync::Mutex;
333
334pub trait LADatum:
335    Sized
336    + std::fmt::Display
337    + Debug
338    + Copy
339    + Clone
340    + Zero
341    + One
342    + 'static
343    + Add<Output = Self>
344    + Sub<Output = Self>
345    + Mul
346    + AddAssign
347    + PartialOrd
348    + Bounded
349    + tract_data::prelude::Datum
350{
351    #[cfg(test)]
352    fn strat() -> proptest::prelude::BoxedStrategy<Self>;
353}
354
355#[cfg(test)]
356use proptest::prelude::*;
357
358impl LADatum for f16 {
359    #[cfg(test)]
360    fn strat() -> BoxedStrategy<Self> {
361        f32::strat().prop_map(|f| f.as_()).boxed()
362    }
363}
364
365impl LADatum for f32 {
366    #[cfg(test)]
367    fn strat() -> BoxedStrategy<Self> {
368        (-1000isize..1000).prop_map(|i| i as f32 / 1000.0).boxed()
369    }
370}
371
372impl LADatum for f64 {
373    #[cfg(test)]
374    fn strat() -> BoxedStrategy<Self> {
375        (-1000isize..1000).prop_map(|i| i as f64 / 1000.0).boxed()
376    }
377}
378
379impl LADatum for u8 {
380    #[cfg(test)]
381    fn strat() -> BoxedStrategy<Self> {
382        any::<u8>().boxed()
383    }
384}
385
386impl LADatum for i8 {
387    #[cfg(test)]
388    fn strat() -> BoxedStrategy<Self> {
389        any::<i8>().boxed()
390    }
391}
392
393impl LADatum for i32 {
394    #[cfg(test)]
395    fn strat() -> BoxedStrategy<Self> {
396        any::<i32>().boxed()
397    }
398}
399
400#[cfg(test)]
401#[allow(dead_code)]
402fn setup_test_logger() {
403    let _ = env_logger::Builder::from_env("TRACT_LOG").try_init();
404}