1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
pub mod lir_unary;
pub mod mir;
pub mod mir_unary;
pub mod pack;

use crate::internal::*;
use tract_itertools::Itertools;
use tract_ndarray::prelude::*;

pub use self::mir::MatMul;
pub use self::mir_unary::MatMulUnary;
use self::pack::MatMatMulPack;
use crate::ops::quant::QParams;

pub fn compute_shape<D: DimLike>(
    ashape: &[D],
    bshape: &[D],
    a_trans: bool,
    b_trans: bool,
    c_trans: bool,
) -> TractResult<(D, D, D, TVec<D>)> {
    let mut c_shape = crate::broadcast::multi_broadcast(&[
        &ashape[..(ashape.len() - 2)],
        &bshape[..(bshape.len() - 2)],
    ])
    .ok_or_else(|| format_err!("Could not broadcast"))?;
    let (mut m, mut ka) = (ashape[ashape.len() - 2].clone(), ashape[ashape.len() - 1].clone());
    let (mut kb, mut n) = (bshape[bshape.len() - 2].clone(), bshape[bshape.len() - 1].clone());
    if a_trans {
        std::mem::swap(&mut m, &mut ka);
    }
    if b_trans {
        std::mem::swap(&mut kb, &mut n);
    }
    if ka != kb {
        bail!(
            "Inconsistent matmul: a: {} b: {}, a_trans: {} b_trans: {} c_trans: {}",
            ashape.iter().join(","),
            bshape.iter().join(","),
            a_trans,
            b_trans,
            c_trans
        );
    }
    if c_trans {
        c_shape.push(n.clone());
        c_shape.push(m.clone());
    } else {
        c_shape.push(m.clone());
        c_shape.push(n.clone());
    }
    Ok((m, ka, n, c_shape))
}

pub(super) fn eval(
    a: &Tensor,
    b: &Tensor,
    a_trans: bool,
    b_trans: bool,
    c_trans: bool,
    q_params: Option<&QParams>,
) -> TractResult<Tensor> {
    unsafe {
        let rank = a.rank();
        let (m, k, n, c_shape) = compute_shape(a.shape(), b.shape(), a_trans, b_trans, c_trans)?;
        let c_dt = q_params.map(|q| q.c_datum_type).unwrap_or(a.datum_type());
        let mut mm = tract_linalg::ops()
            .mmm(a.datum_type(), b.datum_type(), c_dt, m, k, n)
            .with_context(|| {
                format!(
                    "No matrix multiplier for {:?}x{:?} to {:?}",
                    a.datum_type(),
                    b.datum_type(),
                    c_dt
                )
            })?;
        mm.c_from_data_and_strides(
            if c_trans { 1 } else { c_shape[rank - 1] as isize },
            if !c_trans { 1 } else { c_shape[rank - 1] as isize },
        );
        if let Some(q) = q_params {
            q.inject_into_mmm(&mut *mm)?;
        }
        let mut c = Tensor::uninitialized_dt(c_dt, &c_shape)?;

        let a_pack = mm.a_pack();
        let b_pack = mm.b_pack();

        let mut packed_a =
            Tensor::uninitialized_aligned_dt(a.datum_type(), &[a_pack.len(m)], a_pack.alignment())?;
        let mut packed_b =
            Tensor::uninitialized_aligned_dt(b.datum_type(), &[b_pack.len(n)], b_pack.alignment())?;

        for prefix in tract_ndarray::indices(&c_shape[..rank - 2]).into_iter() {
            let mut a_prefix = tvec!();
            let mut b_prefix = tvec!();
            for (axis, &dim) in prefix.slice().iter().enumerate() {
                a_prefix.push(dim.min(a.shape()[axis] - 1));
                b_prefix.push(dim.min(b.shape()[axis] - 1));
            }
            a_pack.pack(packed_a.view_mut(), &a.view_at_prefix(&a_prefix)?, !a_trans as usize, a_trans as usize);
            b_pack.pack(packed_b.view_mut(), &b.view_at_prefix(&b_prefix)?, b_trans as usize, !b_trans as usize);
            mm.run(
                &packed_a.view(),
                &packed_b.view(),
                &mut c.view_at_prefix_mut(prefix.slice())?,
                &[],
            )?;
        }
        Ok(c)
    }
}