use crate::internal::*;
use crate::optim::OptimizerSession;
#[derive(Clone, Debug, Default)]
pub struct PropagateUniformTdim;
impl super::TypedPass for PropagateUniformTdim {
fn reset(&mut self) -> TractResult<()> {
Ok(())
}
fn next(
&mut self,
_session: &mut OptimizerSession,
_model: &TypedModel,
) -> TractResult<Option<TypedModelPatch>> {
Ok(None)
}
fn run_direct(&mut self, model: &mut TypedModel) -> TractResult<bool> {
let order = model.eval_order()?;
let mut any_changed = false;
loop {
let mut changed = false;
for &node_id in &order {
let typed_op = match model.nodes()[node_id].op.as_typed() {
Some(op) => op,
None => continue,
};
let input_facts: TVec<TypedFact> = model.nodes()[node_id]
.inputs
.iter()
.map(|i| model.outlet_fact(*i).cloned())
.collect::<TractResult<_>>()?;
let input_refs: TVec<&TypedFact> = input_facts.iter().collect();
let new_facts = match typed_op.output_facts(&input_refs) {
Ok(f) => f,
Err(_) => continue,
};
for (slot, new_fact) in new_facts.iter().enumerate() {
let current_uniform_tdim =
model.nodes()[node_id].outputs[slot].fact.uniform_tdim.clone();
if current_uniform_tdim != new_fact.uniform_tdim {
model.nodes_mut()[node_id].outputs[slot].fact.uniform_tdim =
new_fact.uniform_tdim.clone();
changed = true;
}
}
}
if !changed {
break;
}
any_changed = true;
}
Ok(any_changed)
}
}