tract-gpu 0.23.0-dev.6

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
Documentation
use crate::rule_ensure;
use crate::sync::{DeviceSync, DeviceSyncKind};
use crate::tensor::DeviceTensorExt;
use tract_core::internal::*;
use tract_core::ops::konst::Const;
use tract_core::tract_data::itertools::Itertools;
use tract_core::tract_linalg::block_quant::{BlockQuantFact, BlockQuantStorage};

pub fn rewire_syncs(model: &mut TypedModel) -> TractResult<()> {
    Rewriter::default()
        .with_rule_for("remove-back-and-forth-sync", rewire_back_and_forth_sync)
        .with_rule_for("remove-sync-after-const", rewire_sync_after_const)
        .rewrite(&(), model)
}

pub fn rewire_back_and_forth_sync(
    _ctx: &(),
    model: &TypedModel,
    node: &TypedNode,
    _node_name: &str,
    op: &DeviceSync,
) -> TractResult<Option<TypedModelPatch>> {
    // Search pattern => ToHost => ToDevice
    rule_ensure!(op.kind == DeviceSyncKind::ToDevice);

    // Identify precessor ToHost
    rule_if_some!(sync_to_host_prec = model.single_prec(node.id)?);
    rule_if_some!(sync_to_host_prec_op = sync_to_host_prec.op_as::<DeviceSync>());
    rule_ensure!(sync_to_host_prec_op.kind == DeviceSyncKind::ToHost);

    let patch =
        TypedModelPatch::rewire(model, &sync_to_host_prec.inputs, &[node.id.into()], &|_p, xs| {
            Ok(xs.into())
        })?;
    Ok(Some(patch))
}

pub fn rewire_sync_after_const(
    _ctx: &(),
    model: &TypedModel,
    node: &TypedNode,
    node_name: &str,
    op: &Const,
) -> TractResult<Option<TypedModelPatch>> {
    // Search pattern => Const => ToHost

    rule_if_some!(device_const = op.val().as_device_tensor());

    // Identify successors ToHost
    rule_if_some!(next_nodes = model.all_succ(node.id)?);

    let sync_to_hosts = next_nodes
        .into_iter()
        .filter(|n| n.op_as::<DeviceSync>().is_some_and(|sync| sync.kind == DeviceSyncKind::ToHost))
        .collect_vec();

    rule_if!(!sync_to_hosts.is_empty());

    let host_const = device_const.to_host()?;
    let exotic_fact: Option<Box<dyn ExoticFact>> =
        host_const.storage_as::<BlockQuantStorage>().map(|bqs| {
            Box::new(BlockQuantFact::new(
                tract_core::dyn_clone::clone_box(bqs.format()),
                host_const.shape().into(),
            )) as Box<dyn ExoticFact>
        });

    let mut patch = TypedModelPatch::default();
    let out = patch.wire_node(
        node_name.to_string(),
        Const::new_with_opt_exotic_fact(host_const, exotic_fact)?,
        &[],
    )?;

    for sync_node in sync_to_hosts {
        patch.shunt_outside(model, sync_node.id.into(), out[0])?;
    }

    Ok(Some(patch))
}