use rlx_ir::op::Activation;
use rlx_ir::{Graph, Op};
use std::collections::HashMap;
pub fn run(graph: &mut Graph) -> usize {
let mut redirect: HashMap<rlx_ir::NodeId, rlx_ir::NodeId> = HashMap::new();
for node in graph.nodes() {
let Op::FakeQuantize {
bits, axis, ste, ..
} = &node.op
else {
continue;
};
let inner_input = node.inputs[0];
let mut cur = inner_input;
loop {
let parent = graph.node(cur);
match &parent.op {
Op::FakeQuantize {
bits: b2,
axis: a2,
ste: s2,
..
} if b2 == bits && a2 == axis && s2 == ste => {
redirect.insert(node.id, parent.id);
break;
}
op if is_magnitude_preserving(op, *axis) => {
if parent.inputs.len() != 1 {
break;
}
cur = parent.inputs[0];
}
_ => break,
}
}
}
if redirect.is_empty() {
return 0;
}
let n_eliminated = redirect.len();
let node_ids: Vec<_> = graph.nodes().iter().map(|n| n.id).collect();
for id in node_ids {
let inputs = graph.node(id).inputs.clone();
let mut new_inputs = inputs.clone();
let mut changed = false;
for (i, &input) in inputs.iter().enumerate() {
if let Some(&target) = redirect.get(&input) {
new_inputs[i] = target;
changed = true;
}
}
if changed {
graph.set_inputs(id, new_inputs);
}
}
let outs: Vec<_> = graph
.outputs
.iter()
.map(|&o| redirect.get(&o).copied().unwrap_or(o))
.collect();
if outs != graph.outputs {
graph.set_outputs(outs);
}
n_eliminated
}
fn is_magnitude_preserving(op: &Op, axis: Option<usize>) -> bool {
match op {
Op::Reshape { .. } | Op::Transpose { .. } | Op::Narrow { .. } | Op::Concat { .. } => {
axis.is_none()
}
Op::Pool {
kind: rlx_ir::op::ReduceOp::Max,
..
} => true,
Op::Activation(Activation::Relu) => true,
_ => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
use rlx_ir::op::*;
use rlx_ir::*;
#[test]
fn collapses_redundant_fake_quant_through_relu() {
let f = DType::F32;
let mut g = Graph::new("collapse");
let x = g.input("x", Shape::new(&[4], f));
let q1 = g.add_node(
Op::FakeQuantize {
bits: 8,
axis: None,
ste: SteKind::default(),
scale_mode: ScaleMode::default(),
},
vec![x],
Shape::new(&[4], f),
);
let r = g.activation(Activation::Relu, q1, Shape::new(&[4], f));
let q2 = g.add_node(
Op::FakeQuantize {
bits: 8,
axis: None,
ste: SteKind::default(),
scale_mode: ScaleMode::default(),
},
vec![r],
Shape::new(&[4], f),
);
g.set_outputs(vec![q2]);
let n = run(&mut g);
assert_eq!(n, 1, "should have eliminated the second fake-quant");
assert_eq!(g.outputs, vec![q1]);
}
#[test]
fn keeps_fake_quant_with_different_bits() {
let f = DType::F32;
let mut g = Graph::new("keep");
let x = g.input("x", Shape::new(&[4], f));
let q1 = g.add_node(
Op::FakeQuantize {
bits: 8,
axis: None,
ste: SteKind::default(),
scale_mode: ScaleMode::default(),
},
vec![x],
Shape::new(&[4], f),
);
let r = g.activation(Activation::Relu, q1, Shape::new(&[4], f));
let q2 = g.add_node(
Op::FakeQuantize {
bits: 4,
axis: None,
ste: SteKind::default(),
scale_mode: ScaleMode::default(),
},
vec![r],
Shape::new(&[4], f),
);
g.set_outputs(vec![q2]);
let n = run(&mut g);
assert_eq!(n, 0, "different bits → don't collapse");
}
#[test]
fn keeps_fake_quant_when_intermediate_isnt_safe() {
let f = DType::F32;
let mut g = Graph::new("unsafe_chain");
let x = g.input("x", Shape::new(&[4], f));
let q1 = g.add_node(
Op::FakeQuantize {
bits: 8,
axis: None,
ste: SteKind::default(),
scale_mode: ScaleMode::default(),
},
vec![x],
Shape::new(&[4], f),
);
let e = g.activation(Activation::Exp, q1, Shape::new(&[4], f));
let q2 = g.add_node(
Op::FakeQuantize {
bits: 8,
axis: None,
ste: SteKind::default(),
scale_mode: ScaleMode::default(),
},
vec![e],
Shape::new(&[4], f),
);
g.set_outputs(vec![q2]);
let n = run(&mut g);
assert_eq!(n, 0, "Exp can grow magnitude; don't collapse");
}
}