use std::collections::BTreeMap;
use crate::error::CompileError;
use crate::partition_by_wire_ops::WIRE_DOMAIN;
use bb_ir::peer_class::{home_class_of_node, HOME_CLASS_KEY, SELF_CLASS};
use bb_ir::proto::onnx::{
GraphProto, NodeProto, StringStringEntryProto, TypeProto, ValueInfoProto,
};
const SEND_OP: &str = "Send";
const RECV_OP: &str = "Recv";
pub const SYNTHESIZED_FROM_KEY: &str = "ai.bytesandbrains.synthesized_from_send";
pub fn synthesize_wire_recvs(graph: &mut GraphProto) -> Result<usize, CompileError> {
let snapshot: Vec<NodeProto> = graph.node.clone();
let mut sentinels: BTreeMap<usize, String> = BTreeMap::new();
let mut recvs_by_send: BTreeMap<usize, Vec<NodeProto>> = BTreeMap::new();
let mut rewrites: BTreeMap<(usize, String), String> = BTreeMap::new();
let mut next_seqno: usize = 0;
let mut synthesized = 0usize;
let mut new_value_info: Vec<ValueInfoProto> = Vec::new();
let lookup_denotation = |name: &str| -> Option<String> {
graph
.value_info
.iter()
.find(|v| v.name == name)
.and_then(|v| v.r#type.as_ref())
.map(|t| t.denotation.clone())
.filter(|d| !d.is_empty())
};
for (send_idx, original) in snapshot.iter().enumerate() {
if original.domain != WIRE_DOMAIN || original.op_type != SEND_OP {
continue;
}
let data_name = match original.output.first() {
Some(name) if !name.is_empty() => name.clone(),
_ => continue,
};
let mut by_class: BTreeMap<String, Vec<usize>> = BTreeMap::new();
for (other_idx, other) in snapshot.iter().enumerate() {
if other_idx == send_idx {
continue;
}
if other.input.iter().any(|n| n == &data_name) {
let class = home_class_of_node(other)
.map(str::to_string)
.unwrap_or_else(|| SELF_CLASS.to_string());
by_class.entry(class).or_default().push(other_idx);
}
}
if by_class.is_empty() {
continue;
}
let sentinel_name = format!("{data_name}__send_sentinel_{send_idx}");
sentinels.insert(send_idx, sentinel_name.clone());
let payload_denotation = lookup_denotation(&data_name);
if let Some(denot) = payload_denotation.as_deref() {
new_value_info.push(stamped_value_info(&sentinel_name, denot));
}
for (class, consumer_indices) in by_class {
let minted = format!("{data_name}__recv_{}", next_seqno);
let minted_sender = format!("{data_name}__recv_{}__sender", next_seqno);
next_seqno += 1;
synthesized += 1;
if let Some(denot) = payload_denotation.as_deref() {
new_value_info.push(stamped_value_info(&minted, denot));
}
new_value_info.push(stamped_value_info(&minted_sender, "bb.peer_id"));
recvs_by_send.entry(send_idx).or_default().push(NodeProto {
op_type: RECV_OP.into(),
domain: WIRE_DOMAIN.into(),
input: vec![],
output: vec![minted.clone(), minted_sender],
metadata_props: vec![
StringStringEntryProto {
key: HOME_CLASS_KEY.into(),
value: class,
},
StringStringEntryProto {
key: SYNTHESIZED_FROM_KEY.into(),
value: data_name.clone(),
},
],
..Default::default()
});
for consumer_idx in consumer_indices {
rewrites
.entry((consumer_idx, data_name.clone()))
.or_insert_with(|| minted.clone());
}
}
}
let mut emitted: Vec<NodeProto> = Vec::with_capacity(snapshot.len() + synthesized);
for (idx, n) in snapshot.iter().enumerate() {
let mut clone = n.clone();
if let Some(sentinel) = sentinels.get(&idx) {
clone.output[0] = sentinel.clone();
}
for input in clone.input.iter_mut() {
if let Some(replacement) = rewrites.get(&(idx, input.clone())) {
*input = replacement.clone();
}
}
emitted.push(clone);
if let Some(recvs) = recvs_by_send.remove(&idx) {
emitted.extend(recvs);
}
}
graph.node = emitted;
graph.value_info.extend(new_value_info);
Ok(synthesized)
}
fn stamped_value_info(name: &str, denotation: &str) -> ValueInfoProto {
ValueInfoProto {
name: name.to_string(),
r#type: Some(TypeProto {
denotation: denotation.to_string(),
..Default::default()
}),
..Default::default()
}
}