use rlx_ir::{DType, Graph, NodeId, Op, Shape};
use std::collections::HashMap;
#[derive(Debug, Clone, Default)]
pub struct IoDtypeManifest {
#[allow(dead_code)]
pub inputs: HashMap<String, DType>,
#[allow(dead_code)]
pub params: HashMap<String, DType>,
pub outputs: Vec<DType>,
}
impl IoDtypeManifest {
pub fn from_graph(g: &Graph) -> Self {
let mut inputs = HashMap::new();
let mut params = HashMap::new();
for node in g.nodes() {
match &node.op {
Op::Input { name } => {
inputs.insert(name.clone(), node.shape.dtype());
}
Op::Param { name } => {
params.insert(name.clone(), node.shape.dtype());
}
_ => {}
}
}
let outputs = g
.outputs
.iter()
.map(|&id| g.node(id).shape.dtype())
.collect();
Self {
inputs,
params,
outputs,
}
}
pub fn output_dtype(&self, idx: usize, fallback: DType) -> DType {
self.outputs.get(idx).copied().unwrap_or(fallback)
}
}
#[allow(dead_code)]
pub fn prepare_f32_exec_graph(graph: Graph) -> (Graph, IoDtypeManifest) {
let manifest = IoDtypeManifest::from_graph(&graph);
let exec = if needs_f32_exec(&graph) {
promote_to_f32(graph)
} else {
graph
};
(exec, manifest)
}
pub fn needs_f32_exec(g: &Graph) -> bool {
g.nodes().iter().any(|n| {
if !matches!(n.shape.dtype(), DType::F16 | DType::BF16) {
return false;
}
!matches!(
&n.op,
Op::Custom { .. } | Op::Constant { .. } | Op::Input { .. } | Op::Param { .. }
)
})
}
fn promote_dtype(dt: DType) -> DType {
match dt {
DType::F16 | DType::BF16 => DType::F32,
other => other,
}
}
fn promote_shape(shape: &Shape) -> Shape {
shape.clone().with_dtype(promote_dtype(shape.dtype()))
}
fn widen_constant_bytes(data: &[u8], from: DType) -> Vec<u8> {
match from {
DType::F16 => data
.chunks_exact(2)
.flat_map(|c| {
let v = half::f16::from_le_bytes([c[0], c[1]]).to_f32();
v.to_le_bytes()
})
.collect(),
DType::BF16 => data
.chunks_exact(2)
.flat_map(|c| {
let v = half::bf16::from_le_bytes([c[0], c[1]]).to_f32();
v.to_le_bytes()
})
.collect(),
_ => data.to_vec(),
}
}
pub fn promote_to_f32(graph: Graph) -> Graph {
if !needs_f32_exec(&graph) {
return graph;
}
let mut out = Graph::new(format!("{}_f32_exec", graph.name));
let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
for node in graph.nodes() {
let inputs: Vec<NodeId> = node.inputs.iter().map(|i| id_map[i]).collect();
let shape = promote_shape(&node.shape);
let op = match &node.op {
Op::Constant { data } => Op::Constant {
data: widen_constant_bytes(data, node.shape.dtype()),
},
Op::Cast { to } => Op::Cast {
to: promote_dtype(*to),
},
other => other.clone(),
};
let new_id = out.add_node(op, inputs, shape);
id_map.insert(node.id, new_id);
}
out.set_outputs(graph.outputs.iter().map(|o| id_map[o]).collect());
out
}