Skip to main content

bb_compiler/
insert_dedup_gate_rx.rs

1//! Compiler pass - pair every synthesized `wire::Recv` op with a
2//! downstream `DedupGateRx`. First in the RX gate chain (cheap drops
3//! for replays).
4//!
5//! The gate consumes the Recv's data output, hashes the wire bytes,
6//! and consults [`crate::framework::InboundDedup`]. On dup it returns
7//! an `OpError` whose `detail` carries `duplicate`; on first-arrival
8//! the value forwards polymorphically.
9//!
10//! Idempotent. Updates the Recv's
11//! `RX_CHAIN_HEAD_KEY` metadata so subsequent RX passes attach to the
12//! gate's output rather than the Recv's.
13
14use crate::error::CompileError;
15use crate::rx_chain::{rx_chain_head, set_rx_chain_head};
16use bb_ir::proto::onnx::{GraphProto, NodeProto, StringStringEntryProto};
17use bb_ir::syscall_ids::{OP_DEDUP_GATE_RX as GATE_OP_TYPE, SYSCALL_DOMAIN as GATE_DOMAIN};
18
19/// Idempotence stamp on the gated Recv.
20pub const GATED_KEY: &str = "ai.bytesandbrains.dedup_rx_gated";
21
22const WIRE_DOMAIN: &str = "ai.bytesandbrains.wire";
23const WIRE_RECV_OP: &str = "Recv";
24
25/// Insert a `DedupGateRx` after every synthesized `wire::Recv`.
26pub fn insert_dedup_gate_rx(sub_graph: &mut GraphProto) -> Result<(), CompileError> {
27    let recv_indices: Vec<usize> = sub_graph
28        .node
29        .iter()
30        .enumerate()
31        .filter_map(|(i, n)| (n.domain == WIRE_DOMAIN && n.op_type == WIRE_RECV_OP).then_some(i))
32        .collect();
33
34    let mut new_gates: Vec<NodeProto> = Vec::new();
35
36    for recv_idx in recv_indices {
37        if metadata_value(&sub_graph.node[recv_idx], GATED_KEY).is_some() {
38            continue;
39        }
40        let recv_name = sub_graph.node[recv_idx].name.clone();
41        let head = rx_chain_head(&sub_graph.node[recv_idx]);
42        let new_head = format!("{recv_name}#dedup_rx_out");
43
44        new_gates.push(build_gate_node(&recv_name, &head, &new_head));
45
46        // Rewire consumers that read from the prior head.
47        rewire_consumers(sub_graph, recv_idx, &head, &new_head);
48
49        set_metadata(
50            &mut sub_graph.node[recv_idx].metadata_props,
51            GATED_KEY,
52            "true",
53        );
54        set_rx_chain_head(&mut sub_graph.node[recv_idx], &new_head);
55    }
56
57    sub_graph.node.extend(new_gates);
58    Ok(())
59}
60
61fn build_gate_node(source_name: &str, input: &str, output: &str) -> NodeProto {
62    NodeProto {
63        op_type: GATE_OP_TYPE.to_string(),
64        domain: GATE_DOMAIN.to_string(),
65        name: format!("DedupGateRx@{source_name}"),
66        input: vec![input.to_string()],
67        output: vec![output.to_string()],
68        metadata_props: vec![StringStringEntryProto {
69            key: "ai.bytesandbrains.dedup_rx_source".to_string(),
70            value: source_name.to_string(),
71        }],
72        ..Default::default()
73    }
74}
75
76fn rewire_consumers(sub_graph: &mut GraphProto, recv_idx: usize, old_name: &str, new_name: &str) {
77    for (idx, node) in sub_graph.node.iter_mut().enumerate() {
78        if idx == recv_idx {
79            continue;
80        }
81        for inp in node.input.iter_mut() {
82            if inp == old_name {
83                *inp = new_name.to_string();
84            }
85        }
86    }
87}
88
89fn metadata_value(node: &NodeProto, key: &str) -> Option<String> {
90    node.metadata_props
91        .iter()
92        .find(|p| p.key == key)
93        .map(|p| p.value.clone())
94}
95
96fn set_metadata(props: &mut Vec<StringStringEntryProto>, key: &str, value: &str) {
97    if let Some(existing) = props.iter_mut().find(|p| p.key == key) {
98        existing.value = value.to_string();
99        return;
100    }
101    props.push(StringStringEntryProto {
102        key: key.to_string(),
103        value: value.to_string(),
104    });
105}
106