use std::collections::HashMap;
use bb_ir::proto::onnx::GraphProto;
use bb_ir::types::{PortRef, RelationResult, TypeNode, TypeRelation, TYPE_ANY};
#[derive(Debug)]
pub enum TypeError {
ConstraintFailed {
op: String,
detail: String,
},
UnresolvedType {
value: String,
},
PortOutOfRange {
op: String,
port: PortRef,
},
}
impl std::fmt::Display for TypeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ConstraintFailed { op, detail } => {
write!(f, "type constraint failed at {op}: {detail}")
}
Self::UnresolvedType { value } => {
write!(f, "value `{value}` did not resolve to a concrete type")
}
Self::PortOutOfRange { op, port } => {
write!(f, "op {op} references out-of-range port {port:?}")
}
}
}
}
impl std::error::Error for TypeError {}
#[derive(Debug)]
pub struct TypeSolution {
by_value: HashMap<String, &'static TypeNode>,
}
impl TypeSolution {
pub fn type_of(&self, value: &str) -> Option<&'static TypeNode> {
self.by_value.get(value).copied()
}
pub fn iter(&self) -> impl Iterator<Item = (&str, &'static TypeNode)> {
self.by_value.iter().map(|(k, v)| (k.as_str(), *v))
}
}
pub struct TypeSolver {
types: Vec<TypeNodeSlot>,
relations: Vec<RelationNode>,
value_index: HashMap<String, usize>,
}
struct TypeNodeSlot {
resolved: &'static TypeNode,
rel_set: Vec<usize>,
}
struct RelationNode {
decl: &'static TypeRelation,
op_name: String,
slots: Vec<usize>,
satisfied: bool,
}
impl TypeSolver {
pub fn from_graph(
graph: &GraphProto,
decl_for_op: impl Fn(&str, &str) -> Option<&'static bb_ir::atomic::AtomicOpDecl>,
) -> Result<Self, TypeError> {
let mut solver = Self {
types: Vec::new(),
relations: Vec::new(),
value_index: HashMap::new(),
};
for input in &graph.input {
solver.intern_value(&input.name);
}
for node in &graph.node {
for out in &node.output {
if !out.is_empty() {
solver.intern_value(out);
}
}
for inp in &node.input {
if !inp.is_empty() {
solver.intern_value(inp);
}
}
}
for node in &graph.node {
let Some(decl) = decl_for_op(&node.domain, &node.op_type) else {
continue;
};
for relation in decl.type_relations {
let slots = solver.resolve_relation_ports(node, relation)?;
let rel_idx = solver.relations.len();
solver.relations.push(RelationNode {
decl: relation,
op_name: format!("{}::{}", node.domain, node.op_type),
slots: slots.clone(),
satisfied: false,
});
for s in slots {
solver.types[s].rel_set.push(rel_idx);
}
}
}
Ok(solver)
}
fn intern_value(&mut self, name: &str) -> usize {
if let Some(&idx) = self.value_index.get(name) {
return idx;
}
let idx = self.types.len();
self.types.push(TypeNodeSlot {
resolved: &TYPE_ANY,
rel_set: Vec::new(),
});
self.value_index.insert(name.to_string(), idx);
idx
}
fn resolve_relation_ports(
&mut self,
node: &bb_ir::proto::onnx::NodeProto,
relation: &TypeRelation,
) -> Result<Vec<usize>, TypeError> {
let ports: Vec<PortRef> = match relation {
TypeRelation::SameType(p) | TypeRelation::SameElementType(p) => p.to_vec(),
TypeRelation::Elementwise { input, output } => vec![*input, *output],
TypeRelation::BroadcastShape { in0, in1, out } => vec![*in0, *in1, *out],
TypeRelation::ReduceOver { input, output } => vec![*input, *output],
TypeRelation::Custom { .. } => Vec::new(),
};
let op_name = format!("{}::{}", node.domain, node.op_type);
let mut slots = Vec::with_capacity(ports.len());
for port in ports {
let value_name = match port {
PortRef::Input(i) => node.input.get(i as usize).cloned(),
PortRef::Output(o) => node.output.get(o as usize).cloned(),
};
let Some(name) = value_name else {
return Err(TypeError::PortOutOfRange { op: op_name, port });
};
if name.is_empty() {
return Err(TypeError::PortOutOfRange { op: op_name, port });
}
slots.push(self.intern_value(&name));
}
Ok(slots)
}
pub fn seed(&mut self, value: &str, node: &'static TypeNode) {
if let Some(&idx) = self.value_index.get(value) {
self.types[idx].resolved = node;
}
}
pub fn seed_from_value_info(&mut self, graph: &GraphProto) {
for vi in graph.input.iter().chain(graph.value_info.iter()) {
let Some(type_proto) = vi.r#type.as_ref() else {
continue;
};
let denotation = type_proto.denotation.as_str();
if denotation.is_empty() {
continue;
}
if let Some(node) = bb_ir::types::builtins::lookup_denotation(denotation) {
self.seed(&vi.name, node);
}
}
}
pub fn solve(mut self) -> Result<TypeSolution, TypeError> {
let mut worklist: std::collections::VecDeque<usize> = (0..self.relations.len()).collect();
while let Some(rel_idx) = worklist.pop_front() {
if self.relations[rel_idx].satisfied {
continue;
}
let outcome = self.run_relation(rel_idx)?;
match outcome {
RelationResult::Refined => {
let slots = self.relations[rel_idx].slots.clone();
for s in slots {
for &dep in &self.types[s].rel_set {
if dep != rel_idx && !self.relations[dep].satisfied {
worklist.push_back(dep);
}
}
}
}
RelationResult::Satisfied => {
self.relations[rel_idx].satisfied = true;
}
RelationResult::Defer => {
}
RelationResult::Failed(detail) => {
return Err(TypeError::ConstraintFailed {
op: self.relations[rel_idx].op_name.clone(),
detail: detail.to_string(),
});
}
}
}
let mut by_value: HashMap<String, &'static TypeNode> = HashMap::new();
for (name, &idx) in &self.value_index {
let node = self.types[idx].resolved;
by_value.insert(name.clone(), node);
}
Ok(TypeSolution { by_value })
}
pub fn apply_solution_to_value_info(graph: &mut GraphProto, solution: &TypeSolution) {
for vi in graph.input.iter_mut().chain(graph.value_info.iter_mut()) {
let Some(node) = solution.type_of(&vi.name) else {
continue;
};
if node.is_abstract() {
continue;
}
let denotation = type_node_to_denotation(node);
if denotation.is_empty() {
continue;
}
if let Some(type_proto) = vi.r#type.as_mut() {
type_proto.denotation = denotation.to_string();
}
}
}
pub fn solve_strict(self) -> Result<TypeSolution, TypeError> {
let solution = self.solve()?;
for (name, node) in &solution.by_value {
if node.is_abstract() {
return Err(TypeError::UnresolvedType {
value: name.clone(),
});
}
}
Ok(solution)
}
fn run_relation(&mut self, idx: usize) -> Result<RelationResult, TypeError> {
let slots = self.relations[idx].slots.clone();
let decl = self.relations[idx].decl;
let outcome = match decl {
TypeRelation::SameType(_) => self.run_same_type(&slots),
TypeRelation::SameElementType(_) => self.run_same_element_type(&slots),
TypeRelation::Elementwise { .. } => self.run_elementwise(&slots),
TypeRelation::BroadcastShape { .. } => self.run_broadcast_shape(&slots),
TypeRelation::ReduceOver { .. } => self.run_reduce_over(&slots),
TypeRelation::Custom { run, .. } => {
let _ = run;
Ok(RelationResult::Defer)
}
}?;
Ok(outcome)
}
fn run_same_type(&mut self, slots: &[usize]) -> Result<RelationResult, TypeError> {
let pivot: Option<&'static TypeNode> = slots
.iter()
.map(|&s| self.types[s].resolved)
.find(|n| n.is_concrete());
let Some(pivot) = pivot else {
return Ok(RelationResult::Defer);
};
let mut refined = false;
for &s in slots {
let cur = self.types[s].resolved;
if std::ptr::eq(cur, pivot) {
continue;
}
if cur.is_abstract() && pivot.is_subtype_of(cur) {
self.types[s].resolved = pivot;
refined = true;
} else {
return Ok(RelationResult::Failed(
"SameType: incompatible concrete types",
));
}
}
Ok(if refined {
RelationResult::Refined
} else {
RelationResult::Satisfied
})
}
fn run_same_element_type(&mut self, slots: &[usize]) -> Result<RelationResult, TypeError> {
self.run_same_type(slots)
}
fn run_elementwise(&mut self, slots: &[usize]) -> Result<RelationResult, TypeError> {
let inp = self.types[slots[0]].resolved;
let out = self.types[slots[1]].resolved;
if inp.is_concrete() && std::ptr::eq(inp, out) {
return Ok(RelationResult::Satisfied);
}
if inp.is_concrete() && out.is_abstract() && inp.is_subtype_of(out) {
self.types[slots[1]].resolved = inp;
return Ok(RelationResult::Refined);
}
if out.is_concrete() && inp.is_abstract() && out.is_subtype_of(inp) {
self.types[slots[0]].resolved = out;
return Ok(RelationResult::Refined);
}
if inp.is_concrete() && out.is_concrete() && !std::ptr::eq(inp, out) {
return Ok(RelationResult::Failed("Elementwise: input != output"));
}
Ok(RelationResult::Defer)
}
fn run_broadcast_shape(&mut self, slots: &[usize]) -> Result<RelationResult, TypeError> {
self.run_same_element_type(&[slots[0], slots[1], slots[2]])
}
fn run_reduce_over(&mut self, slots: &[usize]) -> Result<RelationResult, TypeError> {
self.run_elementwise(slots)
}
}
fn type_node_to_denotation(node: &'static TypeNode) -> &'static str {
use bb_ir::types::builtins as B;
if std::ptr::eq(node, &B::TYPE_TENSOR_F32) {
return "ai.bytesandbrains.tensor.f32";
}
if std::ptr::eq(node, &B::TYPE_TENSOR_F64) {
return "ai.bytesandbrains.tensor.f64";
}
if std::ptr::eq(node, &B::TYPE_TENSOR_F16) {
return "ai.bytesandbrains.tensor.f16";
}
if std::ptr::eq(node, &B::TYPE_TENSOR_U8) {
return "ai.bytesandbrains.tensor.u8";
}
if std::ptr::eq(node, &B::TYPE_TENSOR_I32) {
return "ai.bytesandbrains.tensor.i32";
}
if std::ptr::eq(node, &B::TYPE_SCALAR_F32) {
return "bb.f32";
}
if std::ptr::eq(node, &B::TYPE_SCALAR_F64) {
return "bb.f64";
}
if std::ptr::eq(node, &B::TYPE_SCALAR_F16) {
return "bb.f16";
}
if std::ptr::eq(node, &B::TYPE_SCALAR_U8) {
return "bb.u8";
}
if std::ptr::eq(node, &B::TYPE_SCALAR_I32) {
return "bb.i32";
}
if std::ptr::eq(node, &B::TYPE_PEER_ID) {
return "bb.peer_id";
}
if std::ptr::eq(node, &B::TYPE_PEER_ID_VEC) {
return "bb.peer_id_vec";
}
if std::ptr::eq(node, &B::TYPE_TRIGGER) {
return "bb.trigger";
}
if std::ptr::eq(node, &B::TYPE_WIRE_REQ_ID) {
return "bb.wire_req_id";
}
""
}