use rlx_fusion::pass::Pass;
use rlx_ir::*;
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Precision {
F32,
F16,
BF16,
}
impl Precision {
pub fn dtype(self) -> DType {
match self {
Precision::F32 => DType::F32,
Precision::F16 => DType::F16,
Precision::BF16 => DType::BF16,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct CastConfig {
pub out_dtype: DType,
pub sf_block: Option<(usize, usize)>,
pub round_sf: bool,
}
impl CastConfig {
pub const fn plain(out_dtype: DType) -> Self {
Self {
out_dtype,
sf_block: None,
round_sf: false,
}
}
pub fn is_noop(&self, in_dtype: DType) -> bool {
self.out_dtype == in_dtype && self.sf_block.is_none()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum OpKind {
Compute,
Reduction,
Elementwise,
DataMovement,
Boundary,
}
fn op_kind(op: &Op) -> OpKind {
match op {
Op::MatMul
| Op::FusedMatMulBiasAct { .. }
| Op::Conv { .. }
| Op::DotGeneral { .. }
| Op::DenseSolve
| Op::BatchedDenseSolve
| Op::Attention { .. }
| Op::FusedTransformerLayer { .. }
| Op::GroupedMatMul
| Op::DequantGroupedMatMul { .. }
| Op::DequantMoEWeights { .. }
| Op::LoraMatMul { .. }
| Op::DequantMatMul { .. }
| Op::QMatMul { .. }
| Op::QConv2d { .. }
| Op::Conv2dBackwardInput { .. }
| Op::Conv2dBackwardWeight { .. }
| Op::AttentionBackward { .. } => OpKind::Compute,
Op::LayerNorm { .. }
| Op::RmsNorm { .. }
| Op::Softmax { .. }
| Op::FusedResidualLN { .. }
| Op::FusedResidualRmsNorm { .. }
| Op::Reduce { .. }
| Op::Cumsum { .. }
| Op::Sample { .. }
| Op::SelectiveScan { .. }
| Op::GatedDeltaNet { .. }
| Op::SoftmaxCrossEntropyWithLogits
| Op::SoftmaxCrossEntropyBackward
| Op::LayerNormBackwardInput { .. }
| Op::LayerNormBackwardGamma { .. }
| Op::GroupNorm { .. } => OpKind::Reduction,
Op::Activation(_)
| Op::Binary(_)
| Op::FusedSwiGLU { .. }
| Op::Compare(_)
| Op::Where
| Op::ElementwiseRegion { .. }
| Op::Quantize { .. }
| Op::Dequantize { .. }
| Op::FakeQuantize { .. }
| Op::FakeQuantizeBackward { .. }
| Op::FakeQuantizeLSQ { .. }
| Op::FakeQuantizeLSQBackwardX { .. }
| Op::FakeQuantizeLSQBackwardScale { .. }
| Op::ReluBackward
| Op::ActivationBackward { .. }
| Op::ComplexNormSq
| Op::ComplexNormSqBackward
| Op::Conjugate => OpKind::Elementwise,
Op::Gather { .. }
| Op::Narrow { .. }
| Op::Reshape { .. }
| Op::Transpose { .. }
| Op::Concat { .. }
| Op::Expand { .. }
| Op::Cast { .. }
| Op::Rope { .. }
| Op::Pool { .. }
| Op::FusedAttentionBlock { .. }
| Op::TopK { .. }
| Op::ScatterAdd
| Op::MaxPool2dBackward { .. }
| Op::ResizeNearest2x
| Op::AxialRope2d { .. } => OpKind::DataMovement,
Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => OpKind::Boundary,
Op::If { .. } | Op::While { .. } => OpKind::DataMovement,
Op::Custom { .. } => OpKind::Compute,
Op::Scan { .. } => OpKind::Compute,
Op::ScanBackward { .. } => OpKind::Compute,
Op::ScanBackwardXs { .. } => OpKind::Compute,
Op::CustomFn { .. } => OpKind::Compute,
Op::Fft { .. } => OpKind::Compute,
_ => OpKind::Compute,
}
}
#[derive(Debug, Clone, Default)]
pub enum PrecisionPolicy {
#[default]
AlwaysF32,
AlwaysF16,
AutoMixedConservative,
AutoMixed,
AutoMixedBf16,
Custom(HashMap<OpKind, Precision>),
}
impl PrecisionPolicy {
pub fn precision_for(&self, kind: OpKind) -> Precision {
match self {
PrecisionPolicy::AlwaysF32 => Precision::F32,
PrecisionPolicy::AlwaysF16 => match kind {
OpKind::Boundary => Precision::F32, _ => Precision::F16,
},
PrecisionPolicy::AutoMixedConservative => match kind {
OpKind::Compute => Precision::F16,
OpKind::Reduction => Precision::F32,
OpKind::Elementwise => Precision::F16,
OpKind::DataMovement => Precision::F16,
OpKind::Boundary => Precision::F32,
},
PrecisionPolicy::AutoMixed => match kind {
OpKind::Compute => Precision::F16,
OpKind::Reduction => Precision::F16,
OpKind::Elementwise => Precision::F16,
OpKind::DataMovement => Precision::F16,
OpKind::Boundary => Precision::F32,
},
PrecisionPolicy::AutoMixedBf16 => match kind {
OpKind::Compute => Precision::BF16,
OpKind::Reduction => Precision::BF16,
OpKind::Elementwise => Precision::BF16,
OpKind::DataMovement => Precision::BF16,
OpKind::Boundary => Precision::F32,
},
PrecisionPolicy::Custom(map) => map.get(&kind).copied().unwrap_or(Precision::F32),
}
}
}
pub struct AutoMixedPrecision {
pub policy: PrecisionPolicy,
}
impl AutoMixedPrecision {
pub fn new(policy: PrecisionPolicy) -> Self {
Self { policy }
}
}
impl Pass for AutoMixedPrecision {
fn name(&self) -> &str {
"auto_mixed_precision"
}
fn run(&self, graph: Graph) -> Graph {
if matches!(self.policy, PrecisionPolicy::AlwaysF32) {
return graph;
}
let mut new_graph = Graph::new(&graph.name);
let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
let mut node_precision: HashMap<NodeId, Precision> = HashMap::new();
let mut cast_cache: HashMap<(NodeId, Precision), NodeId> = HashMap::new();
for node in graph.nodes() {
let kind = op_kind(&node.op);
let target = self.policy.precision_for(kind);
let target = match kind {
OpKind::Boundary => Precision::F32,
_ => target,
};
let mut new_inputs = Vec::with_capacity(node.inputs.len());
for &in_id in &node.inputs {
let src_new_id = id_map[&in_id];
let src_prec = node_precision
.get(&in_id)
.copied()
.unwrap_or(Precision::F32);
if src_prec == target {
new_inputs.push(src_new_id);
} else {
let cast_id = *cast_cache.entry((src_new_id, target)).or_insert_with(|| {
let shape = new_graph
.node(src_new_id)
.shape
.clone()
.with_dtype(target.dtype());
new_graph.add_node(Op::Cast { to: target.dtype() }, vec![src_new_id], shape)
});
new_inputs.push(cast_id);
}
}
let new_shape = node.shape.clone().with_dtype(target.dtype());
let new_id = new_graph.add_node(node.op.clone(), new_inputs, new_shape);
id_map.insert(node.id, new_id);
node_precision.insert(node.id, target);
}
let new_outputs: Vec<NodeId> = graph
.outputs
.iter()
.map(|&out_id| {
let src_new_id = id_map[&out_id];
let src_prec = node_precision
.get(&out_id)
.copied()
.unwrap_or(Precision::F32);
if src_prec == Precision::F32 {
src_new_id
} else {
let shape = new_graph
.node(src_new_id)
.shape
.clone()
.with_dtype(DType::F32);
new_graph.add_node(Op::Cast { to: DType::F32 }, vec![src_new_id], shape)
}
})
.collect();
new_graph.set_outputs(new_outputs);
new_graph
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn always_f32_is_noop() {
let mut g = Graph::new("test");
let x = g.input("x", Shape::new(&[2, 4], DType::F32));
let w = g.param("w", Shape::new(&[4, 3], DType::F32));
let mm = g.matmul(x, w, Shape::new(&[2, 3], DType::F32));
g.set_outputs(vec![mm]);
let pass = AutoMixedPrecision::new(PrecisionPolicy::AlwaysF32);
let out = pass.run(g);
assert_eq!(out.len(), 3); }
#[test]
fn auto_mixed_inserts_casts_at_boundary() {
let mut g = Graph::new("test");
let x = g.input("x", Shape::new(&[2, 4], DType::F32));
let w = g.param("w", Shape::new(&[4, 3], DType::F32));
let mm = g.matmul(x, w, Shape::new(&[2, 3], DType::F32));
g.set_outputs(vec![mm]);
let pass = AutoMixedPrecision::new(PrecisionPolicy::AutoMixed);
let out = pass.run(g);
assert!(out.len() >= 6);
let final_node = out.node(out.outputs[0]);
assert!(matches!(final_node.op, Op::Cast { to: DType::F32 }));
}
}