use super::{put_expr, put_nodes};
use crate::serial::wire::encode::WireEncodeErr;
use crate::serial::wire::framing::{put_len_u32, put_string, put_u32, put_u8};
use crate::serial::wire::{Node, MAX_OPAQUE_PAYLOAD_LEN};
#[inline]
#[must_use]
#[expect(
clippy::too_many_lines,
reason = "wire discriminant table is an ABI contract and must remain auditable in one encoder"
)]
pub fn put_node(out: &mut Vec<u8>, node: &Node) -> Result<(), WireEncodeErr> {
match node {
Node::Let { name, value } => {
put_u8(out, 0);
put_string(out, name.as_str())?;
put_expr(out, value)?;
}
Node::Assign { name, value } => {
put_u8(out, 1);
put_string(out, name.as_str())?;
put_expr(out, value)?;
}
Node::Store {
buffer,
index,
value,
} => {
put_u8(out, 2);
put_string(out, buffer.as_str())?;
put_expr(out, index)?;
put_expr(out, value)?;
}
Node::If {
cond,
then,
otherwise,
} => {
put_u8(out, 3);
put_expr(out, cond)?;
put_nodes(out, then)?;
put_nodes(out, otherwise)?;
}
Node::Loop {
var,
from,
to,
body,
} => {
put_u8(out, 4);
put_string(out, var.as_str())?;
put_expr(out, from)?;
put_expr(out, to)?;
put_nodes(out, body)?;
}
Node::Return => put_u8(out, 5),
Node::Block(nodes) => {
put_u8(out, 6);
put_nodes(out, nodes)?;
}
Node::Barrier { ordering } => {
put_u8(out, 7);
put_u8(out, ordering.wire_tag());
}
Node::IndirectDispatch {
count_buffer,
count_offset,
} => {
put_u8(out, 8);
put_string(out, count_buffer.as_str())?;
out.extend_from_slice(&count_offset.to_le_bytes());
}
Node::AsyncLoad {
source,
destination,
offset,
size,
tag,
} => {
put_u8(out, 9);
put_string(out, source.as_str())?;
put_string(out, destination.as_str())?;
put_expr(out, offset)?;
put_expr(out, size)?;
put_string(out, tag.as_str())?;
}
Node::AsyncStore {
source,
destination,
offset,
size,
tag,
} => {
put_u8(out, 12);
put_string(out, source.as_str())?;
put_string(out, destination.as_str())?;
put_expr(out, offset)?;
put_expr(out, size)?;
put_string(out, tag.as_str())?;
}
Node::Trap { address, tag } => {
put_u8(out, 13);
put_expr(out, address)?;
put_string(out, tag.as_str())?;
}
Node::Resume { tag } => {
put_u8(out, 14);
put_string(out, tag.as_str())?;
}
Node::AllReduce { buffer, op, group } => {
put_u8(out, 15);
put_string(out, buffer.as_str())?;
put_u8(out, op.builtin_wire_tag());
put_u32(out, group.as_u32());
}
Node::AllGather {
input,
output,
group,
} => {
put_u8(out, 16);
put_string(out, input.as_str())?;
put_string(out, output.as_str())?;
put_u32(out, group.as_u32());
}
Node::ReduceScatter {
input,
output,
op,
group,
} => {
put_u8(out, 17);
put_string(out, input.as_str())?;
put_string(out, output.as_str())?;
put_u8(out, op.builtin_wire_tag());
put_u32(out, group.as_u32());
}
Node::Broadcast {
buffer,
root,
group,
} => {
put_u8(out, 18);
put_string(out, buffer.as_str())?;
put_u32(out, *root);
put_u32(out, group.as_u32());
}
Node::AsyncWait { tag } => {
put_u8(out, 10);
put_string(out, tag.as_str())?;
}
Node::Region {
generator,
source_region,
body,
} => {
put_u8(out, 11);
put_string(out, generator.as_str())?;
match source_region {
Some(region) => {
put_u8(out, 1);
put_string(out, region.name.as_str())?;
}
None => put_u8(out, 0),
}
put_nodes(out, body)?;
}
Node::Opaque(extension) => {
put_u8(out, 0x80);
put_string(out, extension.extension_kind())?;
let payload = extension.wire_payload();
if payload.len() > MAX_OPAQUE_PAYLOAD_LEN {
return Err(WireEncodeErr::fmt_usize(
"opaque node payload",
payload.len(),
&format!(" exceeds {MAX_OPAQUE_PAYLOAD_LEN}. Fix: split the payload across multiple opaque nodes or reduce the extension data size."),
));
}
put_len_u32(out, payload.len(), "opaque node payload length")?;
out.extend_from_slice(&payload);
}
}
Ok(())
}