use crate::pass::Pass;
use rlx_ir::*;
use std::collections::HashMap;
pub struct LowerDotGeneral;
impl Pass for LowerDotGeneral {
fn name(&self) -> &str {
"lower_dot_general"
}
fn run(&self, graph: Graph) -> Graph {
if !graph
.nodes()
.iter()
.any(|n| matches!(n.op, Op::DotGeneral { .. }))
{
return graph;
}
let mut new_graph = Graph::new(&graph.name);
let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
for node in graph.nodes() {
let new_id = match &node.op {
Op::DotGeneral {
lhs_contracting,
rhs_contracting,
lhs_batch,
rhs_batch,
} => {
if lhs_batch.is_empty()
&& rhs_batch.is_empty()
&& lhs_contracting.as_slice() == [1]
&& rhs_contracting.as_slice() == [0]
{
let lhs = id_map[&node.inputs[0]];
let rhs = id_map[&node.inputs[1]];
new_graph.add_node(Op::MatMul, vec![lhs, rhs], node.shape.clone())
} else {
let inputs: Vec<NodeId> = node.inputs.iter().map(|i| id_map[i]).collect();
new_graph.add_node(node.op.clone(), inputs, node.shape.clone())
}
}
_ => {
let inputs: Vec<NodeId> = node.inputs.iter().map(|i| id_map[i]).collect();
new_graph.add_node(node.op.clone(), inputs, node.shape.clone())
}
};
id_map.insert(node.id, new_id);
}
let new_outputs: Vec<NodeId> = graph.outputs.iter().map(|i| id_map[i]).collect();
new_graph.set_outputs(new_outputs);
new_graph
}
}