use crate::internal::*;
use crate::ops::logic::sym_to_coord_axis;
use crate::optim::OptimizerSession;
#[derive(Clone, Debug, Default)]
pub struct PropagateRoi;
fn roi_union(a: &TDim, b: &TDim) -> TDim {
if a == b {
return a.clone();
}
a.clone() + b.clone() - a.clone() * b.clone()
}
pub fn bubble_roi(model: &TypedModel, node: &TypedNode) -> TractResult<Option<TVec<Option<TDim>>>> {
let output_fact = model.outlet_fact(OutletId::new(node.id, 0))?;
let Some(roi) = &output_fact.region_of_interest else { return Ok(None) };
let input_facts: TVec<&TypedFact> =
node.inputs.iter().map(|i| model.outlet_fact(*i)).collect::<TractResult<_>>()?;
let output_facts = tvec![output_fact];
let inputs_ref: Vec<&TypedFact> = input_facts.iter().copied().collect();
let outputs_ref: Vec<&TypedFact> = output_facts.iter().copied().collect();
let mapping = node.op.as_typed().unwrap().axes_mapping(&inputs_ref, &outputs_ref)?;
let roi_coord_syms: Vec<(usize, Symbol)> =
roi.symbols().into_iter().filter_map(|s| sym_to_coord_axis(&s).map(|k| (k, s))).collect();
let remap_for_input = |input_ix: usize| -> Option<TDim> {
let mut sub_map: HashMap<Symbol, TDim> = HashMap::new();
for (out_pos, sym) in &roi_coord_syms {
let logical = mapping
.iter_all_axes()
.find(|a| a.outputs.first().is_some_and(|o| o.contains(out_pos)))?;
if logical.inputs[input_ix].is_empty() {
return None;
}
let in_pos = logical.inputs[input_ix][0];
if input_facts[input_ix].shape[in_pos] != output_fact.shape[*out_pos] {
return None;
}
if in_pos != *out_pos {
let scope = sym.scope()?;
sub_map.insert(sym.clone(), TDim::Sym(scope.coord_sym(in_pos)));
}
}
if sub_map.is_empty() { Some(roi.clone()) } else { roi.substitute_all(&sub_map).ok() }
};
let result: TVec<Option<TDim>> = (0..node.inputs.len()).map(|ix| remap_for_input(ix)).collect();
Ok(Some(result))
}
impl super::TypedPass for PropagateRoi {
fn reset(&mut self) -> TractResult<()> {
Ok(())
}
fn next(
&mut self,
_session: &mut OptimizerSession,
_model: &TypedModel,
) -> TractResult<Option<TypedModelPatch>> {
Ok(None)
}
fn run_direct(&mut self, model: &mut TypedModel) -> TractResult<bool> {
let order = model.eval_order()?;
let mut changed = false;
let mut demands: HashMap<OutletId, Option<TDim>> = HashMap::new();
for &node_id in &order {
let node = &model.nodes()[node_id];
let Some(input_rois) = node.op.as_typed().unwrap().input_roi(model, node)? else {
continue;
};
for (ix, roi) in input_rois.into_iter().enumerate() {
let outlet = node.inputs[ix];
match (demands.get(&outlet), &roi) {
(_, None) => {
demands.insert(outlet, None);
}
(Option::None, Some(roi)) => {
demands.insert(outlet, Some(roi.clone()));
}
(Some(None), Some(_)) => {}
(Some(Some(existing)), Some(new)) => {
demands.insert(outlet, Some(roi_union(existing, new)));
}
}
}
}
for (outlet, demand) in demands {
if let Some(roi) = demand {
let fact = &mut model.nodes_mut()[outlet.node].outputs[outlet.slot].fact;
if fact.region_of_interest.as_ref() != Some(&roi) {
fact.region_of_interest = Some(roi);
changed = true;
}
}
}
Ok(changed)
}
}