use crate::serial::wire::encode::WireEncodeErr;
use crate::serial::wire::framing::{put_len_u32, put_string, put_u32, put_u8};
use crate::serial::wire::tags::{atomic_op_tag, bin_op_tag, put_data_type, un_op_tag};
use crate::serial::wire::{Expr, MAX_OPAQUE_PAYLOAD_LEN};
#[inline]
#[must_use]
#[expect(
clippy::too_many_lines,
reason = "wire discriminant table is an ABI contract and must remain auditable in one encoder"
)]
pub fn put_expr(out: &mut Vec<u8>, expr: &Expr) -> Result<(), WireEncodeErr> {
enum Step<'e> {
Encode(&'e Expr),
WriteByte(u8),
}
let mut stack: Vec<Step<'_>> = Vec::with_capacity(16);
stack.push(Step::Encode(expr));
while let Some(step) = stack.pop() {
match step {
Step::WriteByte(b) => put_u8(out, b),
Step::Encode(expr) => match expr {
Expr::LitU32(value) => {
put_u8(out, 0);
put_u32(out, *value);
}
Expr::LitI32(value) => {
put_u8(out, 1);
put_u32(out, u32::from_le_bytes(value.to_le_bytes()));
}
Expr::LitBool(value) => {
put_u8(out, 2);
put_u8(out, u8::from(*value));
}
Expr::LitF32(value) => {
put_u8(out, 15);
put_u32(out, canonical_f32_bits(*value));
}
Expr::Var(name) => {
put_u8(out, 3);
put_string(out, name)?;
}
Expr::Load { buffer, index } => {
put_u8(out, 4);
put_string(out, buffer)?;
stack.push(Step::Encode(index));
}
Expr::BufLen { buffer } => {
put_u8(out, 5);
put_string(out, buffer)?;
}
Expr::InvocationId { axis } => {
put_u8(out, 6);
put_u8(out, *axis);
}
Expr::WorkgroupId { axis } => {
put_u8(out, 7);
put_u8(out, *axis);
}
Expr::LocalId { axis } => {
put_u8(out, 8);
put_u8(out, *axis);
}
Expr::BinOp { op, left, right } => {
put_u8(out, 9);
if let crate::ir::BinOp::Opaque(id) = op {
put_u8(out, 0x80);
put_u32(out, id.as_u32());
} else {
put_u8(out, bin_op_tag(*op)?);
}
stack.push(Step::Encode(right));
stack.push(Step::Encode(left));
}
Expr::UnOp { op, operand } => {
put_u8(out, 10);
if let crate::ir::UnOp::Opaque(id) = op {
put_u8(out, 0x80);
put_u32(out, id.as_u32());
} else {
put_u8(out, un_op_tag(op)?);
}
stack.push(Step::Encode(operand));
}
Expr::Call { op_id, args } => {
put_u8(out, 11);
put_string(out, op_id.as_str())?;
put_len_u32(out, args.len(), "call argument count")?;
for arg in args.iter().rev() {
stack.push(Step::Encode(arg));
}
}
Expr::Select {
cond,
true_val,
false_val,
} => {
put_u8(out, 12);
stack.push(Step::Encode(false_val));
stack.push(Step::Encode(true_val));
stack.push(Step::Encode(cond));
}
Expr::Cast { target, value } => {
put_u8(out, 13);
put_data_type(out, target)?;
stack.push(Step::Encode(value));
}
Expr::Fma { a, b, c } => {
put_u8(out, 16);
stack.push(Step::Encode(c));
stack.push(Step::Encode(b));
stack.push(Step::Encode(a));
}
Expr::Atomic {
op,
buffer,
index,
expected,
value,
ordering,
} => {
put_u8(out, 14);
if let crate::ir::AtomicOp::Opaque(id) = op {
put_u8(out, 0x80);
put_u32(out, id.as_u32());
} else {
put_u8(out, atomic_op_tag(*op)?);
}
put_u8(out, ordering.wire_tag());
put_string(out, buffer)?;
stack.push(Step::Encode(value));
match expected {
Some(expected_expr) => {
stack.push(Step::Encode(expected_expr));
stack.push(Step::WriteByte(1));
}
None => {
stack.push(Step::WriteByte(0));
}
}
stack.push(Step::Encode(index));
}
Expr::SubgroupAdd { value } => {
put_u8(out, 17);
stack.push(Step::Encode(value));
}
Expr::SubgroupShuffle { value, lane } => {
put_u8(out, 18);
stack.push(Step::Encode(lane));
stack.push(Step::Encode(value));
}
Expr::SubgroupBallot { cond } => {
put_u8(out, 19);
stack.push(Step::Encode(cond));
}
Expr::SubgroupLocalId => {
put_u8(out, 20);
}
Expr::SubgroupSize => {
put_u8(out, 21);
}
Expr::Opaque(extension) => {
put_u8(out, 0x80);
put_string(out, extension.extension_kind())?;
let payload = extension.wire_payload();
if payload.len() > MAX_OPAQUE_PAYLOAD_LEN {
return Err(WireEncodeErr::fmt_usize(
"opaque expression payload",
payload.len(),
&format!(" exceeds {MAX_OPAQUE_PAYLOAD_LEN}. Fix: split the payload across multiple opaque expressions or reduce the extension data size."),
));
}
put_len_u32(out, payload.len(), "opaque expression payload length")?;
out.extend_from_slice(&payload);
}
},
}
}
Ok(())
}
#[inline]
fn canonical_f32_bits(value: f32) -> u32 {
if value.is_nan() {
return 0x7FC0_0000;
}
if value.is_subnormal() {
return 0.0f32.to_bits();
}
let bits = value.to_bits();
if bits == (-0.0f32).to_bits() {
0.0f32.to_bits()
} else {
bits
}
}