use runmat_accelerate_api::{
GpuTensorHandle, ProviderBitModulationRequest, ProviderModulationRequest,
};
use runmat_builtins::{ComplexTensor, LogicalArray, ResolveContext, Tensor, Type, Value};
use runmat_macros::runtime_builtin;
use crate::builtins::common::gpu_helpers;
use crate::builtins::common::spec::{
BroadcastSemantics, BuiltinGpuSpec, ConstantStrategy, GpuOpKind, ProviderHook, ReductionNaN,
ResidencyPolicy, ScalarType,
};
use crate::{build_runtime_error, BuiltinResult, RuntimeError};
const NAME: &str = "qammod";
const INTEGER_TOL: f64 = 1e-9;
#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::comms::qammod")]
pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
name: NAME,
op_kind: GpuOpKind::Custom("modulation"),
supported_precisions: &[ScalarType::F32, ScalarType::F64, ScalarType::I32, ScalarType::Bool],
broadcast: BroadcastSemantics::None,
provider_hooks: &[ProviderHook::Custom("qammod")],
constant_strategy: ConstantStrategy::InlineLiteral,
residency: ResidencyPolicy::NewHandle,
nan_mode: ReductionNaN::Include,
two_pass_threshold: None,
workgroup_size: None,
accepts_nan_mode: false,
notes: "Uses provider-side constellation modulation for integer and bit gpuArray inputs and returns complex-interleaved GPU storage; providers without the hook fall back to host-compatible validation and re-upload.",
};
fn qammod_error(message: impl Into<String>) -> RuntimeError {
build_runtime_error(message).with_builtin(NAME).build()
}
fn qammod_type(args: &[Type], _ctx: &ResolveContext) -> Type {
match args.first() {
Some(Type::Tensor { shape }) | Some(Type::Logical { shape }) => Type::Tensor {
shape: shape.clone(),
},
Some(Type::Num) | Some(Type::Int) | Some(Type::Bool) => Type::Tensor {
shape: Some(vec![Some(1), Some(1)]),
},
_ => Type::tensor(),
}
}
#[runtime_builtin(
name = "qammod",
category = "comms/modulation",
summary = "Map integer or bit symbols onto a QAM complex-baseband constellation.",
keywords = "qammod,qam,modulation,communications,gray,binary,gpu",
type_resolver(qammod_type),
builtin_path = "crate::builtins::comms::qammod"
)]
async fn qammod_builtin(x: Value, m: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
let order = parse_modulation_order(&m)?;
let options = ParsedOptions::parse(&rest, order)?;
match x {
Value::GpuTensor(handle) => qammod_gpu(handle, order, options).await,
other => {
let symbols = SymbolInput::from_value(other, order, options.input_type)?;
modulate_symbols(symbols, order, &options)
}
}
}
async fn qammod_gpu(
handle: GpuTensorHandle,
order: usize,
options: ParsedOptions,
) -> BuiltinResult<Value> {
let provider = runmat_accelerate_api::provider_for_handle(&handle)
.or_else(runmat_accelerate_api::provider)
.ok_or_else(|| qammod_error("qammod: no acceleration provider registered"))?;
let constellation = constellation_table(order, &options)?;
if runmat_accelerate_api::handle_is_logical(&handle)
&& !matches!(options.input_type, InputType::Bit)
{
return Err(qammod_error("qammod: logical X requires InputType='bit'"));
}
match options.input_type {
InputType::Integer => {
let request = ProviderModulationRequest {
input: &handle,
constellation: &constellation,
};
if let Ok(out) = provider.modulate_constellation(request).await {
return Ok(gpu_helpers::complex_gpu_value(out));
}
}
InputType::Bit => {
let input_rows = handle.shape.first().copied().unwrap_or(0);
if input_rows > 0 {
let request = ProviderBitModulationRequest {
input: &handle,
input_rows,
bits_per_symbol: bits_per_symbol(order)?,
constellation: &constellation,
};
if let Ok(out) = provider.modulate_bits_constellation(request).await {
return Ok(gpu_helpers::complex_gpu_value(out));
}
}
}
}
let tensor = gpu_helpers::gather_tensor_async(&handle).await?;
let symbols = SymbolInput::from_tensor(tensor, order, options.input_type)?;
let result = modulate_symbols(symbols, order, &options)?;
let Value::ComplexTensor(tensor) = result else {
return Err(qammod_error("qammod: expected complex modulation output"));
};
let out = gpu_helpers::upload_complex_tensor(provider, &tensor)
.map_err(|err| qammod_error(format!("qammod: {err}")))?;
Ok(gpu_helpers::complex_gpu_value(out))
}
#[derive(Clone, Debug)]
struct ParsedOptions {
mapping: SymbolMapping,
input_type: InputType,
unit_average_power: bool,
output_dtype: OutputDType,
}
#[derive(Clone, Debug)]
enum SymbolMapping {
Binary,
Gray,
Custom(Vec<usize>),
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum InputType {
Integer,
Bit,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum OutputDType {
Double,
Single,
}
impl ParsedOptions {
fn parse(args: &[Value], order: usize) -> BuiltinResult<Self> {
let mut mapping = SymbolMapping::Gray;
let mut input_type = InputType::Integer;
let mut unit_average_power = false;
let mut output_dtype = OutputDType::Double;
let mut idx = 0;
if let Some(first) = args.first() {
if !is_name_value_key(first) {
mapping = parse_symbol_mapping(first, order)?;
idx = 1;
}
}
while idx < args.len() {
let key = value_as_string(&args[idx])
.ok_or_else(|| qammod_error("qammod: expected name-value option name"))?;
let Some(value) = args.get(idx + 1) else {
return Err(qammod_error(format!(
"qammod: expected value after option '{key}'"
)));
};
match normalize_key(&key).as_str() {
"unitaveragepower" => {
unit_average_power = value_as_bool(value, "UnitAveragePower")?;
}
"plotconstellation" => {
if value_as_bool(value, "PlotConstellation")? {
return Err(qammod_error(
"qammod: PlotConstellation is not implemented in RunMat yet",
));
}
}
"outputdatatype" => {
let dtype = value_as_string(value).ok_or_else(|| {
qammod_error("qammod: OutputDataType must be 'double' or 'single'")
})?;
output_dtype = match dtype.trim().to_ascii_lowercase().as_str() {
"double" => OutputDType::Double,
"single" => OutputDType::Single,
_ => {
return Err(qammod_error(
"qammod: only OutputDataType 'double' and 'single' are accepted",
));
}
};
}
"inputtype" => {
let text = value_as_string(value)
.ok_or_else(|| qammod_error("qammod: InputType must be a string"))?;
input_type = match text.trim().to_ascii_lowercase().as_str() {
"integer" => InputType::Integer,
"bit" => InputType::Bit,
other => {
return Err(qammod_error(format!(
"qammod: unsupported InputType '{other}'"
)));
}
};
}
other => {
return Err(qammod_error(format!(
"qammod: unrecognised option '{other}'"
)));
}
}
idx += 2;
}
Ok(Self {
mapping,
input_type,
unit_average_power,
output_dtype,
})
}
}
#[derive(Debug)]
struct SymbolInput {
data: Vec<usize>,
shape: Vec<usize>,
}
impl SymbolInput {
fn from_value(value: Value, order: usize, input_type: InputType) -> BuiltinResult<Self> {
match value {
Value::Tensor(tensor) => Self::from_tensor(tensor, order, input_type),
Value::LogicalArray(logical) => Self::from_logical(logical, order, input_type),
Value::Num(n) => match input_type {
InputType::Integer => Ok(Self {
data: vec![number_to_symbol(n)?],
shape: vec![1, 1],
}),
InputType::Bit => bits_to_symbols(vec![number_to_bit(n, "X")?], vec![1, 1], order),
},
Value::Int(i) => {
let n = i.to_f64();
match input_type {
InputType::Integer => Ok(Self {
data: vec![number_to_symbol(n)?],
shape: vec![1, 1],
}),
InputType::Bit => {
bits_to_symbols(vec![number_to_bit(n, "X")?], vec![1, 1], order)
}
}
}
Value::Bool(b) => Ok(Self {
data: vec![usize::from(b)],
shape: vec![1, 1],
}),
Value::Complex(_, _) | Value::ComplexTensor(_) => {
Err(qammod_error("qammod: X must contain real integer symbols"))
}
Value::String(_) | Value::StringArray(_) | Value::CharArray(_) => {
Err(qammod_error("qammod: X must be numeric or logical"))
}
other => Err(qammod_error(format!(
"qammod: unsupported input type {other:?}"
))),
}
}
fn from_tensor(tensor: Tensor, order: usize, input_type: InputType) -> BuiltinResult<Self> {
match input_type {
InputType::Integer => {
let data = tensor
.data
.iter()
.map(|&value| number_to_symbol(value))
.collect::<BuiltinResult<Vec<_>>>()?;
Ok(Self {
data,
shape: tensor.shape,
})
}
InputType::Bit => {
let bits = tensor
.data
.iter()
.map(|&value| number_to_bit(value, "X"))
.collect::<BuiltinResult<Vec<_>>>()?;
bits_to_symbols(bits, tensor.shape, order)
}
}
}
fn from_logical(
logical: LogicalArray,
order: usize,
input_type: InputType,
) -> BuiltinResult<Self> {
match input_type {
InputType::Integer => Err(qammod_error("qammod: logical X requires InputType='bit'")),
InputType::Bit => bits_to_symbols(logical.data, logical.shape, order),
}
}
}
fn bits_to_symbols(
bits: Vec<u8>,
mut shape: Vec<usize>,
order: usize,
) -> BuiltinResult<SymbolInput> {
let bits_per_symbol = bits_per_symbol(order)?;
let rows = shape.first().copied().unwrap_or(bits.len());
if rows == 0 {
if shape.is_empty() {
shape.push(0);
shape.push(1);
} else {
shape[0] = 0;
}
return Ok(SymbolInput {
data: Vec::new(),
shape,
});
}
if !bits.len().is_multiple_of(rows) {
return Err(qammod_error("qammod: bit input shape is inconsistent"));
}
if rows % bits_per_symbol != 0 {
return Err(qammod_error(format!(
"qammod: number of bit rows must be a multiple of log2(M) ({bits_per_symbol})"
)));
}
let out_rows = rows / bits_per_symbol;
let channels = bits.len() / rows;
let mut data = Vec::with_capacity(out_rows * channels);
for channel in 0..channels {
let channel_offset = channel * rows;
for group in 0..out_rows {
let mut symbol = 0usize;
for bit_idx in 0..bits_per_symbol {
symbol = (symbol << 1)
| usize::from(bits[channel_offset + group * bits_per_symbol + bit_idx]);
}
data.push(symbol);
}
}
if shape.is_empty() {
shape.push(out_rows);
shape.push(1);
} else {
shape[0] = out_rows;
}
Ok(SymbolInput { data, shape })
}
fn modulate_symbols(
symbols: SymbolInput,
order: usize,
options: &ParsedOptions,
) -> BuiltinResult<Value> {
validate_symbol_range(&symbols.data, order)?;
let constellation = constellation_table(order, options)?;
let mut out = Vec::with_capacity(symbols.data.len());
for symbol in symbols.data {
let point = symbol * 2;
out.push((constellation[point], constellation[point + 1]));
}
let tensor =
ComplexTensor::new(out, symbols.shape).map_err(|e| qammod_error(format!("qammod: {e}")))?;
Ok(Value::ComplexTensor(tensor))
}
fn constellation_table(order: usize, options: &ParsedOptions) -> BuiltinResult<Vec<f64>> {
let layout = ConstellationLayout::for_order(order)?;
let scale = if options.unit_average_power {
1.0 / layout.average_power().sqrt()
} else {
1.0
};
let custom_inverse = match &options.mapping {
SymbolMapping::Custom(mapping) => Some(invert_mapping(mapping, order)?),
_ => None,
};
let mut table = Vec::with_capacity(order * 2);
for symbol in 0..order {
let point_index = match &options.mapping {
SymbolMapping::Binary => symbol,
SymbolMapping::Gray => layout.gray_point_index(symbol),
SymbolMapping::Custom(_) => custom_inverse
.as_ref()
.and_then(|inverse| inverse.get(symbol).copied())
.ok_or_else(|| qammod_error("qammod: invalid custom mapping"))?,
};
let (i_amp, q_amp) = layout.point_for_index(point_index);
let point = match options.output_dtype {
OutputDType::Double => (i_amp * scale, q_amp * scale),
OutputDType::Single => (
((i_amp * scale) as f32) as f64,
((q_amp * scale) as f32) as f64,
),
};
table.push(point.0);
table.push(point.1);
}
Ok(table)
}
#[derive(Clone, Debug)]
struct ConstellationLayout {
rows: usize,
cols: usize,
}
impl ConstellationLayout {
fn for_order(order: usize) -> BuiltinResult<Self> {
let bits = order.trailing_zeros() as usize;
let row_bits = bits / 2;
let col_bits = bits - row_bits;
Ok(Self {
rows: 1usize
.checked_shl(row_bits as u32)
.ok_or_else(|| qammod_error("qammod: modulation order is too large"))?,
cols: 1usize
.checked_shl(col_bits as u32)
.ok_or_else(|| qammod_error("qammod: modulation order is too large"))?,
})
}
fn gray_point_index(&self, symbol: usize) -> usize {
let row = symbol % self.rows;
let col = symbol / self.rows;
gray_encode(col) * self.rows + gray_encode(row)
}
fn point_for_index(&self, point_index: usize) -> (f64, f64) {
let row = point_index % self.rows;
let col = point_index / self.rows;
let i_amp = 2.0 * col as f64 - (self.cols as f64 - 1.0);
let q_amp = (self.rows as f64 - 1.0) - 2.0 * row as f64;
(i_amp, q_amp)
}
fn average_power(&self) -> f64 {
average_axis_power(self.rows) + average_axis_power(self.cols)
}
}
fn average_axis_power(points: usize) -> f64 {
if points <= 1 {
0.0
} else {
((points * points - 1) as f64) / 3.0
}
}
fn gray_encode(value: usize) -> usize {
value ^ (value >> 1)
}
fn parse_modulation_order(value: &Value) -> BuiltinResult<usize> {
let number = scalar_number(value, "M")?;
let order = number_to_symbol_with_name(number, "M")?;
if order < 2 || !order.is_power_of_two() {
return Err(qammod_error(
"qammod: M must be a positive power-of-two integer greater than 1",
));
}
Ok(order)
}
fn parse_symbol_mapping(value: &Value, order: usize) -> BuiltinResult<SymbolMapping> {
if let Some(text) = value_as_string(value) {
return match text.trim().to_ascii_lowercase().as_str() {
"gray" => Ok(SymbolMapping::Gray),
"bin" | "binary" => Ok(SymbolMapping::Binary),
other => Err(qammod_error(format!(
"qammod: unsupported symbol order '{other}'"
))),
};
}
let mapping = vector_to_symbols(value, "symOrder")?;
if mapping.len() != order {
return Err(qammod_error(format!(
"qammod: custom symbol order must contain exactly {order} elements"
)));
}
validate_custom_mapping(&mapping, order)?;
Ok(SymbolMapping::Custom(mapping))
}
fn validate_symbol_range(symbols: &[usize], order: usize) -> BuiltinResult<()> {
for &symbol in symbols {
if symbol >= order {
return Err(qammod_error(format!(
"qammod: input symbols must be in the range [0, {}]",
order - 1
)));
}
}
Ok(())
}
fn validate_custom_mapping(mapping: &[usize], order: usize) -> BuiltinResult<()> {
let mut seen = vec![false; order];
for &symbol in mapping {
if symbol >= order {
return Err(qammod_error(format!(
"qammod: custom symbol order values must be in the range [0, {}]",
order - 1
)));
}
if seen[symbol] {
return Err(qammod_error(
"qammod: custom symbol order values must be unique",
));
}
seen[symbol] = true;
}
Ok(())
}
fn invert_mapping(mapping: &[usize], order: usize) -> BuiltinResult<Vec<usize>> {
let mut inverse = vec![0usize; order];
for (point_index, &symbol) in mapping.iter().enumerate() {
inverse[symbol] = point_index;
}
Ok(inverse)
}
fn vector_to_symbols(value: &Value, name: &str) -> BuiltinResult<Vec<usize>> {
match value {
Value::Tensor(tensor) => tensor
.data
.iter()
.map(|&number| number_to_symbol_with_name(number, name))
.collect(),
Value::LogicalArray(logical) => Ok(logical.data.iter().map(|&v| usize::from(v)).collect()),
Value::Num(n) => Ok(vec![number_to_symbol_with_name(*n, name)?]),
Value::Int(i) => Ok(vec![number_to_symbol_with_name(i.to_f64(), name)?]),
Value::Bool(b) => Ok(vec![usize::from(*b)]),
other => Err(qammod_error(format!(
"qammod: {name} must be a numeric vector, got {other:?}"
))),
}
}
fn scalar_number(value: &Value, name: &str) -> BuiltinResult<f64> {
match value {
Value::Num(n) => Ok(*n),
Value::Int(i) => Ok(i.to_f64()),
Value::Bool(b) => Ok(if *b { 1.0 } else { 0.0 }),
Value::Tensor(tensor) if tensor.data.len() == 1 => Ok(tensor.data[0]),
Value::LogicalArray(logical) if logical.data.len() == 1 => {
Ok(if logical.data[0] != 0 { 1.0 } else { 0.0 })
}
_ => Err(qammod_error(format!("qammod: {name} must be a scalar"))),
}
}
fn number_to_symbol(value: f64) -> BuiltinResult<usize> {
number_to_symbol_with_name(value, "X")
}
fn bits_per_symbol(order: usize) -> BuiltinResult<usize> {
if order < 2 || !order.is_power_of_two() {
return Err(qammod_error(
"qammod: M must be a power-of-two integer for bit input",
));
}
Ok(order.trailing_zeros() as usize)
}
fn number_to_symbol_with_name(value: f64, name: &str) -> BuiltinResult<usize> {
if !value.is_finite() {
return Err(qammod_error(format!(
"qammod: {name} values must be finite integers"
)));
}
let rounded = value.round();
if (value - rounded).abs() > INTEGER_TOL || rounded < 0.0 {
return Err(qammod_error(format!(
"qammod: {name} values must be nonnegative integers"
)));
}
if rounded > usize::MAX as f64 {
return Err(qammod_error(format!(
"qammod: {name} value is too large for this platform"
)));
}
Ok(rounded as usize)
}
fn number_to_bit(value: f64, name: &str) -> BuiltinResult<u8> {
let symbol = number_to_symbol_with_name(value, name)?;
match symbol {
0 | 1 => Ok(symbol as u8),
_ => Err(qammod_error(format!(
"qammod: {name} bit values must be 0 or 1"
))),
}
}
fn value_as_bool(value: &Value, name: &str) -> BuiltinResult<bool> {
let n = scalar_number(value, name)?;
if (n - 0.0).abs() <= INTEGER_TOL {
Ok(false)
} else if (n - 1.0).abs() <= INTEGER_TOL {
Ok(true)
} else {
Err(qammod_error(format!(
"qammod: {name} must be logical or numeric 0/1"
)))
}
}
fn value_as_string(value: &Value) -> Option<String> {
match value {
Value::String(s) => Some(s.clone()),
Value::StringArray(sa) if sa.data.len() == 1 => Some(sa.data[0].clone()),
Value::CharArray(ca) if ca.rows == 1 => Some(ca.data.iter().collect()),
_ => None,
}
}
fn normalize_key(key: &str) -> String {
key.chars()
.filter(|c| *c != '_' && *c != '-' && !c.is_whitespace())
.flat_map(char::to_lowercase)
.collect()
}
fn is_name_value_key(value: &Value) -> bool {
let Some(text) = value_as_string(value) else {
return false;
};
matches!(
normalize_key(&text).as_str(),
"unitaveragepower" | "plotconstellation" | "outputdatatype" | "inputtype"
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::builtins::common::test_support;
use futures::executor::block_on;
fn qammod(x: Value, order: usize, rest: Vec<Value>) -> ComplexTensor {
match block_on(super::qammod_builtin(x, Value::Num(order as f64), rest)).expect("qammod") {
Value::ComplexTensor(tensor) => tensor,
other => panic!("expected ComplexTensor, got {other:?}"),
}
}
fn tensor(data: Vec<f64>, shape: Vec<usize>) -> Value {
Value::Tensor(Tensor::new(data, shape).expect("tensor"))
}
fn assert_complex_close(actual: &[(f64, f64)], expected: &[(f64, f64)]) {
assert_eq!(actual.len(), expected.len());
for (idx, ((ar, ai), (er, ei))) in actual.iter().zip(expected.iter()).enumerate() {
assert!(
(ar - er).abs() < 1e-12 && (ai - ei).abs() < 1e-12,
"at {idx}: expected ({er}, {ei}), got ({ar}, {ai})"
);
}
}
#[test]
fn qammod_4_gray_mapping() {
let out = qammod(tensor(vec![0.0, 1.0, 2.0, 3.0], vec![1, 4]), 4, vec![]);
assert_eq!(out.shape, vec![1, 4]);
assert_complex_close(
&out.data,
&[(-1.0, 1.0), (-1.0, -1.0), (1.0, 1.0), (1.0, -1.0)],
);
}
#[test]
fn qammod_16_gray_mapping() {
let input: Vec<f64> = (0..16).map(|v| v as f64).collect();
let out = qammod(tensor(input, vec![1, 16]), 16, vec![]);
assert_complex_close(
&out.data,
&[
(-3.0, 3.0),
(-3.0, 1.0),
(-3.0, -3.0),
(-3.0, -1.0),
(-1.0, 3.0),
(-1.0, 1.0),
(-1.0, -3.0),
(-1.0, -1.0),
(3.0, 3.0),
(3.0, 1.0),
(3.0, -3.0),
(3.0, -1.0),
(1.0, 3.0),
(1.0, 1.0),
(1.0, -3.0),
(1.0, -1.0),
],
);
}
#[test]
fn qammod_64_gray_has_expected_edges() {
let out = qammod(
tensor(vec![0.0, 1.0, 2.0, 3.0, 63.0], vec![1, 5]),
64,
vec![],
);
assert_complex_close(
&out.data,
&[
(-7.0, 7.0),
(-7.0, 5.0),
(-7.0, 1.0),
(-7.0, 3.0),
(1.0, -1.0),
],
);
}
#[test]
fn qammod_binary_mapping_is_sequential_columnwise() {
let out = qammod(
tensor(vec![0.0, 1.0, 2.0, 3.0, 4.0], vec![1, 5]),
16,
vec![Value::from("bin")],
);
assert_complex_close(
&out.data,
&[
(-3.0, 3.0),
(-3.0, 1.0),
(-3.0, -1.0),
(-3.0, -3.0),
(-1.0, 3.0),
],
);
}
#[test]
fn qammod_bit_input_groups_rows_by_channel() {
let out = qammod(
tensor(
vec![
0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, ],
vec![4, 2],
),
16,
vec![Value::from("InputType"), Value::from("bit")],
);
assert_eq!(out.shape, vec![1, 2]);
assert_complex_close(&out.data, &[(-3.0, 3.0), (1.0, -1.0)]);
}
#[test]
fn qammod_bit_input_honors_binary_mapping() {
let out = qammod(
tensor(vec![0.0, 0.0, 1.0, 1.0], vec![4, 1]),
16,
vec![
Value::from("bin"),
Value::from("InputType"),
Value::from("bit"),
],
);
assert_eq!(out.shape, vec![1, 1]);
assert_complex_close(&out.data, &[(-3.0, -3.0)]);
}
#[test]
fn qammod_logical_bit_input_preserves_channels() {
let logical = LogicalArray::new(vec![0, 0, 1, 1], vec![2, 2]).expect("logical");
let out = qammod(
Value::LogicalArray(logical),
4,
vec![Value::from("InputType"), Value::from("bit")],
);
assert_eq!(out.shape, vec![1, 2]);
assert_complex_close(&out.data, &[(-1.0, 1.0), (1.0, -1.0)]);
}
#[test]
fn qammod_inprocess_gpu_bit_input_returns_complex_resident_output() {
use runmat_accelerate_api::{GpuTensorStorage, HostTensorView};
test_support::with_test_provider(|provider| {
let input = Tensor::new(vec![0.0, 0.0, 1.0, 1.0], vec![4, 1]).unwrap();
let expected = qammod(
Value::Tensor(input.clone()),
16,
vec![Value::from("InputType"), Value::from("bit")],
);
let input_handle = provider
.upload(&HostTensorView {
data: &input.data,
shape: &input.shape,
})
.expect("upload");
let result = block_on(super::qammod_builtin(
Value::GpuTensor(input_handle.clone()),
Value::Num(16.0),
vec![Value::from("InputType"), Value::from("bit")],
))
.expect("gpu qammod bit");
let Value::GpuTensor(output_handle) = result else {
panic!("expected resident gpu output");
};
assert_eq!(
runmat_accelerate_api::handle_storage(&output_handle),
GpuTensorStorage::ComplexInterleaved
);
let gathered = block_on(crate::dispatcher::gather_if_needed_async(
&Value::GpuTensor(output_handle.clone()),
))
.expect("gather output");
let Value::ComplexTensor(actual) = gathered else {
panic!("expected gathered complex tensor");
};
assert_eq!(actual.shape, expected.shape);
assert_complex_close(&actual.data, &expected.data);
provider.free(&input_handle).ok();
provider.free(&output_handle).ok();
});
}
#[test]
fn qammod_bit_input_rejects_partial_symbols() {
let err = block_on(super::qammod_builtin(
tensor(vec![0.0, 1.0, 0.0], vec![3, 1]),
Value::Num(16.0),
vec![Value::from("InputType"), Value::from("bit")],
))
.expect_err("expected bit grouping error");
assert!(err.to_string().contains("multiple of log2(M)"));
}
#[test]
fn qammod_bit_input_rejects_non_bits() {
let err = block_on(super::qammod_builtin(
tensor(vec![0.0, 2.0, 0.0, 1.0], vec![4, 1]),
Value::Num(16.0),
vec![Value::from("InputType"), Value::from("bit")],
))
.expect_err("expected bit value error");
assert!(err.to_string().contains("bit values must be 0 or 1"));
}
#[test]
fn qammod_custom_mapping_uses_columnwise_symbol_order() {
let mapping = tensor(vec![0.0, 3.0, 1.0, 2.0], vec![1, 4]);
let out = qammod(
tensor(vec![0.0, 1.0, 2.0, 3.0], vec![1, 4]),
4,
vec![mapping],
);
assert_complex_close(
&out.data,
&[(-1.0, 1.0), (1.0, 1.0), (1.0, -1.0), (-1.0, -1.0)],
);
}
#[test]
fn qammod_unit_average_power_normalizes_constellation() {
let input: Vec<f64> = (0..16).map(|v| v as f64).collect();
let out = qammod(
tensor(input, vec![1, 16]),
16,
vec![Value::from("UnitAveragePower"), Value::Bool(true)],
);
let avg_power = out
.data
.iter()
.map(|(re, im)| re * re + im * im)
.sum::<f64>()
/ out.data.len() as f64;
assert!((avg_power - 1.0).abs() < 1e-12, "{avg_power}");
}
#[test]
fn qammod_output_datatype_single_rounds_samples() {
let out = qammod(
tensor(vec![0.0], vec![1, 1]),
16,
vec![
Value::from("UnitAveragePower"),
Value::Bool(true),
Value::from("OutputDataType"),
Value::from("single"),
],
);
let expected = (-3.0 / 10.0_f64.sqrt()) as f32 as f64;
assert_eq!(out.data[0].0, expected);
assert_eq!(out.data[0].1, -expected);
}
#[test]
fn qammod_rejects_out_of_range_symbols() {
let err = block_on(super::qammod_builtin(
tensor(vec![0.0, 16.0], vec![1, 2]),
Value::Num(16.0),
vec![],
))
.expect_err("expected out-of-range error");
assert!(err.to_string().contains("range [0, 15]"));
}
#[test]
#[cfg(feature = "wgpu")]
fn qammod_wgpu_input_returns_complex_resident_output() {
use runmat_accelerate_api::{AccelProvider, GpuTensorStorage, HostTensorView};
match runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
) {
Ok(provider) => {
let input = Tensor::new(vec![0.0, 1.0, 2.0, 3.0], vec![1, 4]).unwrap();
let expected = qammod(Value::Tensor(input.clone()), 4, vec![]);
let view = HostTensorView {
data: &input.data,
shape: &input.shape,
};
let input_handle = provider.upload(&view).expect("upload");
let result = block_on(super::qammod_builtin(
Value::GpuTensor(input_handle.clone()),
Value::Num(4.0),
vec![],
))
.expect("gpu qammod");
let Value::GpuTensor(output_handle) = result else {
panic!("expected resident gpu output");
};
assert_eq!(
runmat_accelerate_api::handle_storage(&output_handle),
GpuTensorStorage::ComplexInterleaved
);
let gathered = block_on(crate::dispatcher::gather_if_needed_async(
&Value::GpuTensor(output_handle.clone()),
))
.expect("gather output");
let Value::ComplexTensor(actual) = gathered else {
panic!("expected gathered complex tensor");
};
assert_eq!(actual.shape, expected.shape);
assert_complex_close(&actual.data, &expected.data);
provider.free(&input_handle).ok();
provider.free(&output_handle).ok();
}
Err(err) => {
tracing::warn!("Skipping qammod_wgpu_input_returns_complex_resident_output: {err}");
}
}
runmat_accelerate::simple_provider::register_inprocess_provider();
}
#[test]
#[cfg(feature = "wgpu")]
fn qammod_wgpu_bit_input_returns_complex_resident_output() {
use runmat_accelerate_api::{AccelProvider, GpuTensorStorage, HostTensorView};
match runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
) {
Ok(provider) => {
let input = Tensor::new(vec![0.0, 0.0, 1.0, 1.0], vec![4, 1]).unwrap();
let expected = qammod(
Value::Tensor(input.clone()),
16,
vec![Value::from("InputType"), Value::from("bit")],
);
let view = HostTensorView {
data: &input.data,
shape: &input.shape,
};
let input_handle = provider.upload(&view).expect("upload");
let result = block_on(super::qammod_builtin(
Value::GpuTensor(input_handle.clone()),
Value::Num(16.0),
vec![Value::from("InputType"), Value::from("bit")],
))
.expect("gpu qammod bit");
let Value::GpuTensor(output_handle) = result else {
panic!("expected resident gpu output");
};
assert_eq!(
runmat_accelerate_api::handle_storage(&output_handle),
GpuTensorStorage::ComplexInterleaved
);
let gathered = block_on(crate::dispatcher::gather_if_needed_async(
&Value::GpuTensor(output_handle.clone()),
))
.expect("gather output");
let Value::ComplexTensor(actual) = gathered else {
panic!("expected gathered complex tensor");
};
assert_eq!(actual.shape, expected.shape);
assert_complex_close(&actual.data, &expected.data);
provider.free(&input_handle).ok();
provider.free(&output_handle).ok();
}
Err(err) => {
tracing::warn!(
"Skipping qammod_wgpu_bit_input_returns_complex_resident_output: {err}"
);
}
}
runmat_accelerate::simple_provider::register_inprocess_provider();
}
}