1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
pub mod lir_unary; pub mod mir; pub mod mir_unary; pub mod pack; use crate::internal::*; use tract_itertools::Itertools; use tract_ndarray::prelude::*; pub use self::mir::MatMul; pub use self::mir_unary::MatMulUnary; use self::pack::MatMatMulPack; use crate::ops::quant::QParams; pub fn compute_shape<D: DimLike>( ashape: &[D], bshape: &[D], a_trans: bool, b_trans: bool, c_trans: bool, ) -> TractResult<(D, D, D, TVec<D>)> { let mut c_shape = crate::broadcast::multi_broadcast(&[ &ashape[..(ashape.len() - 2)], &bshape[..(bshape.len() - 2)], ]) .ok_or_else(|| format_err!("Could not broadcast"))?; let (mut m, mut ka) = (ashape[ashape.len() - 2].clone(), ashape[ashape.len() - 1].clone()); let (mut kb, mut n) = (bshape[bshape.len() - 2].clone(), bshape[bshape.len() - 1].clone()); if a_trans { std::mem::swap(&mut m, &mut ka); } if b_trans { std::mem::swap(&mut kb, &mut n); } if ka != kb { bail!( "Inconsistent matmul: a: {} b: {}, a_trans: {} b_trans: {} c_trans: {}", ashape.iter().join(","), bshape.iter().join(","), a_trans, b_trans, c_trans ); } if c_trans { c_shape.push(n.clone()); c_shape.push(m.clone()); } else { c_shape.push(m.clone()); c_shape.push(n.clone()); } Ok((m, ka, n, c_shape)) } pub(super) fn eval( a: &Tensor, b: &Tensor, a_trans: bool, b_trans: bool, c_trans: bool, q_params: Option<&QParams>, ) -> TractResult<Tensor> { unsafe { let rank = a.rank(); let (m, k, n, c_shape) = compute_shape(a.shape(), b.shape(), a_trans, b_trans, c_trans)?; let c_dt = q_params.map(|q| q.c_datum_type).unwrap_or(a.datum_type()); let mut mm = tract_linalg::ops() .mmm(a.datum_type(), b.datum_type(), c_dt, m, k, n) .with_context(|| { format!( "No matrix multiplier for {:?}x{:?} to {:?}", a.datum_type(), b.datum_type(), c_dt ) })?; mm.c_from_data_and_strides( if c_trans { 1 } else { c_shape[rank - 1] as isize }, if !c_trans { 1 } else { c_shape[rank - 1] as isize }, ); if let Some(q) = q_params { q.inject_into_mmm(&mut *mm)?; } let mut c = Tensor::uninitialized_dt(c_dt, &c_shape)?; let a_pack = mm.a_pack(); let b_pack = mm.b_pack(); let mut packed_a = Tensor::uninitialized_aligned_dt(a.datum_type(), &[a_pack.len(m)], a_pack.alignment())?; let mut packed_b = Tensor::uninitialized_aligned_dt(b.datum_type(), &[b_pack.len(n)], b_pack.alignment())?; for prefix in tract_ndarray::indices(&c_shape[..rank - 2]).into_iter() { let mut a_prefix = tvec!(); let mut b_prefix = tvec!(); for (axis, &dim) in prefix.slice().iter().enumerate() { a_prefix.push(dim.min(a.shape()[axis] - 1)); b_prefix.push(dim.min(b.shape()[axis] - 1)); } a_pack.pack(packed_a.view_mut(), &a.view_at_prefix(&a_prefix)?, !a_trans as usize, a_trans as usize); b_pack.pack(packed_b.view_mut(), &b.view_at_prefix(&b_prefix)?, b_trans as usize, !b_trans as usize); mm.run( &packed_a.view(), &packed_b.view(), &mut c.view_at_prefix_mut(prefix.slice())?, &[], )?; } Ok(c) } }