use super::prelude::*;
use onnx_ir::ir::ArgType;
impl NodeCodegen for onnx_ir::node::constant_of_shape::ConstantOfShapeNode {
fn inputs(&self) -> &[Argument] {
&self.inputs
}
fn outputs(&self) -> &[Argument] {
&self.outputs
}
fn forward(&self, _scope: &mut super::super::scope::ScopeAtPosition<'_>) -> TokenStream {
let output = arg_to_ident(self.outputs.first().unwrap());
let value = if let Some(tensor_data) = &self.config.value {
match tensor_data.dtype {
onnx_ir::ir::DType::F32 => {
let val = tensor_data.as_slice::<f32>().unwrap()[0];
super::super::codegen::f32_to_tokens(val)
}
onnx_ir::ir::DType::F64 => {
let val = tensor_data.as_slice::<f64>().unwrap()[0];
super::super::codegen::f64_to_tokens(val)
}
onnx_ir::ir::DType::I32 => {
let val = tensor_data.as_slice::<i32>().unwrap()[0];
quote! { #val }
}
onnx_ir::ir::DType::I64 => {
let val = tensor_data.as_slice::<i64>().unwrap()[0];
quote! { #val }
}
onnx_ir::ir::DType::Bool(_) => {
let val = tensor_data.as_slice::<bool>().unwrap()[0];
quote! { #val }
}
_ => quote! { 0.0f32 }, }
} else {
quote! { 0.0f32 } };
let shape_expr = match &self.config.shape {
onnx_ir::node::constant_of_shape::ConstantOfShapeShape::Static(static_shape) => {
let shape_values = static_shape.iter().map(|v| {
let val = *v as usize;
quote! { #val }
});
quote! { [#(#shape_values),*] }
}
onnx_ir::node::constant_of_shape::ConstantOfShapeShape::Runtime(runtime_ref) => {
let arg = &self.inputs[runtime_ref.input_index];
let input_name = arg_to_ident(arg);
quote! { #input_name }
}
};
match &self.outputs[0].ty {
ArgType::ScalarNative(_) | ArgType::ScalarTensor(_) => {
quote! {
let #output = #value;
}
}
ArgType::Tensor(tensor) => {
let output_rank = tensor.rank.to_tokens();
let is_bool_value = if let Some(tensor_data) = &self.config.value {
tensor_data.dtype.is_bool()
} else {
false
};
if is_bool_value {
let bool_val = if let Some(tensor_data) = &self.config.value {
tensor_data.as_slice::<bool>().unwrap()[0]
} else {
false
};
if bool_val {
quote! {
let #output = Tensor::<B, #output_rank, Int>::ones(#shape_expr, &self.device).bool();
}
} else {
quote! {
let #output = Tensor::<B, #output_rank, Int>::zeros(#shape_expr, &self.device).bool();
}
}
} else {
let dtype_tokens = tensor.dtype.to_tokens();
let ones: Vec<_> = (0..tensor.rank).map(|_| quote! { 1 }).collect();
if tensor.dtype.is_float() {
quote! {
let #output = Tensor::<B, 1>::from_data(
burn::tensor::TensorData::from([#value as f64]),
(&self.device, #dtype_tokens)
).reshape([#(#ones),*]).expand(#shape_expr);
}
} else {
quote! {
let #output = Tensor::<B, 1, Int>::from_data(
burn::tensor::TensorData::from([#value as i64]),
(&self.device, #dtype_tokens)
).reshape([#(#ones),*]).expand(#shape_expr);
}
}
}
}
ArgType::Shape(size) => {
let size_val = *size;
let values = std::iter::repeat_n(value.clone(), size_val);
quote! {
let #output: [i64; #size_val] = [#(#values),*];
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::super::test_helpers::*;
use insta::assert_snapshot;
use onnx_ir::ir::{BoolStore, DType, RuntimeInputRef, TensorData};
use onnx_ir::node::constant_of_shape::{
ConstantOfShapeConfig, ConstantOfShapeNodeBuilder, ConstantOfShapeShape,
};
#[test]
fn test_constant_of_shape_scalar_f32() {
let config = ConstantOfShapeConfig {
shape: ConstantOfShapeShape::Static(vec![]),
value: Some(TensorData::new(vec![3.14f32], [0usize; 0])),
};
let node = ConstantOfShapeNodeBuilder::new("const1")
.input_shape("dims", 1)
.output_scalar("result", DType::F32)
.config(config)
.build();
let code = codegen_forward_default(&node);
assert_snapshot!(code, @r"
pub fn forward(&self, dims: [i64; 1]) -> f32 {
let result = 3.14f32;
result
}
");
}
#[test]
fn test_constant_of_shape_scalar_f64() {
let config = ConstantOfShapeConfig {
shape: ConstantOfShapeShape::Static(vec![]),
value: Some(TensorData::new(vec![2.718f64], [0usize; 0])),
};
let node = ConstantOfShapeNodeBuilder::new("const1")
.input_shape("shape_in", 1)
.output_scalar("value", DType::F64)
.config(config)
.build();
let code = codegen_forward_default(&node);
assert_snapshot!(code, @r"
pub fn forward(&self, shape_in: [i64; 1]) -> f64 {
let value = 2.718f64;
value
}
");
}
#[test]
fn test_constant_of_shape_scalar_i32() {
let config = ConstantOfShapeConfig {
shape: ConstantOfShapeShape::Static(vec![]),
value: Some(TensorData::new(vec![42i32], [0usize; 0])),
};
let node = ConstantOfShapeNodeBuilder::new("const1")
.input_shape("s", 1)
.output_scalar("num", DType::I32)
.config(config)
.build();
let code = codegen_forward_default(&node);
assert_snapshot!(code, @r"
pub fn forward(&self, s: [i64; 1]) -> i32 {
let num = 42i32;
num
}
");
}
#[test]
fn test_constant_of_shape_scalar_i64() {
let config = ConstantOfShapeConfig {
shape: ConstantOfShapeShape::Static(vec![]),
value: Some(TensorData::new(vec![999i64], [0usize; 0])),
};
let node = ConstantOfShapeNodeBuilder::new("const1")
.input_shape("shape_data", 1)
.output_scalar("output", DType::I64)
.config(config)
.build();
let code = codegen_forward_default(&node);
assert_snapshot!(code, @r"
pub fn forward(&self, shape_data: [i64; 1]) -> i64 {
let output = 999i64;
output
}
");
}
#[test]
fn test_constant_of_shape_scalar_bool() {
let config = ConstantOfShapeConfig {
shape: ConstantOfShapeShape::Static(vec![]),
value: Some(TensorData::new(vec![true], [0usize; 0])),
};
let node = ConstantOfShapeNodeBuilder::new("const1")
.input_shape("shape_vec", 1)
.output_scalar("flag", DType::Bool(BoolStore::Native))
.config(config)
.build();
let code = codegen_forward_default(&node);
assert_snapshot!(code, @r"
pub fn forward(&self, shape_vec: [i64; 1]) -> bool {
let flag = true;
flag
}
");
}
#[test]
fn test_constant_of_shape_scalar_default() {
let config = ConstantOfShapeConfig {
shape: ConstantOfShapeShape::Static(vec![]),
value: None, };
let node = ConstantOfShapeNodeBuilder::new("const1")
.input_shape("s", 1)
.output_scalar("zero", DType::F32)
.config(config)
.build();
let code = codegen_forward_default(&node);
assert_snapshot!(code, @r"
pub fn forward(&self, s: [i64; 1]) -> f32 {
let zero = 0.0f32;
zero
}
");
}
#[test]
fn test_constant_of_shape_tensor_f32_static() {
let config = ConstantOfShapeConfig {
shape: ConstantOfShapeShape::Static(vec![2, 3, 4]),
value: Some(TensorData::new(vec![1.5f32], [0usize; 0])),
};
let node = ConstantOfShapeNodeBuilder::new("const1")
.input_shape("target_shape", 3)
.output_tensor("filled", 3, DType::F32)
.config(config)
.build();
let code = codegen_forward_default(&node);
assert_snapshot!(code, @r"
pub fn forward(&self, target_shape: [i64; 3]) -> Tensor<B, 3> {
let filled = Tensor::<
B,
1,
>::from_data(
burn::tensor::TensorData::from([1.5f32 as f64]),
(&self.device, burn::tensor::DType::F32),
)
.reshape([1, 1, 1])
.expand([2usize, 3usize, 4usize]);
filled
}
");
}
#[test]
fn test_constant_of_shape_tensor_f64_static() {
let config = ConstantOfShapeConfig {
shape: ConstantOfShapeShape::Static(vec![10, 20]),
value: Some(TensorData::new(vec![0.5f64], [0usize; 0])),
};
let node = ConstantOfShapeNodeBuilder::new("const1")
.input_shape("dims", 2)
.output_tensor("matrix", 2, DType::F64)
.config(config)
.build();
let code = codegen_forward_default(&node);
assert_snapshot!(code, @r"
pub fn forward(&self, dims: [i64; 2]) -> Tensor<B, 2> {
let matrix = Tensor::<
B,
1,
>::from_data(
burn::tensor::TensorData::from([0.5f64 as f64]),
(&self.device, burn::tensor::DType::F64),
)
.reshape([1, 1])
.expand([10usize, 20usize]);
matrix
}
");
}
#[test]
fn test_constant_of_shape_tensor_i32_static() {
let config = ConstantOfShapeConfig {
shape: ConstantOfShapeShape::Static(vec![5, 5]),
value: Some(TensorData::new(vec![7i32], [0usize; 0])),
};
let node = ConstantOfShapeNodeBuilder::new("const1")
.input_shape("size", 2)
.output_tensor("grid", 2, DType::I32)
.config(config)
.build();
let code = codegen_forward_default(&node);
assert_snapshot!(code, @r"
pub fn forward(&self, size: [i64; 2]) -> Tensor<B, 2, Int> {
let grid = Tensor::<
B,
1,
Int,
>::from_data(
burn::tensor::TensorData::from([7i32 as i64]),
(&self.device, burn::tensor::DType::I32),
)
.reshape([1, 1])
.expand([5usize, 5usize]);
grid
}
");
}
#[test]
fn test_constant_of_shape_tensor_i64_static() {
let config = ConstantOfShapeConfig {
shape: ConstantOfShapeShape::Static(vec![8]),
value: Some(TensorData::new(vec![100i64], [0usize; 0])),
};
let node = ConstantOfShapeNodeBuilder::new("const1")
.input_shape("length", 1)
.output_tensor("vector", 1, DType::I64)
.config(config)
.build();
let code = codegen_forward_default(&node);
assert_snapshot!(code, @r"
pub fn forward(&self, length: [i64; 1]) -> Tensor<B, 1, Int> {
let vector = Tensor::<
B,
1,
Int,
>::from_data(
burn::tensor::TensorData::from([100i64 as i64]),
(&self.device, burn::tensor::DType::I64),
)
.reshape([1])
.expand([8usize]);
vector
}
");
}
#[test]
fn test_constant_of_shape_tensor_bool_true_static() {
let config = ConstantOfShapeConfig {
shape: ConstantOfShapeShape::Static(vec![3, 4]),
value: Some(TensorData::new(vec![true], [0usize; 0])),
};
let node = ConstantOfShapeNodeBuilder::new("const1")
.input_shape("shape_dims", 2)
.output_tensor("mask", 2, DType::Bool(BoolStore::Native))
.config(config)
.build();
let code = codegen_forward_default(&node);
assert_snapshot!(code, @r"
pub fn forward(&self, shape_dims: [i64; 2]) -> Tensor<B, 2, Bool> {
let mask = Tensor::<B, 2, Int>::ones([3usize, 4usize], &self.device).bool();
mask
}
");
}
#[test]
fn test_constant_of_shape_tensor_bool_false_static() {
let config = ConstantOfShapeConfig {
shape: ConstantOfShapeShape::Static(vec![6, 7, 8]),
value: Some(TensorData::new(vec![false], [0usize; 0])),
};
let node = ConstantOfShapeNodeBuilder::new("const1")
.input_shape("dimensions", 3)
.output_tensor("flags", 3, DType::Bool(BoolStore::Native))
.config(config)
.build();
let code = codegen_forward_default(&node);
assert_snapshot!(code, @r"
pub fn forward(&self, dimensions: [i64; 3]) -> Tensor<B, 3, Bool> {
let flags = Tensor::<B, 3, Int>::zeros([6usize, 7usize, 8usize], &self.device)
.bool();
flags
}
");
}
#[test]
fn test_constant_of_shape_tensor_default_static() {
let config = ConstantOfShapeConfig {
shape: ConstantOfShapeShape::Static(vec![2, 2]),
value: None, };
let node = ConstantOfShapeNodeBuilder::new("const1")
.input_shape("size", 2)
.output_tensor("zeros", 2, DType::F32)
.config(config)
.build();
let code = codegen_forward_default(&node);
assert_snapshot!(code, @r"
pub fn forward(&self, size: [i64; 2]) -> Tensor<B, 2> {
let zeros = Tensor::<
B,
1,
>::from_data(
burn::tensor::TensorData::from([0.0f32 as f64]),
(&self.device, burn::tensor::DType::F32),
)
.reshape([1, 1])
.expand([2usize, 2usize]);
zeros
}
");
}
#[test]
fn test_constant_of_shape_tensor_runtime_f32() {
let config = ConstantOfShapeConfig {
shape: ConstantOfShapeShape::Runtime(RuntimeInputRef {
name: "dynamic_shape".to_string(),
input_index: 0,
}),
value: Some(TensorData::new(vec![2.5f32], [0usize; 0])),
};
let node = ConstantOfShapeNodeBuilder::new("const1")
.input_shape("dynamic_shape", 3)
.output_tensor("tensor", 3, DType::F32)
.config(config)
.build();
let code = codegen_forward_default(&node);
assert_snapshot!(code, @r"
pub fn forward(&self, dynamic_shape: [i64; 3]) -> Tensor<B, 3> {
let tensor = Tensor::<
B,
1,
>::from_data(
burn::tensor::TensorData::from([2.5f32 as f64]),
(&self.device, burn::tensor::DType::F32),
)
.reshape([1, 1, 1])
.expand(dynamic_shape);
tensor
}
");
}
#[test]
fn test_constant_of_shape_tensor_runtime_i64() {
let config = ConstantOfShapeConfig {
shape: ConstantOfShapeShape::Runtime(RuntimeInputRef {
name: "shape_param".to_string(),
input_index: 0,
}),
value: Some(TensorData::new(vec![255i64], [0usize; 0])),
};
let node = ConstantOfShapeNodeBuilder::new("const1")
.input_shape("shape_param", 2)
.output_tensor("data", 2, DType::I64)
.config(config)
.build();
let code = codegen_forward_default(&node);
assert_snapshot!(code, @r"
pub fn forward(&self, shape_param: [i64; 2]) -> Tensor<B, 2, Int> {
let data = Tensor::<
B,
1,
Int,
>::from_data(
burn::tensor::TensorData::from([255i64 as i64]),
(&self.device, burn::tensor::DType::I64),
)
.reshape([1, 1])
.expand(shape_param);
data
}
");
}
#[test]
fn test_constant_of_shape_tensor_runtime_bool_true() {
let config = ConstantOfShapeConfig {
shape: ConstantOfShapeShape::Runtime(RuntimeInputRef {
name: "sz".to_string(),
input_index: 0,
}),
value: Some(TensorData::new(vec![true], [0usize; 0])),
};
let node = ConstantOfShapeNodeBuilder::new("const1")
.input_shape("sz", 4)
.output_tensor("bitmask", 4, DType::Bool(BoolStore::Native))
.config(config)
.build();
let code = codegen_forward_default(&node);
assert_snapshot!(code, @r"
pub fn forward(&self, sz: [i64; 4]) -> Tensor<B, 4, Bool> {
let bitmask = Tensor::<B, 4, Int>::ones(sz, &self.device).bool();
bitmask
}
");
}
#[test]
fn test_constant_of_shape_tensor_runtime_bool_false() {
let config = ConstantOfShapeConfig {
shape: ConstantOfShapeShape::Runtime(RuntimeInputRef {
name: "target_dims".to_string(),
input_index: 0,
}),
value: Some(TensorData::new(vec![false], [0usize; 0])),
};
let node = ConstantOfShapeNodeBuilder::new("const1")
.input_shape("target_dims", 2)
.output_tensor("empty_mask", 2, DType::Bool(BoolStore::Native))
.config(config)
.build();
let code = codegen_forward_default(&node);
assert_snapshot!(code, @r"
pub fn forward(&self, target_dims: [i64; 2]) -> Tensor<B, 2, Bool> {
let empty_mask = Tensor::<B, 2, Int>::zeros(target_dims, &self.device).bool();
empty_mask
}
");
}
#[test]
fn test_constant_of_shape_tensor_runtime_default() {
let config = ConstantOfShapeConfig {
shape: ConstantOfShapeShape::Runtime(RuntimeInputRef {
name: "runtime_shape".to_string(),
input_index: 0,
}),
value: None, };
let node = ConstantOfShapeNodeBuilder::new("const1")
.input_shape("runtime_shape", 3)
.output_tensor("zeros", 3, DType::F32)
.config(config)
.build();
let code = codegen_forward_default(&node);
assert_snapshot!(code, @r"
pub fn forward(&self, runtime_shape: [i64; 3]) -> Tensor<B, 3> {
let zeros = Tensor::<
B,
1,
>::from_data(
burn::tensor::TensorData::from([0.0f32 as f64]),
(&self.device, burn::tensor::DType::F32),
)
.reshape([1, 1, 1])
.expand(runtime_shape);
zeros
}
");
}
#[test]
fn test_constant_of_shape_scalar_f32_infinity() {
let config = ConstantOfShapeConfig {
shape: ConstantOfShapeShape::Static(vec![]),
value: Some(TensorData::new(vec![f32::INFINITY], [0usize; 0])),
};
let node = ConstantOfShapeNodeBuilder::new("const1")
.input_shape("dims", 1)
.output_scalar("result", DType::F32)
.config(config)
.build();
let code = codegen_forward_default(&node);
assert_snapshot!(code, @r"
pub fn forward(&self, dims: [i64; 1]) -> f32 {
let result = f32::INFINITY;
result
}
");
}
#[test]
fn test_constant_of_shape_scalar_f32_neg_infinity() {
let config = ConstantOfShapeConfig {
shape: ConstantOfShapeShape::Static(vec![]),
value: Some(TensorData::new(vec![f32::NEG_INFINITY], [0usize; 0])),
};
let node = ConstantOfShapeNodeBuilder::new("const1")
.input_shape("dims", 1)
.output_scalar("result", DType::F32)
.config(config)
.build();
let code = codegen_forward_default(&node);
assert_snapshot!(code, @r"
pub fn forward(&self, dims: [i64; 1]) -> f32 {
let result = f32::NEG_INFINITY;
result
}
");
}
#[test]
fn test_constant_of_shape_scalar_f32_nan() {
let config = ConstantOfShapeConfig {
shape: ConstantOfShapeShape::Static(vec![]),
value: Some(TensorData::new(vec![f32::NAN], [0usize; 0])),
};
let node = ConstantOfShapeNodeBuilder::new("const1")
.input_shape("dims", 1)
.output_scalar("result", DType::F32)
.config(config)
.build();
let code = codegen_forward_default(&node);
assert_snapshot!(code, @r"
pub fn forward(&self, dims: [i64; 1]) -> f32 {
let result = f32::NAN;
result
}
");
}
#[test]
fn test_constant_of_shape_scalar_f64_infinity() {
let config = ConstantOfShapeConfig {
shape: ConstantOfShapeShape::Static(vec![]),
value: Some(TensorData::new(vec![f64::INFINITY], [0usize; 0])),
};
let node = ConstantOfShapeNodeBuilder::new("const1")
.input_shape("dims", 1)
.output_scalar("result", DType::F64)
.config(config)
.build();
let code = codegen_forward_default(&node);
assert_snapshot!(code, @r"
pub fn forward(&self, dims: [i64; 1]) -> f64 {
let result = f64::INFINITY;
result
}
");
}
#[test]
fn test_constant_of_shape_shape_output_i64() {
let config = ConstantOfShapeConfig {
shape: ConstantOfShapeShape::Static(vec![3]), value: Some(TensorData::new(vec![10i64], [0usize; 0])),
};
let node = ConstantOfShapeNodeBuilder::new("const1")
.input_shape("in_shape", 1)
.output_shape("out_shape", 3) .config(config)
.build();
let code = codegen_forward_default(&node);
assert_snapshot!(code, @r"
pub fn forward(&self, in_shape: [i64; 1]) -> [i64; 3] {
let out_shape: [i64; 3usize] = [10i64, 10i64, 10i64];
out_shape
}
");
}
#[test]
fn test_constant_of_shape_shape_output_single_element() {
let config = ConstantOfShapeConfig {
shape: ConstantOfShapeShape::Static(vec![1]), value: Some(TensorData::new(vec![5i64], [0usize; 0])),
};
let node = ConstantOfShapeNodeBuilder::new("const1")
.input_shape("dims", 1)
.output_shape("result", 1) .config(config)
.build();
let code = codegen_forward_default(&node);
assert_snapshot!(code, @r"
pub fn forward(&self, dims: [i64; 1]) -> [i64; 1] {
let result: [i64; 1usize] = [5i64];
result
}
");
}
}