tract-core 0.23.0-dev.4

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
Documentation
#![allow(clippy::type_complexity)]

use dyn_clone::clone_box;
use dyn_eq::DynEq;
use tract_itertools::Itertools;
use tract_linalg::WeightType;
use tract_linalg::block_quant::BlockQuantFact;
use tract_linalg::mmm::{ImplementationQuality, MMMInputFormat, MatMatMul, PanelExtractor};

use crate::internal::*;
use crate::ops::matmul::ModePicker;

use super::einsum_matmul::EinSumMatMul;

pub type Impl = (Box<dyn MatMatMul>, usize, Option<PanelExtractor>);
pub type Strat = (ModePicker, Box<dyn MMMInputFormat>, Vec<Impl>);

fn single_strat(it: Impl) -> Strat {
    (ModePicker::Single, it.0.packings()[it.1].0.clone(), vec![it])
}

pub fn strategize(model: &TypedModel, node: &TypedNode, op: &EinSumMatMul) -> TractResult<Strat> {
    let input_facts = model.node_input_facts(node.id)?;
    if let (Some(m), Some(k), Some(n)) = (op.m.as_i64(), op.k.as_i64(), op.n.as_i64())
        && input_facts[0].is_plain()
        && input_facts[1].is_plain()
        && op.op.operating_dt == input_facts[0].datum_type
        && op.op.operating_dt == input_facts[1].datum_type
        && let Some(mmm) = tract_linalg::ops().mmm(
            op.operating_dt,
            Some(m as usize),
            Some(k as usize),
            Some(n as usize),
        )
        && mmm.quality() == ImplementationQuality::ManuallyOptimized
    {
        return Ok((ModePicker::Single, mmm.packings()[0].0.clone(), vec![(mmm, 0, None)]));
    };

    let mut impls = list_impls(model, node, op)?;
    ensure!(impls.len() > 0);
    fn score(mmm: &dyn MatMatMul) -> isize {
        -(mmm.quality().cost() as isize * 1000) + mmm.dynamic_boost()
    }
    let wanted_quality = impls.iter().map(|(mmm, _, _)| score(&**mmm)).max().unwrap();
    impls.retain(|(mmm, _, _)| score(&**mmm) == wanted_quality);
    if impls.len() == 1 {
        return Ok(single_strat(impls.remove(0)));
    }
    if op.n.is_one() {
        let it =
            impls.into_iter().max_by_key(|(m, _, pe)| (m.nr() == 1, pe.is_none(), m.mr())).unwrap();
        return Ok(single_strat(it));
    }
    if op.n.as_i64().is_some_and(|n| n > 1) {
        let it =
            impls.into_iter().max_by_key(|(m, _, pe)| (pe.is_none(), m.nr() * m.mr())).unwrap();
        return Ok(single_strat(it));
    }
    let mut grouped_by_left_packing = Vec::<(&dyn MMMInputFormat, Vec<_>)>::new();
    'mmm: for (m, p, pe) in &impls {
        let left_packing: &dyn MMMInputFormat =
            pe.as_ref().map(|pe| &*pe.from).unwrap_or(&*m.packings()[*p].0);
        for kit in &mut grouped_by_left_packing {
            if let Some(merged) = kit.0.merge_with(left_packing) {
                kit.0 = merged;
                kit.1.push((m, p, pe));
                continue 'mmm;
            }
        }
        grouped_by_left_packing.push((left_packing, vec![(m, p, pe)]));
    }
    let (p, mmv, mmm) = grouped_by_left_packing
        .iter()
        .map(|(p, kit)| {
            let best_for_mmv =
                kit.iter().max_by_key(|(m, _, pe)| (m.nr() == 1, pe.is_none())).unwrap();
            let best_for_mmm = kit.iter().max_by_key(|(m, _, _)| m.nr()).unwrap();
            (p, best_for_mmv, best_for_mmm)
        })
        .max_by_key(|(_, mmv, mmm)| {
            (mmv.0.nr() == 1 && mmm.0.nr() > 1, mmv.2.is_none(), mmm.0.mr(), mmm.0.nr())
        })
        .unwrap();

    if mmm == mmv {
        Ok((ModePicker::Single, clone_box(*p), vec![(mmv.0.clone(), *mmv.1, mmv.2.clone())]))
    } else {
        Ok((
            ModePicker::VecVsMat,
            clone_box(*p),
            vec![(mmv.0.clone(), *mmv.1, mmv.2.clone()), (mmm.0.clone(), *mmm.1, mmm.2.clone())],
        ))
    }
}

pub fn list_impls(
    model: &TypedModel,
    node: &TypedNode,
    op: &EinSumMatMul,
) -> TractResult<Vec<Impl>> {
    let (a_fact, b_fact) = model.node_input_facts(node.id)?.into_iter().collect_tuple().unwrap();
    let a_dt = a_fact.datum_type;
    let b_dt = b_fact.datum_type;

    let a_weight: WeightType = if let Some(of) = a_fact.exotic_fact() {
        if let Some(bqf) = of.downcast_ref::<BlockQuantFact>() {
            WeightType::BlockQuant(bqf.format.clone())
        } else {
            bail!("Can not translate to matmul operand {a_fact:?}");
        }
    } else {
        a_dt.into()
    };

    let impls = tract_linalg::ops()
        .mmm_impls()
        .iter()
        .filter(|mmm| {
            op.acceptable_accumulators().contains(&mmm.internal_type())
                && mmm.stores().contains(&op.operating_dt.unquantized())
        })
        .flat_map(move |mmm| {
            mmm.packings().iter().enumerate().map(|(ix, p)| (mmm.clone(), ix, &p.0, &p.1))
        })
        .filter_map(|(m, p, pa, pb)| {
            if pb.precursor().as_dt().is_none_or(|dt| dt != b_dt.unquantized()) {
                return None;
            }
            if pa.precursor() == a_weight {
                Some((m, p, None))
            } else {
                tract_linalg::ops()
                    .panel_extractors()
                    .iter()
                    .find(|pe| pe.from.precursor() == a_weight && pe.to.dyn_eq(&**pa))
                    .map(|pe| (m, p, Some(pe.clone())))
            }
        })
        .collect_vec();
    Ok(impls)
}