use rlx_ir::{Graph, NodeId, OpKind};
pub type LegalizeResult = Result<(), Vec<(NodeId, OpKind)>>;
pub fn legalize_for_backend(graph: &Graph, supported: &[OpKind]) -> LegalizeResult {
if supported.is_empty() {
return Ok(());
}
let mut bad = Vec::new();
for node in graph.nodes() {
let k = node.op.kind();
if !supported.contains(&k) {
bad.push((node.id, k));
}
}
if bad.is_empty() { Ok(()) } else { Err(bad) }
}
pub fn format_legalize_error(backend_name: &str, errors: &[(NodeId, OpKind)]) -> String {
use std::fmt::Write as _;
let mut s = format!(
"rlx-opt: backend {backend_name:?} doesn't claim support for {} op kind(s):\n",
errors.len(),
);
for (id, kind) in errors {
let _ = writeln!(s, " - node {id:?}: {kind:?}");
}
s.push_str(
" Backend::supported_ops() must include each kind, or rewrite \
the graph upstream to remove them.",
);
if errors.iter().any(|(_, k)| *k == OpKind::Custom) {
s.push_str(
"\n `Op::Custom` is registered by name; the IR-level \
extension (`rlx_ir::register_op`) routes shape inference \
and autodiff. Per-backend execution requires registering \
a kernel in that backend's `op_registry`:\n\
\x20 - CPU: `rlx_cpu::op_registry::register_cpu_kernel`\n\
\x20 - Metal: `rlx_metal::op_registry::register_metal_kernel` \
(trait surface only — execution dispatch not wired yet)\n\
\x20 - MLX: `rlx_mlx::op_registry::register_mlx_kernel` \
(trait surface only — execution dispatch not wired yet)\n\
\x20For now, pin custom-op graphs to `Device::Cpu`.",
);
}
s
}
#[cfg(test)]
mod tests {
use super::*;
use rlx_ir::infer::GraphExt;
use rlx_ir::*;
fn tiny_graph() -> Graph {
let f = DType::F32;
let mut g = Graph::new("legalize");
let a = g.input("a", Shape::new(&[4], f));
let b = g.input("b", Shape::new(&[4], f));
let s = g.add(a, b);
let r = g.relu(s);
g.set_outputs(vec![r]);
g
}
#[test]
fn empty_supported_set_accepts_anything() {
let g = tiny_graph();
assert!(legalize_for_backend(&g, &[]).is_ok());
}
#[test]
fn supported_set_with_all_required_kinds_passes() {
let g = tiny_graph();
let supported = &[OpKind::Input, OpKind::Binary, OpKind::Activation];
assert!(legalize_for_backend(&g, supported).is_ok());
}
#[test]
fn unsupported_op_kind_is_reported() {
let g = tiny_graph();
let supported = &[OpKind::Input, OpKind::Binary];
let result = legalize_for_backend(&g, supported);
let errors = result.expect_err("should fail");
assert_eq!(errors.len(), 1);
assert_eq!(errors[0].1, OpKind::Activation);
}
#[test]
fn multiple_offenders_all_reported() {
let g = tiny_graph();
let supported = &[OpKind::Input];
let result = legalize_for_backend(&g, supported);
let errors = result.expect_err("should fail");
assert_eq!(errors.len(), 2);
let kinds: Vec<OpKind> = errors.iter().map(|(_, k)| *k).collect();
assert!(kinds.contains(&OpKind::Binary));
assert!(kinds.contains(&OpKind::Activation));
}
#[test]
fn format_error_includes_kind_and_count() {
let g = tiny_graph();
let supported = &[OpKind::Input];
let errors = legalize_for_backend(&g, supported).unwrap_err();
let msg = format_legalize_error("test_backend", &errors);
assert!(msg.contains("test_backend"));
assert!(msg.contains("2 op kind"));
assert!(msg.contains("Binary"));
assert!(msg.contains("Activation"));
}
}