use crate::internal::*;
use std::collections::{BTreeSet, HashMap};
use tract_core::axes::AxesMapping;
use tract_core::model::TypedModelPatch;
use tract_core::ops::binary::TypedBinOp;
use tract_core::ops::change_axes::AxisOp;
use tract_core::ops::einsum::EinSum;
use tract_core::ops::nn::{Reduce, Reducer};
use tract_core::transform::ModelTransform;
use tract_transformers::ops::DiagGather;
#[derive(Debug, Default, Clone, serde::Deserialize)]
pub struct BlockifyConfig {
pub symbol: Option<String>,
}
pub const BLOCKIFY_CHUNK_SYMBOL: &str = "blockify.chunk_symbol";
pub const BLOCKIFY_CHUNK_SIZE: &str = "blockify.chunk_size";
pub const BLOCKIFY_ORIGINAL_SYMBOL: &str = "blockify.original_symbol";
#[derive(Debug)]
pub struct BlockifyTransform(pub BlockifyConfig);
impl ModelTransform for BlockifyTransform {
fn name(&self) -> std::borrow::Cow<'static, str> {
"blockify".into()
}
fn transform(&self, model: &mut TypedModel) -> TractResult<()> {
let symbol_name = self.0.symbol.as_deref().unwrap_or("S");
let stream_sym = model.symbols.sym(symbol_name);
let sections = find_quadratic_sections(model, &stream_sym)?;
if sections.is_empty() {
return Ok(());
}
let k = sections[0].mask.chunk_size;
if !sections.iter().all(|s| s.mask.chunk_size == k) {
bail!(
"Blockify found multiple quadratic sections with mismatched chunk \
sizes; a single global substitution cannot cover them. \
Refusing to blockify rather than produce a partial rewrite."
);
}
let chunk_sym = model.symbols.new_with_prefix("S");
let subs: HashMap<Symbol, TDim> =
HashMap::from([(stream_sym.clone(), chunk_sym.to_dim() * k)]);
let new_model = model.set_symbols(&subs)?;
*model = new_model;
rewrite_sections(model, &chunk_sym, k)?;
model.properties.insert(
BLOCKIFY_ORIGINAL_SYMBOL.to_string(),
tensor1(&[symbol_name.to_string()]).into_arc_tensor(),
);
Ok(())
}
}
pub fn has_quadratic_sections(model: &TypedModel, stream_sym: &Symbol) -> TractResult<bool> {
Ok(!find_quadratic_sections(model, stream_sym)?.is_empty())
}
pub fn rewrite_sections(
model: &mut TypedModel,
chunk_sym: &Symbol,
substitute_multiplier: i64,
) -> TractResult<bool> {
let sections = find_quadratic_sections(model, chunk_sym)?;
if sections.is_empty() {
return Ok(false);
}
let k = sections[0].mask.chunk_size;
if !sections.iter().all(|s| s.mask.chunk_size == k) {
bail!(
"Blockify found multiple quadratic sections with mismatched chunk \
sizes; a single global substitution cannot cover them. \
Refusing to blockify rather than produce a partial rewrite."
);
}
for sec in §ions {
let patch = build_section_patch(model, sec, chunk_sym, sec.mask.chunk_size)?;
patch.apply(model)?;
}
model.properties.insert(
BLOCKIFY_CHUNK_SYMBOL.to_string(),
tensor1(&[format!("{chunk_sym}")]).into_arc_tensor(),
);
model
.properties
.insert(BLOCKIFY_CHUNK_SIZE.to_string(), tensor0(substitute_multiplier).into_arc_tensor());
Ok(true)
}
pub fn blockify_output(model: &TypedModel) -> Option<(Symbol, i64)> {
let k = model.properties.get(BLOCKIFY_CHUNK_SIZE)?.cast_to_scalar::<i64>().ok()?;
let name_tensor = model.properties.get(BLOCKIFY_CHUNK_SYMBOL)?;
let view = name_tensor.to_plain_array_view::<String>().ok()?;
let name = view.iter().next()?;
Some((model.symbols.sym(name), k))
}
fn section_only_diag_gather_consumer(
model: &TypedModel,
einsum_node: &TypedNode,
sec: &QuadraticSection,
) -> Option<usize> {
let consumers: Vec<_> = model
.outlet_successors(OutletId::new(einsum_node.id, 0))
.iter()
.filter(|s| sec.section.contains(&s.node))
.collect();
if consumers.len() != 1 {
return None;
}
let dg_id = consumers[0].node;
if !model.nodes[dg_id].op_is::<DiagGather>() {
return None;
}
Some(dg_id)
}
fn streaming_positions(fact: &TypedFact, stream_sym: &Symbol) -> TVec<usize> {
fact.shape
.iter()
.enumerate()
.filter(|(_, d)| d.symbols().contains(stream_sym))
.map(|(i, _)| i)
.collect()
}
#[derive(Debug)]
struct QuadraticSection {
#[allow(dead_code)]
section: BTreeSet<usize>,
initiators: Vec<usize>,
terminators: Vec<usize>,
mask: MaskForm,
contracted_axis: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct MaskForm {
chunk_size: i64,
lower: i64,
upper: i64,
axis_a: usize,
axis_b: usize,
}
impl MaskForm {
fn is_block_diag(&self) -> bool {
self.lower == 0 && self.upper == 0
}
}
fn connected_components(model: &TypedModel, nodes: &BTreeSet<usize>) -> Vec<BTreeSet<usize>> {
let mut parent: HashMap<usize, usize> = nodes.iter().map(|&n| (n, n)).collect();
fn uf_find(p: &mut HashMap<usize, usize>, x: usize) -> usize {
let px = p[&x];
if px == x {
return x;
}
let r = uf_find(p, px);
p.insert(x, r);
r
}
fn uf_union(p: &mut HashMap<usize, usize>, x: usize, y: usize) {
let rx = uf_find(p, x);
let ry = uf_find(p, y);
if rx != ry {
p.insert(rx, ry);
}
}
for &nid in nodes {
for cons in model.outlet_successors(OutletId::new(nid, 0)) {
if nodes.contains(&cons.node) {
uf_union(&mut parent, nid, cons.node);
}
}
}
let mut groups: HashMap<usize, BTreeSet<usize>> = HashMap::default();
for &nid in nodes {
let root = uf_find(&mut parent, nid);
groups.entry(root).or_default().insert(nid);
}
let mut out: Vec<BTreeSet<usize>> = groups.into_values().collect();
out.sort_by_key(|g| *g.iter().next().unwrap_or(&usize::MAX));
out
}
fn find_quadratic_sections(
model: &TypedModel,
stream_sym: &Symbol,
) -> TractResult<Vec<QuadraticSection>> {
let is_multi_t_axis = |fact: &TypedFact| {
fact.shape.iter().filter(|d| d.symbols().contains(stream_sym)).count() >= 2
};
let multi_nodes: BTreeSet<usize> = model
.nodes
.iter()
.filter(|n| n.outputs.len() == 1 && is_multi_t_axis(&n.outputs[0].fact))
.map(|n| n.id)
.collect();
if multi_nodes.is_empty() {
return Ok(vec![]);
}
let groups = connected_components(model, &multi_nodes);
let mut sections: Vec<QuadraticSection> = vec![];
for section in groups {
let initiators: Vec<usize> = section
.iter()
.copied()
.filter(|&nid| !model.nodes[nid].inputs.iter().any(|i| section.contains(&i.node)))
.collect();
let mut terminators_set: BTreeSet<usize> = BTreeSet::new();
for &nid in §ion {
for cons in model.outlet_successors(OutletId::new(nid, 0)) {
if !section.contains(&cons.node) {
terminators_set.insert(cons.node);
}
}
}
let terminators: Vec<usize> = terminators_set.into_iter().collect();
let any_annotated = section.iter().any(|&nid| {
let fact = &model.nodes[nid].outputs[0].fact;
fact.uniform_tdim.is_some() || fact.region_of_interest.is_some()
});
if !any_annotated {
continue;
}
let mut mask: Option<MaskForm> = None;
for &nid in §ion {
let fact = &model.nodes[nid].outputs[0].fact;
let Some(uniform) = &fact.uniform_tdim else {
continue;
};
let streaming_axes: TVec<usize> = fact
.shape
.iter()
.enumerate()
.filter(|(_, d)| d.symbols().contains(stream_sym))
.map(|(i, _)| i)
.collect();
if let Some(form) = decode_mask(uniform, &streaming_axes) {
mask = Some(form);
break;
}
}
let Some(mask) = mask else {
continue;
};
let mut contracted_axis: Option<usize> = None;
let mut contracted_ok = true;
for &t_id in &terminators {
let t_node = &model.nodes[t_id];
let Ok(ax) = detect_contracted_score_axis(model, t_node, stream_sym) else {
contracted_ok = false;
break;
};
if let Some(prev) = contracted_axis
&& prev != ax
{
contracted_ok = false;
break;
}
contracted_axis = Some(ax);
}
let Some(contracted_axis) = (if contracted_ok { contracted_axis } else { None }) else {
continue;
};
sections.push(QuadraticSection { section, initiators, terminators, mask, contracted_axis });
}
Ok(sections)
}
fn detect_contracted_score_axis(
model: &TypedModel,
terminator: &TypedNode,
stream_sym: &Symbol,
) -> TractResult<usize> {
let input_fact = model.outlet_fact(terminator.inputs[0])?;
let streaming_axes = streaming_positions(input_fact, stream_sym);
ensure!(
streaming_axes.len() == 2,
"Terminator score input has {} streaming axes, expected 2",
streaming_axes.len()
);
let score_rank = input_fact.rank();
let rank_diff = score_rank
.checked_sub(2)
.ok_or_else(|| format_err!("Terminator score input rank {score_rank} < 2; expected ≥ 2"))?;
let to_mask_frame = |score_axis: usize| -> TractResult<usize> {
score_axis.checked_sub(rank_diff).ok_or_else(|| {
format_err!(
"Terminator score axis {score_axis} doesn't map to mask frame \
(rank_diff={rank_diff})"
)
})
};
if let Some(reduce) = terminator.op_as::<Reduce>() {
for &ax in &streaming_axes {
if reduce.axes.contains(&ax) {
return to_mask_frame(ax);
}
}
bail!("Reduce terminator doesn't reduce a streaming axis of the score input");
}
if let Some(einsum) = terminator.op_as::<EinSum>() {
for &ax in &streaming_axes {
let mapped = einsum.axes.track_axis((InOut::In(0), ax), InOut::Out(0))?;
if mapped.is_none() {
return to_mask_frame(ax);
}
}
bail!("EinSum terminator doesn't contract any streaming axis of input 0");
}
bail!("Unsupported terminator op for contracted-axis detection: {}", terminator.op.name())
}
fn decode_mask(expr: &TDim, streaming_axes: &[usize]) -> Option<MaskForm> {
if streaming_axes.len() != 2 {
return None;
}
let want: BTreeSet<usize> = streaming_axes.iter().copied().collect();
if let TDim::Eq(lhs, rhs) = expr {
let (axis_a, k_a) = decode_coord_div(lhs)?;
let (axis_b, k_b) = decode_coord_div(rhs)?;
if k_a != k_b {
return None;
}
let got: BTreeSet<usize> = [axis_a, axis_b].into_iter().collect();
if want != got {
return None;
}
return Some(MaskForm { chunk_size: k_a as i64, lower: 0, upper: 0, axis_a, axis_b });
}
if let TDim::Mul(terms) = expr
&& terms.len() == 2
{
for (a, b) in [(&terms[0], &terms[1]), (&terms[1], &terms[0])] {
if let Some(form) = decode_banded_terms(a, b)
&& want == [form.axis_a, form.axis_b].into_iter().collect()
{
return Some(form);
}
}
}
None
}
fn decode_banded_terms(upper_term: &TDim, lower_term: &TDim) -> Option<MaskForm> {
let TDim::Ge(u_val, d_upper) = upper_term else {
return None;
};
let TDim::Val(upper) = **u_val else {
return None;
};
let TDim::Ge(d_lower, l_val) = lower_term else {
return None;
};
let TDim::Val(lower) = **l_val else {
return None;
};
if d_lower != d_upper {
return None;
}
let (axis_a, axis_b, k) = decode_diff(d_lower)?;
Some(MaskForm { chunk_size: k as i64, lower, upper, axis_a, axis_b })
}
fn decode_diff(expr: &TDim) -> Option<(usize, usize, u64)> {
let TDim::Add(terms) = expr else {
return None;
};
if terms.len() != 2 {
return None;
}
for (pos, neg) in [(&terms[0], &terms[1]), (&terms[1], &terms[0])] {
let Some((axis_a, k_a)) = decode_coord_div(pos) else {
continue;
};
let TDim::MulInt(-1, neg_inner) = neg else {
continue;
};
let Some((axis_b, k_b)) = decode_coord_div(neg_inner) else {
continue;
};
if k_a == k_b {
return Some((axis_a, axis_b, k_a));
}
}
None
}
fn decode_coord_div(expr: &TDim) -> Option<(usize, u64)> {
let TDim::Div(num, k) = expr else {
return None;
};
let TDim::Sym(sym) = num.as_ref() else {
return None;
};
let axis = tract_core::ops::logic::sym_to_coord_axis(sym)?;
Some((axis, *k))
}
fn build_section_patch(
model: &TypedModel,
sec: &QuadraticSection,
chunk_sym: &Symbol,
k: i64,
) -> TractResult<TypedModelPatch> {
ensure!(sec.mask.lower <= 0);
ensure!(sec.mask.upper >= 0);
let mut patch = TypedModelPatch::default();
let mut chunked: HashMap<OutletId, OutletId> = HashMap::default();
let mut already_wired: BTreeSet<usize> = BTreeSet::new();
let mut shunts: Vec<(OutletId, OutletId)> = vec![];
for &nid in &sec.initiators {
let einsum_node = &model.nodes[nid];
if !einsum_node.op_is::<EinSum>() {
continue;
}
let Some(dg_id) = section_only_diag_gather_consumer(model, einsum_node, sec) else {
continue;
};
let dg_node = &model.nodes[dg_id];
let dg_in_fact = model.outlet_fact(dg_node.inputs[0])?;
let dg_in_streaming = streaming_positions(dg_in_fact, chunk_sym);
if dg_in_streaming.len() != 1 {
bail!(
"EinSum+DiagGather initiator: DG input must have a single streaming axis, got {dg_in_streaming:?}"
);
}
let dg_op = dg_node.op_as::<DiagGather>().unwrap();
let dg_chunked = wire_initiator_diag_gather(
&mut patch,
model,
dg_node,
dg_op,
&sec.mask,
sec.contracted_axis,
chunk_sym,
k,
)?;
chunked.insert(OutletId::new(nid, 0), dg_chunked);
chunked.insert(OutletId::new(dg_id, 0), dg_chunked);
already_wired.insert(nid);
already_wired.insert(dg_id);
}
for &nid in &sec.initiators {
if already_wired.contains(&nid) {
continue;
}
let node = &model.nodes[nid];
let out = if node.outputs[0].fact.uniform_tdim.is_some() {
wire_uniform_tdim_initiator(
&mut patch,
model,
node,
&sec.mask,
sec.contracted_axis,
chunk_sym,
k,
)?
} else {
wire_initiator(&mut patch, model, node, &sec.mask, sec.contracted_axis, chunk_sym, k)?
};
chunked.insert(OutletId::new(nid, 0), out);
}
ensure!(!chunked.is_empty());
for &nid in &model.eval_order()? {
if !sec.section.contains(&nid) {
continue;
}
if sec.initiators.contains(&nid) {
continue;
}
if already_wired.contains(&nid) {
continue;
}
let node = &model.nodes[nid];
let out = wire_body(
&mut patch,
model,
node,
&sec.mask,
sec.contracted_axis,
&chunked,
chunk_sym,
k,
)?;
chunked.insert(OutletId::new(nid, 0), out);
}
for &nid in &sec.terminators {
let node = &model.nodes[nid];
let (boundary, chunked_form) = wire_terminator(
&mut patch,
model,
node,
&chunked,
&sec.mask,
sec.contracted_axis,
chunk_sym,
k,
)?;
shunts.push((boundary, chunked_form));
}
for (boundary, chunked_form) in shunts {
let merged = wire_merge_reshape(
&mut patch,
&model.nodes[boundary.node].name,
chunked_form,
chunk_sym,
k,
)?;
let merged = wire_affine_tail_pad(&mut patch, model, boundary, merged, chunk_sym, k)?;
patch.shunt_outside(model, boundary, merged)?;
}
Ok(patch)
}
fn wire_affine_tail_pad(
patch: &mut TypedModelPatch,
model: &TypedModel,
boundary: OutletId,
merged: OutletId,
chunk_sym: &Symbol,
k: i64,
) -> TractResult<OutletId> {
let boundary_fact = model.outlet_fact(boundary)?;
let merged_fact = patch.outlet_fact(merged)?.clone();
if boundary_fact.shape.len() != merged_fact.shape.len() {
return Ok(merged);
}
let mut pad_axis: Option<(usize, i64)> = None;
for (axis, (b, m)) in boundary_fact.shape.iter().zip(merged_fact.shape.iter()).enumerate() {
if b == m {
continue;
}
let b_off = affine_chunk_offset(b, chunk_sym, k);
let m_off = affine_chunk_offset(m, chunk_sym, k);
match (b_off, m_off) {
(Some(bc), Some(0)) if bc > 0 => {
if pad_axis.is_some() {
return Ok(merged);
}
pad_axis = Some((axis, bc));
}
_ => return Ok(merged),
}
}
let Some((axis, c)) = pad_axis else {
return Ok(merged);
};
let mut pads = vec![(0usize, 0usize); merged_fact.shape.len()];
pads[axis] = (0, c as usize);
let pad_value = Tensor::zero_scalar_dt(merged_fact.datum_type)?.into_arc_tensor();
let pad_op = tract_core::ops::array::Pad {
pads,
mode: tract_core::ops::array::PadMode::Constant(pad_value),
};
let name = format!("{}.affine_tail_pad", &model.nodes[boundary.node].name);
Ok(patch.wire_node(name, pad_op, &[merged])?[0])
}
fn wire_initiator(
patch: &mut TypedModelPatch,
model: &TypedModel,
node: &TypedNode,
mask: &MaskForm,
contracted_axis: usize,
chunk_sym: &Symbol,
k: i64,
) -> TractResult<OutletId> {
if let Some(op) = node.op_as::<EinSum>() {
return wire_initiator_einsum(patch, model, node, op, mask, contracted_axis, chunk_sym, k);
}
if node.op_as::<tract_core::ops::array::MultiBroadcastTo>().is_some() {
let in_fact = model.outlet_fact(node.inputs[0])?;
if streaming_positions(in_fact, chunk_sym).is_empty() {
return wire_initiator_multibroadcastto(patch, model, node, chunk_sym);
} else {
return wire_initiator_multibroadcastto_streaming(
patch,
model,
node,
mask,
contracted_axis,
chunk_sym,
k,
);
}
}
if let Some(op) = node.op_as::<DiagGather>() {
return wire_initiator_diag_gather(
patch,
model,
node,
op,
mask,
contracted_axis,
chunk_sym,
k,
);
}
if let Some(op) = node.op_as::<TypedBinOp>() {
return wire_initiator_typed_binop(
patch,
model,
node,
op,
mask,
contracted_axis,
chunk_sym,
k,
);
}
bail!("Unsupported initiator {node}")
}
fn wire_initiator_diag_gather(
patch: &mut TypedModelPatch,
model: &TypedModel,
node: &TypedNode,
op: &DiagGather,
mask: &MaskForm,
contracted_axis: usize,
chunk_sym: &Symbol,
k: i64,
) -> TractResult<OutletId> {
let out_streaming = streaming_positions(&node.outputs[0].fact, chunk_sym);
ensure!(
out_streaming.len() == 2 && out_streaming[1] == out_streaming[0] + 1,
"Initiator DiagGather output must have two contiguous streaming axes, \
got {out_streaming:?}"
);
ensure!(node.inputs.len() == 1, "DiagGather has 1 input, got {}", node.inputs.len());
let in_fact = model.outlet_fact(node.inputs[0])?;
let in_streaming = streaming_positions(in_fact, chunk_sym);
ensure!(
in_streaming.len() == 1,
"Initiator DiagGather input must have exactly one streaming axis, got {in_streaming:?}"
);
let stream_axis = in_streaming[0];
let tapped = patch.tap_model(model, node.inputs[0])?;
let in_fact_patch = patch.outlet_fact(tapped)?.clone();
let chunked = wire_chunk_split(patch, &node.name, tapped, stream_axis, chunk_sym, k)?;
let r_axis = in_fact_patch.shape.last().context("DiagGather input has no last axis")?;
let r = r_axis.to_i64().context("DiagGather R axis must be a constant integer")?;
let centre = op.offset.to_i64().ok().unwrap_or((r - 1) / 2);
let l = mask.upper - mask.lower;
let w = (l + 1) * k;
let window_start = window_start_for(mask, contracted_axis);
let chunked_offset = centre + window_start * k;
let chunked_op = DiagGather { offset: chunked_offset.to_dim(), out_len: w.to_dim() };
Ok(patch.wire_node(format!("{}.blockified", node.name), chunked_op, &[chunked])?[0])
}
fn wire_initiator_typed_binop(
patch: &mut TypedModelPatch,
model: &TypedModel,
node: &TypedNode,
op: &TypedBinOp,
mask: &MaskForm,
contracted_axis: usize,
chunk_sym: &Symbol,
k: i64,
) -> TractResult<OutletId> {
let out_streaming_axes = streaming_positions(&node.outputs[0].fact, chunk_sym);
ensure!(
out_streaming_axes.len() == 2 && out_streaming_axes[1] == out_streaming_axes[0] + 1,
"Initiator TypedBinOp output must have two contiguous streaming axes"
);
let chunks_target_axis = out_streaming_axes[0];
let score_rank = node.outputs[0].fact.rank();
let rank_diff = score_rank.checked_sub(2).ok_or_else(|| {
format_err!("Score rank {score_rank} < 2; cannot translate to mask frame")
})?;
let input_facts: TVec<&TypedFact> =
node.inputs.iter().map(|inp| model.outlet_fact(*inp)).collect::<TractResult<_>>()?;
let output_facts: TVec<&TypedFact> = node.outputs.iter().map(|o| &o.fact).collect();
let mapping = op.axes_mapping(&input_facts, &output_facts)?;
let mut chunked_inputs: TVec<OutletId> = tvec!();
for (ix, &input) in node.inputs.iter().enumerate() {
let in_fact = model.outlet_fact(input)?;
let streaming = streaming_positions(in_fact, chunk_sym);
ensure!(
streaming.len() <= 1,
"Initiator TypedBinOp input {ix} has {} streaming axes, expected 0 or 1",
streaming.len()
);
let tapped = patch.tap_model(model, input)?;
let wire = if streaming.is_empty() {
let target_rank = score_rank + 1;
bump_rank_to(patch, &node.name, ix, tapped, target_rank)?
} else {
let stream_axis = streaming[0];
let split = wire_chunk_split(
patch,
&format!("{}.{ix}", node.name),
tapped,
stream_axis,
chunk_sym,
k,
)?;
let tracked_in_score = mapping
.track_axis((InOut::In(ix), stream_axis), InOut::Out(0))?
.ok_or_else(|| {
format_err!(
"TypedBinOp stream axis on input {ix} doesn't track to a unique output axis"
)
})?;
let tracked_in_mask = tracked_in_score.checked_sub(rank_diff).ok_or_else(|| {
format_err!(
"Tracked score axis {tracked_in_score} doesn't map to mask frame \
(rank_diff={rank_diff})"
)
})?;
let needs_window = !mask.is_block_diag() && tracked_in_mask == contracted_axis;
let after_window = if needs_window {
let window: usize = (mask.upper - mask.lower + 1) as usize;
let start = window_start_for(mask, contracted_axis);
let dt = patch.outlet_fact(split)?.datum_type;
let absorbing = op.0.absorbing_element().ok_or_else(|| {
format_err!(
"TypedBinOp '{}' has no absorbing_element; cannot safely window-pad \
a section-initiator input",
op.0.name()
)
})?;
let pad_value = tensor0(absorbing).cast_to_dt(dt)?.into_owned().into_arc_tensor();
let windowed = patch.wire_node(
format!("{}.{ix}.window", node.name),
tract_pulse_opl::ops::WindowOnAxis {
axis: stream_axis,
window,
start,
pad_value,
},
&[split],
)?[0];
let from = tvec!(window.to_dim(), k.to_dim());
let to = tvec!(((window as i64) * k).to_dim());
patch.wire_node(
format!("{}.{ix}.window_flat", node.name),
AxisOp::Reshape(stream_axis + 1, from, to),
&[windowed],
)?[0]
} else {
split
};
if stream_axis != chunks_target_axis {
patch.wire_node(
format!("{}.{ix}.move_chunks", node.name),
AxisOp::Move(stream_axis, chunks_target_axis),
&[after_window],
)?[0]
} else {
after_window
}
};
chunked_inputs.push(wire);
}
Ok(patch.wire_node(format!("{}.blockified", node.name), op.clone(), &chunked_inputs)?[0])
}
fn wire_initiator_multibroadcastto(
patch: &mut TypedModelPatch,
model: &TypedModel,
node: &TypedNode,
chunk_sym: &Symbol,
) -> TractResult<OutletId> {
ensure!(node.inputs.len() == 1, "MultiBroadcastTo expects 1 input, got {}", node.inputs.len());
let input = node.inputs[0];
let in_fact = model.outlet_fact(input)?;
ensure!(
streaming_positions(in_fact, chunk_sym).is_empty(),
"MultiBroadcastTo initiator with streaming input not supported (input has \
{} streaming axes)",
streaming_positions(in_fact, chunk_sym).len(),
);
let target_rank = node.outputs[0].fact.rank() + 1;
let mut wire = patch.tap_model(model, input)?;
let mut step = 0;
while patch.outlet_fact(wire)?.rank() < target_rank {
wire =
patch.wire_node(format!("{}.bump_rank.{step}", node.name), AxisOp::Add(0), &[wire])?[0];
step += 1;
}
Ok(wire)
}
fn wire_initiator_multibroadcastto_streaming(
patch: &mut TypedModelPatch,
model: &TypedModel,
node: &TypedNode,
mask: &MaskForm,
contracted_axis: usize,
chunk_sym: &Symbol,
k: i64,
) -> TractResult<OutletId> {
ensure!(node.inputs.len() == 1, "MultiBroadcastTo expects 1 input, got {}", node.inputs.len());
let input = node.inputs[0];
let in_fact = model.outlet_fact(input)?;
let in_streaming = streaming_positions(in_fact, chunk_sym);
ensure!(
in_streaming.len() == 1,
"MultiBroadcastTo streaming initiator: input must have exactly one streaming axis, \
got {in_streaming:?}"
);
let in_stream_axis = in_streaming[0];
let out_streaming = streaming_positions(&node.outputs[0].fact, chunk_sym);
ensure!(
out_streaming.len() == 2 && out_streaming[1] == out_streaming[0] + 1,
"Initiator MultiBroadcastTo output must have two contiguous streaming axes, \
got {out_streaming:?}"
);
let bcast_axis = if out_streaming[0] == in_stream_axis {
out_streaming[1]
} else if out_streaming[1] == in_stream_axis {
out_streaming[0]
} else {
bail!(
"MultiBroadcastTo streaming initiator: input stream axis {in_stream_axis} not in \
output streaming axes {out_streaming:?}"
);
};
ensure!(
in_fact.shape[bcast_axis].is_one(),
"MultiBroadcastTo streaming initiator: broadcast-from axis {bcast_axis} must be 1, \
got {}",
in_fact.shape[bcast_axis]
);
let score_rank = node.outputs[0].fact.rank();
let rank_diff = score_rank.checked_sub(2).ok_or_else(|| {
format_err!("Score rank {score_rank} < 2; cannot translate to mask frame")
})?;
let tracked_in_mask = in_stream_axis.checked_sub(rank_diff).ok_or_else(|| {
format_err!(
"Tracked score axis {in_stream_axis} doesn't map to mask frame (rank_diff={rank_diff})"
)
})?;
ensure!(
tracked_in_mask == contracted_axis,
"MultiBroadcastTo streaming initiator: input stream axis must track to the \
contracted axis ({contracted_axis}), got {tracked_in_mask}"
);
let tapped = patch.tap_model(model, input)?;
let split = wire_chunk_split(patch, &node.name, tapped, in_stream_axis, chunk_sym, k)?;
let bcast_axis_post_split =
if bcast_axis > in_stream_axis { bcast_axis + 1 } else { bcast_axis };
let window: usize = (mask.upper - mask.lower + 1) as usize;
let start = window_start_for(mask, contracted_axis);
let dt = patch.outlet_fact(split)?.datum_type;
let pad_value = Tensor::zero_scalar_dt(dt)?.into_arc_tensor();
let windowed = patch.wire_node(
format!("{}.window", node.name),
tract_pulse_opl::ops::WindowOnAxis { axis: in_stream_axis, window, start, pad_value },
&[split],
)?[0];
let bcast_axis_post_window = if bcast_axis_post_split > in_stream_axis {
bcast_axis_post_split + 1
} else {
bcast_axis_post_split
};
let from = tvec!(window.to_dim(), k.to_dim());
let to = tvec!(((window as i64) * k).to_dim());
let flat = patch.wire_node(
format!("{}.window_flat", node.name),
AxisOp::Reshape(in_stream_axis + 1, from, to),
&[windowed],
)?[0];
let bcast_axis_post_flat = if bcast_axis_post_window > in_stream_axis + 1 {
bcast_axis_post_window - 1
} else {
bcast_axis_post_window
};
let chunks_target_axis = out_streaming[0];
let mut chunks_axis = in_stream_axis;
let mut bcast_axis_now = bcast_axis_post_flat;
let mut wire = flat;
if chunks_axis != chunks_target_axis {
wire = patch.wire_node(
format!("{}.move_chunks", node.name),
AxisOp::Move(chunks_axis, chunks_target_axis),
&[wire],
)?[0];
if chunks_target_axis < chunks_axis {
if bcast_axis_now >= chunks_target_axis && bcast_axis_now < chunks_axis {
bcast_axis_now += 1;
}
} else if bcast_axis_now > chunks_axis && bcast_axis_now <= chunks_target_axis {
bcast_axis_now = bcast_axis_now.saturating_sub(1);
}
chunks_axis = chunks_target_axis;
let _ = chunks_axis;
}
let mut target_shape: TVec<TDim> = patch.outlet_fact(wire)?.shape.to_tvec();
target_shape[bcast_axis_now] = k.to_dim();
let bcast = tract_core::ops::array::MultiBroadcastTo { shape: target_shape.into() };
Ok(patch.wire_node(format!("{}.blockified", node.name), bcast, &[wire])?[0])
}
fn wire_uniform_tdim_initiator(
patch: &mut TypedModelPatch,
model: &TypedModel,
node: &TypedNode,
mask: &MaskForm,
contracted_axis: usize,
chunk_sym: &Symbol,
k: i64,
) -> TractResult<OutletId> {
let mut chunked_inputs: TVec<OutletId> = tvec!();
for (ix, &input) in node.inputs.iter().enumerate() {
let chunked = chunkify_uniform_tdim_input(
patch,
model,
input,
&format!("{}.in{ix}", node.name),
mask,
contracted_axis,
chunk_sym,
k,
)?;
chunked_inputs.push(chunked);
}
let mut out =
patch.wire_node(format!("{}.blockified", node.name), node.op.clone(), &chunked_inputs)?[0];
let source_dt = node.outputs[0].fact.datum_type;
let cur_dt = patch.outlet_fact(out)?.datum_type;
if cur_dt != source_dt {
out = patch.wire_node(
format!("{}.blockified.cast_back", node.name),
tract_core::ops::cast::cast(source_dt),
&[out],
)?[0];
}
Ok(out)
}
fn chunkify_uniform_tdim_input(
patch: &mut TypedModelPatch,
model: &TypedModel,
input: OutletId,
name_prefix: &str,
mask: &MaskForm,
contracted_axis: usize,
chunk_sym: &Symbol,
k: i64,
) -> TractResult<OutletId> {
let in_fact = model.outlet_fact(input)?;
let positions = streaming_positions(in_fact, chunk_sym);
ensure!(
positions.len() == 1,
"uniform_tdim initiator input must have exactly one streaming axis (got {})",
positions.len(),
);
let stream_axis = positions[0];
let tapped = patch.tap_model(model, input)?;
let mut wire = tapped;
if patch.outlet_fact(wire)?.datum_type == TDim::datum_type() {
wire = patch.wire_node(
format!("{name_prefix}.cast_i64"),
tract_core::ops::cast::cast(i64::datum_type()),
&[wire],
)?[0];
}
let dt = patch.outlet_fact(wire)?.datum_type;
wire = wire_chunk_split(patch, name_prefix, wire, stream_axis, chunk_sym, k)?;
if stream_axis != 0 {
wire = patch.wire_node(
format!("{name_prefix}.move_chunk"),
AxisOp::Move(stream_axis, 0),
&[wire],
)?[0];
}
let needs_window = !mask.is_block_diag() && stream_axis == contracted_axis;
if needs_window {
let window_size: usize = (mask.upper - mask.lower + 1) as usize;
let start = window_start_for(mask, contracted_axis);
let sentinel = sentinel_pad_value(dt)?.into_arc_tensor();
wire = patch.wire_node(
format!("{name_prefix}.window"),
tract_pulse_opl::ops::WindowOnAxis {
axis: 0,
window: window_size,
start,
pad_value: sentinel,
},
&[wire],
)?[0];
let post_window = patch.outlet_fact(wire)?.clone();
let rank_after = post_window.rank();
let from: TVec<TDim> = (1..rank_after).map(|i| post_window.shape[i].clone()).collect();
let within_slice_idx = stream_axis;
let mut to: TVec<TDim> = tvec!();
for (i, dim) in from.iter().enumerate() {
if i == 0 {
continue;
}
if i == within_slice_idx + 1 {
let merged = from[0].clone() * dim.clone();
to.push(merged);
} else {
to.push(dim.clone());
}
}
wire = patch.wire_node(
format!("{name_prefix}.flatten_window"),
AxisOp::Reshape(1, from, to),
&[wire],
)?[0];
}
Ok(wire)
}
fn sentinel_pad_value(dt: DatumType) -> TractResult<Tensor> {
if dt == bool::datum_type() {
bail!("uniform_tdim wire of bool dtype not expected as initiator-side input");
}
Ok(tensor0((i32::MAX / 4) as i64).cast_to_dt(dt)?.into_owned())
}
fn wire_body(
patch: &mut TypedModelPatch,
model: &TypedModel,
node: &TypedNode,
_mask: &MaskForm,
_contracted_axis: usize,
chunked: &HashMap<OutletId, OutletId>,
chunk_sym: &Symbol,
_k: i64,
) -> TractResult<OutletId> {
let n = node.inputs.len();
let mut new_inputs: TVec<Option<OutletId>> = tvec![None; n];
let mut chunk_input_axes: Vec<(usize, usize)> = vec![];
let mut chunked_rank: Option<usize> = None;
for (slot, &input) in node.inputs.iter().enumerate() {
if let Some(&c) = chunked.get(&input) {
let cf = patch.outlet_fact(c)?;
let positions = streaming_positions(cf, chunk_sym);
ensure!(
positions.len() <= 1,
"Body op {node}: chunked input slot {slot} has {} streaming axes, expected ≤ 1",
positions.len()
);
if let Some(&ax) = positions.first() {
chunk_input_axes.push((slot, ax));
}
chunked_rank = Some(cf.rank().max(chunked_rank.unwrap_or(0)));
new_inputs[slot] = Some(c);
}
}
let chunked_rank = chunked_rank.ok_or_else(|| {
format_err!("Body op {node} has no chunked input — at least one is required")
})?;
for (slot, &input) in node.inputs.iter().enumerate() {
if new_inputs[slot].is_some() {
continue;
}
let tapped = patch.tap_model(model, input)?;
let bumped = bump_rank_to(patch, &node.name, slot, tapped, chunked_rank)?;
new_inputs[slot] = Some(bumped);
}
let new_inputs: TVec<OutletId> = new_inputs.into_iter().map(|o| o.unwrap()).collect();
let input_facts: TVec<TypedFact> =
new_inputs.iter().map(|o| patch.outlet_fact(*o).cloned()).collect::<TractResult<_>>()?;
let in_refs: TVec<&TypedFact> = input_facts.iter().collect();
let output_facts = node.op.output_facts(&in_refs)?;
let out_refs: TVec<&TypedFact> = output_facts.iter().collect();
let am = node.op.axes_mapping(&in_refs, &out_refs)?;
for &(slot, axis) in &chunk_input_axes {
let tracked = am.track_axis((InOut::In(slot), axis), InOut::Out(0))?;
ensure!(
tracked.is_some(),
"Body op {node} doesn't preserve the chunk axis (input slot {slot}, axis {axis}) \
through to the output — its axes_mapping disconnects it"
);
}
let chunk_pos = chunk_input_axes.iter().map(|&(_, ax)| ax).next();
if let Some(cp) = chunk_pos {
ensure!(
chunk_input_axes.iter().all(|&(_, ax)| ax == cp),
"Body op {node}: chunked inputs disagree on chunk axis position {chunk_input_axes:?}"
);
}
let chunked_op = translate_body_op_axes(node.op.as_ref(), chunk_pos);
Ok(patch.wire_node(&*node.name, chunked_op, &new_inputs)?[0])
}
fn translate_body_op_axes(op: &dyn TypedOp, chunk_pos: Option<usize>) -> Box<dyn TypedOp> {
use tract_core::ops::nn::{Softmax, SoftmaxKind};
let shift = |a: usize| match chunk_pos {
Some(cp) => chunked_axis_index(a, cp),
None => a,
};
if let Some(softmax) = op.downcast_ref::<Softmax>() {
let new_axes: TVec<usize> = softmax.axes.iter().map(|&a| shift(a)).collect();
let new_softmax = match &softmax.kind {
SoftmaxKind::Softmax(exp) => {
Softmax::new(new_axes, softmax.quant_output_dt, SoftmaxKind::Softmax(*exp))
}
SoftmaxKind::LogSoftmax => {
Softmax::new(new_axes, softmax.quant_output_dt, SoftmaxKind::LogSoftmax)
}
};
return Box::new(new_softmax);
}
if let Some(ax_op) = op.downcast_ref::<AxisOp>() {
let add_shift = |a: usize| match chunk_pos {
Some(cp) if a > cp => a + 1,
_ => a,
};
let translated = match ax_op {
AxisOp::Move(from, to) => AxisOp::Move(shift(*from), shift(*to)),
AxisOp::Add(at) => AxisOp::Add(add_shift(*at)),
AxisOp::Rm(at) => AxisOp::Rm(shift(*at)),
other => other.clone(),
};
return Box::new(translated);
}
tract_core::dyn_clone::clone_box(op)
}
fn bump_rank_to(
patch: &mut TypedModelPatch,
node_name: &str,
slot: usize,
mut outlet: OutletId,
target: usize,
) -> TractResult<OutletId> {
let mut rank = patch.outlet_fact(outlet)?.rank();
let mut step = 0;
while rank < target {
outlet = patch.wire_node(
format!("{node_name}.bump_rank.{slot}.{step}"),
AxisOp::Add(0),
&[outlet],
)?[0];
rank += 1;
step += 1;
}
Ok(outlet)
}
fn wire_terminator(
patch: &mut TypedModelPatch,
model: &TypedModel,
node: &TypedNode,
chunked: &HashMap<OutletId, OutletId>,
mask: &MaskForm,
contracted_axis: usize,
chunk_sym: &Symbol,
k: i64,
) -> TractResult<(OutletId, OutletId)> {
if let Some(op) = node.op_as::<Reduce>() {
return wire_terminator_reduce(patch, model, node, op, chunked);
}
if let Some(op) = node.op_as::<EinSum>() {
return wire_terminator_einsum(
patch,
model,
node,
op,
chunked,
mask,
contracted_axis,
chunk_sym,
k,
);
}
bail!("Unsupported operator {node}")
}
fn wire_initiator_einsum(
patch: &mut TypedModelPatch,
model: &TypedModel,
node: &TypedNode,
op: &EinSum,
mask: &MaskForm,
contracted_axis: usize,
chunk_sym: &Symbol,
k: i64,
) -> TractResult<OutletId> {
let out_streaming_axes = streaming_positions(&node.outputs[0].fact, chunk_sym);
ensure!(
out_streaming_axes.len() == 2 && out_streaming_axes[1] == out_streaming_axes[0] + 1,
"Initiator EinSum output must have two contiguous streaming axes"
);
let score_rank = node.outputs[0].fact.rank();
let rank_diff = score_rank.checked_sub(2).ok_or_else(|| {
format_err!("Score rank {score_rank} < 2; cannot translate to mask frame")
})?;
let mut in_streaming_axes: TVec<usize> = tvec!();
for &input in &node.inputs {
let positions = streaming_positions(model.outlet_fact(input)?, chunk_sym);
ensure!(
positions.len() == 1,
"Initiator EinSum input must have exactly one streaming axis"
);
in_streaming_axes.push(positions[0]);
}
let mut chunked_inputs: TVec<OutletId> = tvec!();
for (ix, (&input, &stream_axis)) in node.inputs.iter().zip(in_streaming_axes.iter()).enumerate()
{
let tapped = patch.tap_model(model, input)?;
let chunked = wire_chunk_split(
patch,
&format!("{}.{ix}", node.name),
tapped,
stream_axis,
chunk_sym,
k,
)?;
let tracked_in_score =
op.axes.track_axis((InOut::In(ix), stream_axis), InOut::Out(0))?.ok_or_else(|| {
format_err!(
"EinSum stream axis on input {ix} doesn't track to a unique output axis"
)
})?;
let tracked_in_mask = tracked_in_score.checked_sub(rank_diff).ok_or_else(|| {
format_err!(
"Tracked score axis {tracked_in_score} doesn't map to mask frame \
(rank_diff={rank_diff})"
)
})?;
let chunked = wrap_with_window_if_needed(
patch,
chunked,
stream_axis,
tracked_in_mask,
&format!("{}.{ix}", node.name),
mask,
contracted_axis,
k,
)?;
chunked_inputs.push(chunked);
}
let in_starts: Vec<Option<usize>> = in_streaming_axes.iter().map(|&p| Some(p)).collect();
let chunked_op = chunkify_einsum(op, &in_starts, Some(out_streaming_axes[0]))?;
Ok(patch.wire_node(format!("{}.blockified", node.name), chunked_op, &chunked_inputs)?[0])
}
fn wrap_with_window_if_needed(
patch: &mut TypedModelPatch,
chunked: OutletId,
stream_axis: usize,
score_axis: usize,
name_prefix: &str,
mask: &MaskForm,
contracted_axis: usize,
k: i64,
) -> TractResult<OutletId> {
if mask.is_block_diag() || score_axis != contracted_axis {
return Ok(chunked);
}
let window: usize = (mask.upper - mask.lower + 1) as usize;
let start = window_start_for(mask, contracted_axis);
let dt = patch.outlet_fact(chunked)?.datum_type;
let pad_value = Tensor::zero_scalar_dt(dt)?.into_arc_tensor();
let windowed = patch.wire_node(
format!("{name_prefix}.window"),
tract_pulse_opl::ops::WindowOnAxis { axis: stream_axis, window, start, pad_value },
&[chunked],
)?[0];
let from = tvec!(window.to_dim(), k.to_dim());
let to = tvec!(((window as i64) * k).to_dim());
let flatten = AxisOp::Reshape(stream_axis + 1, from, to);
Ok(patch.wire_node(format!("{name_prefix}.window_flat"), flatten, &[windowed])?[0])
}
fn window_start_for(mask: &MaskForm, contracted_axis: usize) -> i64 {
if contracted_axis == mask.axis_a { mask.lower } else { -mask.upper }
}
fn wire_terminator_reduce(
patch: &mut TypedModelPatch,
model: &TypedModel,
node: &TypedNode,
op: &Reduce,
chunked: &HashMap<OutletId, OutletId>,
) -> TractResult<(OutletId, OutletId)> {
ensure!(op.reducer == Reducer::Sum && op.axes.len() == 1);
let chunked_input = chunked[&node.inputs[0]];
let in_fact = model.outlet_fact(node.inputs[0])?;
let stream_sym = first_streaming_symbol(in_fact)?;
let in_streaming = streaming_positions(in_fact, &stream_sym);
ensure!(!in_streaming.is_empty());
let chunk_pos = in_streaming[0];
let new_axis = chunked_axis_index(op.axes[0], chunk_pos);
let new_reduce = Reduce { axes: tvec!(new_axis), reducer: op.reducer };
let chunked_term =
patch.wire_node(format!("{}.blockified", node.name), new_reduce, &[chunked_input])?[0];
let term_consumers = model.outlet_successors(OutletId::new(node.id, 0));
if term_consumers.len() == 1 {
let consumer = &model.nodes[term_consumers[0].node];
if let Some(AxisOp::Rm(axis)) = consumer.op_as::<AxisOp>()
&& *axis == op.axes[0]
{
let new_axis = chunked_axis_index(op.axes[0], chunk_pos);
let chunked_rm = patch.wire_node(
format!("{}.blockified", consumer.name),
AxisOp::Rm(new_axis),
&[chunked_term],
)?[0];
return Ok((OutletId::new(consumer.id, 0), chunked_rm));
}
}
Ok((OutletId::new(node.id, 0), chunked_term))
}
fn wire_terminator_einsum(
patch: &mut TypedModelPatch,
model: &TypedModel,
node: &TypedNode,
op: &EinSum,
chunked: &HashMap<OutletId, OutletId>,
mask: &MaskForm,
contracted_axis: usize,
chunk_sym: &Symbol,
k: i64,
) -> TractResult<(OutletId, OutletId)> {
let score_rank = model.outlet_fact(node.inputs[0])?.rank();
let rank_diff = score_rank.checked_sub(2).ok_or_else(|| {
format_err!("Terminator score rank {score_rank} < 2; cannot translate to mask frame")
})?;
let mut chunked_inputs: TVec<OutletId> = tvec!();
let mut input_starts: Vec<Option<usize>> = vec![];
for (slot, &input) in node.inputs.iter().enumerate() {
let positions = streaming_positions(model.outlet_fact(input)?, chunk_sym);
if let Some(&already_chunked) = chunked.get(&input) {
chunked_inputs.push(already_chunked);
input_starts.push(positions.first().copied());
} else if positions.len() == 1 {
let tapped = patch.tap_model(model, input)?;
let in_fact = patch.outlet_fact(tapped)?.clone();
let stream_axis = in_fact
.shape
.iter()
.position(|d| d.symbols().contains(chunk_sym))
.ok_or_else(|| format_err!("auxiliary input lost streaming axis"))?;
let new_chunked = wire_chunk_split(
patch,
&format!("{}.in{slot}", node.name),
tapped,
stream_axis,
chunk_sym,
k,
)?;
let aux_in_score = op.axes.track_axis((InOut::In(slot), stream_axis), InOut::In(0))?;
let new_chunked = if let Some(score_axis) = aux_in_score
&& let Some(mask_axis) = score_axis.checked_sub(rank_diff)
{
wrap_with_window_if_needed(
patch,
new_chunked,
stream_axis,
mask_axis,
&format!("{}.in{slot}", node.name),
mask,
contracted_axis,
k,
)?
} else {
new_chunked
};
chunked_inputs.push(new_chunked);
input_starts.push(Some(positions[0]));
} else if positions.is_empty() {
chunked_inputs.push(patch.tap_model(model, input)?);
input_starts.push(None);
} else {
bail!(
"Blockify: EinSum terminator input {slot} has {} streaming axes (max 2)",
positions.len()
);
}
}
let out_streaming = streaming_positions(&node.outputs[0].fact, chunk_sym);
let chunked_op = chunkify_einsum(op, &input_starts, out_streaming.first().copied())?;
let chunked_term =
patch.wire_node(format!("{}.blockified", node.name), chunked_op, &chunked_inputs)?[0];
Ok((OutletId::new(node.id, 0), chunked_term))
}
fn wire_merge_reshape(
patch: &mut TypedModelPatch,
boundary_name: &str,
chunked_form: OutletId,
chunk_sym: &Symbol,
k: i64,
) -> TractResult<OutletId> {
let chunked_fact = patch.outlet_fact(chunked_form)?.clone();
let chunk_pos = chunked_fact.shape.iter().position(|d| d == &chunk_sym.to_dim());
if let Some(pos) = chunk_pos
&& pos + 1 < chunked_fact.shape.len()
&& chunked_fact.shape[pos + 1] == k.to_dim()
{
let from = tvec!(chunk_sym.to_dim(), k.to_dim());
let to = tvec!(chunk_sym.to_dim() * k);
let reshape = AxisOp::Reshape(pos, from, to);
Ok(patch.wire_node(
format!("{}.blockify_merge", boundary_name),
reshape,
&[chunked_form],
)?[0])
} else {
Ok(chunked_form)
}
}
fn affine_chunk_offset(dim: &TDim, chunk_sym: &Symbol, k: i64) -> Option<i64> {
let target = chunk_sym.to_dim() * k;
let diff = dim.clone() - target;
let c = diff.to_i64().ok()?;
(c >= 0).then_some(c)
}
fn wire_chunk_split(
patch: &mut TypedModelPatch,
name: &str,
input: OutletId,
stream_axis: usize,
chunk_sym: &Symbol,
k: i64,
) -> TractResult<OutletId> {
let in_fact = patch.outlet_fact(input)?.clone();
let dim = in_fact.shape[stream_axis].clone();
let target = chunk_sym.to_dim() * k;
let mut wire = input;
if dim != target
&& let Some(c) = affine_chunk_offset(&dim, chunk_sym, k)
&& c > 0
{
wire = patch.wire_node(
format!("{name}.affine_trim"),
crate::ops::array::AffineChunkTrim {
axis: stream_axis,
typed_trim: c as usize,
target_per_pulse: k as usize,
},
&[wire],
)?[0];
}
let from = tvec!(patch.outlet_fact(wire)?.shape[stream_axis].clone());
let to = tvec!(chunk_sym.to_dim(), k.to_dim());
Ok(patch.wire_node(
format!("{name}.blockify_split"),
AxisOp::Reshape(stream_axis, from, to),
&[wire],
)?[0])
}
fn first_streaming_symbol(fact: &TypedFact) -> TractResult<Symbol> {
fact.shape
.iter()
.find_map(|d| d.symbols().into_iter().next())
.context("No streaming axis found")
}
fn chunked_axis_index(orig_axis: usize, chunk_pos: usize) -> usize {
if orig_axis < chunk_pos { orig_axis } else { orig_axis + 1 }
}
fn chunkify_einsum(
op: &EinSum,
input_streaming_starts: &[Option<usize>],
output_streaming_start: Option<usize>,
) -> TractResult<EinSum> {
let (inputs, outputs) = op.axes.to_strs();
let new_repr = op.axes.available_label();
let insert_at = |s: &String, pos: Option<usize>| -> String {
let Some(p) = pos else {
return s.clone();
};
let mut chars: Vec<char> = s.chars().collect();
chars.insert(p, new_repr);
chars.into_iter().collect()
};
let new_inputs: Vec<String> = inputs
.iter()
.zip(input_streaming_starts.iter())
.map(|(s, &pos)| insert_at(s, pos))
.collect();
let new_outputs: Vec<String> = outputs
.iter()
.enumerate()
.map(|(i, s)| if i == 0 { insert_at(s, output_streaming_start) } else { s.clone() })
.collect();
let new_mapping = AxesMapping::from_strs(&new_inputs, &new_outputs)?;
Ok(EinSum { axes: new_mapping, operating_dt: op.operating_dt, q_params: op.q_params.clone() })
}
#[cfg(test)]
mod tests {
use super::*;
fn coord(scope: &SymbolScope, axis: usize) -> TDim {
TDim::Sym(scope.sym(&format!("🎯{axis}")))
}
fn make_block_diag(scope: &SymbolScope, i: usize, j: usize, k: u64) -> TDim {
TDim::Eq(
Box::new(TDim::Div(Box::new(coord(scope, i)), k)),
Box::new(TDim::Div(Box::new(coord(scope, j)), k)),
)
}
fn make_banded(scope: &SymbolScope, a: usize, b: usize, k: u64, lo: i64, up: i64) -> TDim {
let div_a = TDim::Div(Box::new(coord(scope, a)), k);
let div_b = TDim::Div(Box::new(coord(scope, b)), k);
let diff = (div_a - div_b).reduce();
let ge_upper = TDim::Ge(Box::new(TDim::Val(up)), Box::new(diff.clone())).reduce();
let ge_lower = TDim::Ge(Box::new(diff), Box::new(TDim::Val(lo))).reduce();
TDim::Mul(vec![ge_upper, ge_lower]).reduce()
}
#[test]
fn decode_mask_recognises_block_diag_canonical_form() {
let scope = SymbolScope::default();
let expr = make_block_diag(&scope, 0, 1, 2);
let m = decode_mask(&expr, &[0, 1]).unwrap();
assert_eq!((m.chunk_size, m.lower, m.upper), (2, 0, 0));
}
#[test]
fn decode_mask_recognises_block_diag_arbitrary_chunk_size() {
let scope = SymbolScope::default();
let expr = make_block_diag(&scope, 0, 1, 137);
let m = decode_mask(&expr, &[0, 1]).unwrap();
assert_eq!(m.chunk_size, 137);
}
#[test]
fn decode_mask_recognises_block_diag_swapped_axes() {
let scope = SymbolScope::default();
let expr = make_block_diag(&scope, 1, 0, 2);
let m = decode_mask(&expr, &[0, 1]).unwrap();
assert_eq!(m.chunk_size, 2);
}
#[test]
fn decode_mask_recognises_banded_form() {
let scope = SymbolScope::default();
let expr = make_banded(&scope, 0, 1, 2, 0, 1);
let m = decode_mask(&expr, &[0, 1]).unwrap();
assert_eq!((m.chunk_size, m.lower, m.upper, m.axis_a, m.axis_b), (2, 0, 1, 0, 1));
}
#[test]
fn decode_mask_recognises_banded_form_negative_lower() {
let scope = SymbolScope::default();
let expr = make_banded(&scope, 0, 1, 2, -1, 1);
let m = decode_mask(&expr, &[0, 1]).unwrap();
assert_eq!((m.chunk_size, m.lower, m.upper), (2, -1, 1));
}
#[test]
fn decode_mask_rejects_mismatched_chunk_sizes() {
let scope = SymbolScope::default();
let expr = TDim::Eq(
Box::new(TDim::Div(Box::new(coord(&scope, 0)), 2)),
Box::new(TDim::Div(Box::new(coord(&scope, 1)), 3)),
);
assert_eq!(decode_mask(&expr, &[0, 1]), None);
}
#[test]
fn decode_mask_rejects_non_streaming_axis() {
let scope = SymbolScope::default();
let expr = make_block_diag(&scope, 0, 2, 2);
assert_eq!(decode_mask(&expr, &[0, 1]), None);
}
#[test]
fn decode_mask_rejects_bare_ge() {
let scope = SymbolScope::default();
let expr = TDim::Ge(
Box::new(TDim::Div(Box::new(coord(&scope, 0)), 2)),
Box::new(TDim::Div(Box::new(coord(&scope, 1)), 2)),
);
assert_eq!(decode_mask(&expr, &[0, 1]), None);
}
#[test]
fn decode_banded_probe_canonical_form() {
let scope = SymbolScope::default();
let coord_a = coord(&scope, 0);
let coord_b = coord(&scope, 1);
let div_a = TDim::Div(Box::new(coord_a), 2);
let div_b = TDim::Div(Box::new(coord_b), 2);
let diff = (div_a.clone() - div_b.clone()).reduce();
let ge_lower = TDim::Ge(Box::new(diff.clone()), Box::new(TDim::Val(0))).reduce();
let ge_upper = TDim::Ge(Box::new(TDim::Val(1)), Box::new(diff.clone())).reduce();
let mask = TDim::Mul(vec![ge_upper, ge_lower]).reduce();
println!("PROBE diff = {diff:?}");
println!("PROBE mask = {mask:?}");
println!("PROBE mask display = {mask}");
}
#[test]
fn decode_mask_rejects_offset_in_numerator() {
let scope = SymbolScope::default();
let expr = TDim::Eq(
Box::new(TDim::Div(Box::new(TDim::Add(vec![coord(&scope, 0), TDim::Val(1)])), 2)),
Box::new(TDim::Div(Box::new(TDim::Add(vec![coord(&scope, 1), TDim::Val(1)])), 2)),
);
assert_eq!(decode_mask(&expr, &[0, 1]), None);
}
fn einsum_for(inputs: &[&str], output: &str) -> EinSum {
EinSum {
axes: AxesMapping::from_strs(inputs, &[output]).unwrap(),
operating_dt: f32::datum_type(),
q_params: None,
}
}
fn axes_to_strings(op: &EinSum) -> (Vec<String>, Vec<String>) {
let (ins, outs) = op.axes.to_strs();
(ins.into_iter().collect(), outs.into_iter().collect())
}
fn ck(op: &EinSum, ins: &[usize], out: usize) -> EinSum {
let in_starts: Vec<Option<usize>> = ins.iter().map(|&p| Some(p)).collect();
chunkify_einsum(op, &in_starts, Some(out)).unwrap()
}
#[test]
fn chunkify_einsum_handles_streaming_at_position_zero() {
let op = einsum_for(&["id", "jd"], "ij");
let chunked = ck(&op, &[0, 0], 0);
let (ins, outs) = axes_to_strings(&chunked);
let chunk_char = op.axes.available_label();
assert_eq!(ins[0], format!("{chunk_char}id"));
assert_eq!(ins[1], format!("{chunk_char}jd"));
assert_eq!(outs[0], format!("{chunk_char}ij"));
}
#[test]
fn chunkify_einsum_handles_streaming_at_inner_position() {
let op = einsum_for(&["bid", "bjd"], "bij");
let chunked = ck(&op, &[1, 1], 1);
let (ins, outs) = axes_to_strings(&chunked);
let chunk_char = op.axes.available_label();
assert_eq!(ins[0], format!("b{chunk_char}id"));
assert_eq!(ins[1], format!("b{chunk_char}jd"));
assert_eq!(outs[0], format!("b{chunk_char}ij"));
}
#[test]
fn chunkify_einsum_handles_mixed_input_positions() {
let op = einsum_for(&["id", "bjd"], "bij");
let chunked = ck(&op, &[0, 1], 1);
let (ins, outs) = axes_to_strings(&chunked);
let chunk_char = op.axes.available_label();
assert_eq!(ins[0], format!("{chunk_char}id"));
assert_eq!(ins[1], format!("b{chunk_char}jd"));
assert_eq!(outs[0], format!("b{chunk_char}ij"));
}
#[test]
fn chunkify_einsum_for_terminator_with_two_streaming_input() {
let op = einsum_for(&["ij", "jd"], "id");
let chunked = ck(&op, &[0, 0], 0);
let (ins, outs) = axes_to_strings(&chunked);
let chunk_char = op.axes.available_label();
assert_eq!(ins[0], format!("{chunk_char}ij"));
assert_eq!(ins[1], format!("{chunk_char}jd"));
assert_eq!(outs[0], format!("{chunk_char}id"));
}
#[test]
fn chunked_axis_index_zero_chunk_position() {
assert_eq!(chunked_axis_index(0, 0), 1);
assert_eq!(chunked_axis_index(1, 0), 2);
}
#[test]
fn chunked_axis_index_inner_chunk_position() {
assert_eq!(chunked_axis_index(0, 1), 0);
assert_eq!(chunked_axis_index(1, 1), 2);
assert_eq!(chunked_axis_index(2, 1), 3);
}
#[test]
fn connected_components_splits_independent_subgraphs() {
let mut model = TypedModel::default();
let t = model.symbols.sym("T");
let a1 = model.add_source("a1", f32::fact(dims![t.clone(), 4_usize].as_ref())).unwrap();
let b1 =
model.wire_node("b1", tract_core::ops::change_axes::AxisOp::Add(0), &[a1]).unwrap()[0];
let c1 =
model.wire_node("c1", tract_core::ops::change_axes::AxisOp::Add(0), &[b1]).unwrap()[0];
let a2 = model.add_source("a2", f32::fact(dims![t.clone(), 4_usize].as_ref())).unwrap();
let b2 =
model.wire_node("b2", tract_core::ops::change_axes::AxisOp::Add(0), &[a2]).unwrap()[0];
let c2 =
model.wire_node("c2", tract_core::ops::change_axes::AxisOp::Add(0), &[b2]).unwrap()[0];
model.select_output_outlets(&[c1, c2]).unwrap();
let multi: BTreeSet<usize> = [b1.node, c1.node, b2.node, c2.node].into_iter().collect();
let groups = connected_components(&model, &multi);
assert_eq!(groups.len(), 2, "expected two independent components: {groups:?}");
let g0: BTreeSet<usize> = [b1.node, c1.node].into_iter().collect();
let g1: BTreeSet<usize> = [b2.node, c2.node].into_iter().collect();
assert!(groups.iter().any(|g| *g == g0), "expected component {g0:?} in {groups:?}");
assert!(groups.iter().any(|g| *g == g1), "expected component {g1:?} in {groups:?}");
}
}