use crate::backend::wgpu::types::NumericPrecision;
pub(crate) fn modulate_constellation_shader(
precision: NumericPrecision,
order: usize,
workgroup_size: u32,
) -> String {
let ty = match precision {
NumericPrecision::F64 => "f64",
NumericPrecision::F32 => "f32",
};
let max_val = match precision {
NumericPrecision::F64 => "1.7976931348623157e308",
NumericPrecision::F32 => "3.4028234663852886e38",
};
format!(
r#"
struct Tensor {{
data: array<{ty}>,
}};
struct ErrorState {{
state: atomic<u32>,
_pad0: u32,
_pad1: u32,
_pad2: u32,
}};
struct Params {{
len: u32,
}};
@group(0) @binding(0) var<storage, read> Symbols: Tensor;
@group(0) @binding(1) var<storage, read> Constellation: Tensor;
@group(0) @binding(2) var<storage, read_write> Out: Tensor;
@group(0) @binding(3) var<storage, read_write> Error: ErrorState;
@group(0) @binding(4) var<uniform> params: Params;
const ORDER: u32 = {order}u;
const EPSILON: {ty} = {epsilon};
const MAX_FINITE: {ty} = {ty}({max_val});
fn isfinite_scalar(x: {ty}) -> bool {{
return (x == x) && (abs(x) < MAX_FINITE);
}}
fn set_error(code: u32, index: u32) {{
let packed_index = min(index, 0x3ffffffeu);
let packed = (code << 30u) | packed_index;
atomicMin(&Error.state, packed);
}}
@compute @workgroup_size({workgroup_size}, 1, 1)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
let idx = gid.x;
if idx >= params.len {{
return;
}}
let raw = Symbols.data[idx];
if !isfinite_scalar(raw) {{
set_error(1u, idx);
Out.data[idx * 2u] = {ty}(0.0);
Out.data[idx * 2u + 1u] = {ty}(0.0);
return;
}}
if raw < {ty}(0.0) {{
set_error(3u, idx);
Out.data[idx * 2u] = {ty}(0.0);
Out.data[idx * 2u + 1u] = {ty}(0.0);
return;
}}
if raw > {ty}(ORDER - 1u) + {ty}(0.5) {{
set_error(2u, idx);
Out.data[idx * 2u] = {ty}(0.0);
Out.data[idx * 2u + 1u] = {ty}(0.0);
return;
}}
let rounded = round(raw);
if abs(rounded - raw) > EPSILON {{
set_error(3u, idx);
Out.data[idx * 2u] = {ty}(0.0);
Out.data[idx * 2u + 1u] = {ty}(0.0);
return;
}}
let symbol = u32(rounded);
if symbol >= ORDER {{
set_error(2u, idx);
Out.data[idx * 2u] = {ty}(0.0);
Out.data[idx * 2u + 1u] = {ty}(0.0);
return;
}}
let point = symbol * 2u;
Out.data[idx * 2u] = Constellation.data[point];
Out.data[idx * 2u + 1u] = Constellation.data[point + 1u];
}}
"#,
ty = ty,
order = order,
epsilon = match precision {
NumericPrecision::F64 => "1.0e-9",
NumericPrecision::F32 => "1.0e-5",
},
max_val = max_val,
workgroup_size = workgroup_size,
)
}
pub(crate) fn modulate_bits_constellation_shader(
precision: NumericPrecision,
order: usize,
workgroup_size: u32,
) -> String {
let ty = match precision {
NumericPrecision::F64 => "f64",
NumericPrecision::F32 => "f32",
};
let max_val = match precision {
NumericPrecision::F64 => "1.7976931348623157e308",
NumericPrecision::F32 => "3.4028234663852886e38",
};
let bit_tol = match precision {
NumericPrecision::F64 => "1.0e-9",
NumericPrecision::F32 => "1.0e-6",
};
format!(
r#"
struct Tensor {{
data: array<{ty}>,
}};
struct ErrorState {{
state: atomic<u32>,
_pad0: u32,
_pad1: u32,
_pad2: u32,
}};
struct Params {{
output_len: u32,
input_rows: u32,
output_rows: u32,
bits_per_symbol: u32,
}};
@group(0) @binding(0) var<storage, read> Bits: Tensor;
@group(0) @binding(1) var<storage, read> Constellation: Tensor;
@group(0) @binding(2) var<storage, read_write> Out: Tensor;
@group(0) @binding(3) var<storage, read_write> Error: ErrorState;
@group(0) @binding(4) var<uniform> params: Params;
const ORDER: u32 = {order}u;
const MAX_FINITE: {ty} = {ty}({max_val});
const BIT_TOL: {ty} = {ty}({bit_tol});
fn isfinite_scalar(x: {ty}) -> bool {{
return (x == x) && (abs(x) < MAX_FINITE);
}}
fn set_error(code: u32, index: u32) {{
let packed_index = min(index, 0x3ffffffeu);
let packed = (code << 30u) | packed_index;
atomicMin(&Error.state, packed);
}}
@compute @workgroup_size({workgroup_size}, 1, 1)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
let out_idx = gid.x;
if out_idx >= params.output_len {{
return;
}}
let channel = out_idx / params.output_rows;
let group = out_idx - channel * params.output_rows;
var symbol: u32 = 0u;
for (var bit_idx: u32 = 0u; bit_idx < params.bits_per_symbol; bit_idx = bit_idx + 1u) {{
let input_row = group * params.bits_per_symbol + bit_idx;
let input_idx = input_row + channel * params.input_rows;
let raw = Bits.data[input_idx];
if !isfinite_scalar(raw) {{
set_error(1u, input_idx);
return;
}}
let rounded = round(raw);
if abs(raw - rounded) > BIT_TOL || (rounded != {ty}(0.0) && rounded != {ty}(1.0)) {{
set_error(2u, input_idx);
return;
}}
symbol = (symbol << 1u) | u32(rounded);
}}
if symbol >= ORDER {{
set_error(3u, out_idx);
return;
}}
let point = symbol * 2u;
Out.data[out_idx * 2u] = Constellation.data[point];
Out.data[out_idx * 2u + 1u] = Constellation.data[point + 1u];
}}
"#,
ty = ty,
order = order,
max_val = max_val,
bit_tol = bit_tol,
workgroup_size = workgroup_size,
)
}