use super::prelude::*;
fn clip_bound_expr(
bound: &Option<onnx_ir::node::clip::ClipInput>,
inputs: &[Argument],
scope: &mut ScopeAtPosition<'_>,
) -> Option<TokenStream> {
match bound {
None => None,
Some(onnx_ir::node::clip::ClipInput::Static(v)) => {
let v = *v;
Some(quote! { #v })
}
Some(onnx_ir::node::clip::ClipInput::Runtime(r)) => {
let arg = &inputs[r.input_index];
match &arg.ty {
ArgType::ScalarNative(_) => {
let ident = arg_to_ident(arg);
Some(quote! { (#ident as f64) })
}
ArgType::ScalarTensor(dtype) => {
let tensor = scope.arg(arg);
let native = on_device_to_native(quote! { #tensor }, dtype);
Some(quote! { (#native as f64) })
}
other => panic!(
"Clip min/max must be a scalar (ScalarNative or ScalarTensor), got {other:?}"
),
}
}
}
}
impl NodeCodegen for onnx_ir::clip::ClipNode {
fn inputs(&self) -> &[Argument] {
&self.inputs
}
fn outputs(&self) -> &[Argument] {
&self.outputs
}
fn forward(&self, scope: &mut ScopeAtPosition<'_>) -> TokenStream {
let output = arg_to_ident(self.outputs.first().unwrap());
let min_expr = clip_bound_expr(&self.config.min, &self.inputs, scope);
let max_expr = clip_bound_expr(&self.config.max, &self.inputs, scope);
let input = scope.arg(self.inputs.first().unwrap());
match (min_expr, max_expr) {
(Some(min), Some(max)) => quote! {
let #output = {
let __clip_min = #min;
let __clip_max = #max;
#input.clamp(__clip_min, __clip_max)
};
},
(Some(min), None) => quote! {
let #output = {
let __clip_min = #min;
#input.clamp_min(__clip_min)
};
},
(None, Some(max)) => quote! {
let #output = {
let __clip_max = #max;
#input.clamp_max(__clip_max)
};
},
(None, None) => panic!("Clip node must have at least one min or max value"),
}
}
}
#[cfg(test)]
mod tests {
use super::super::test_helpers::*;
use burn::tensor::DType;
use insta::assert_snapshot;
use onnx_ir::clip::{ClipConfig, ClipNode, ClipNodeBuilder};
use onnx_ir::node::clip::ClipInput;
fn create_clip_node(name: &str, min: Option<f64>, max: Option<f64>) -> ClipNode {
let config = ClipConfig {
min: min.map(ClipInput::Static),
max: max.map(ClipInput::Static),
};
ClipNodeBuilder::new(name)
.input_tensor("input", 2, DType::F32)
.output_tensor("output", 2, DType::F32)
.config(config)
.build()
}
#[test]
fn test_clip_both_bounds() {
let node = create_clip_node("clip1", Some(-1.0), Some(1.0));
let code = codegen_forward_default(&node);
assert_snapshot!(code, @r"
pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let output = {
let __clip_min = -1f64;
let __clip_max = 1f64;
input.clamp(__clip_min, __clip_max)
};
output
}
");
}
#[test]
fn test_clip_min_only() {
let node = create_clip_node("clip1", Some(0.0), None);
let code = codegen_forward_default(&node);
assert_snapshot!(code, @r"
pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let output = {
let __clip_min = 0f64;
input.clamp_min(__clip_min)
};
output
}
");
}
#[test]
fn test_clip_max_only() {
let node = create_clip_node("clip1", None, Some(10.0));
let code = codegen_forward_default(&node);
assert_snapshot!(code, @r"
pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let output = {
let __clip_max = 10f64;
input.clamp_max(__clip_max)
};
output
}
");
}
#[test]
fn test_clip_runtime_min_scalar_tensor() {
let config = ClipConfig {
min: Some(ClipInput::Runtime(onnx_ir::ir::RuntimeInputRef::new(
"min_val".to_string(),
1,
))),
max: None,
};
let node = ClipNodeBuilder::new("clip1")
.input_tensor("input", 2, DType::F32)
.input_scalar_tensor("min_val", DType::F32)
.output_tensor("output", 2, DType::F32)
.config(config)
.build();
let code = codegen_forward_default(&node);
assert_snapshot!(code, @r"
pub fn forward(&self, input: Tensor<B, 2>, min_val: Tensor<B, 1>) -> Tensor<B, 2> {
let output = {
let __clip_min = (min_val.into_scalar().elem::<f32>() as f64);
input.clamp_min(__clip_min)
};
output
}
");
}
#[test]
fn test_clip_runtime_both_scalar_tensors() {
let config = ClipConfig {
min: Some(ClipInput::Runtime(onnx_ir::ir::RuntimeInputRef::new(
"min_val".to_string(),
1,
))),
max: Some(ClipInput::Runtime(onnx_ir::ir::RuntimeInputRef::new(
"max_val".to_string(),
2,
))),
};
let node = ClipNodeBuilder::new("clip1")
.input_tensor("input", 2, DType::F32)
.input_scalar_tensor("min_val", DType::F32)
.input_scalar_tensor("max_val", DType::F32)
.output_tensor("output", 2, DType::F32)
.config(config)
.build();
let code = codegen_forward_default(&node);
assert_snapshot!(code, @r"
pub fn forward(
&self,
input: Tensor<B, 2>,
min_val: Tensor<B, 1>,
max_val: Tensor<B, 1>,
) -> Tensor<B, 2> {
let output = {
let __clip_min = (min_val.into_scalar().elem::<f32>() as f64);
let __clip_max = (max_val.into_scalar().elem::<f32>() as f64);
input.clamp(__clip_min, __clip_max)
};
output
}
");
}
#[test]
fn test_clip_runtime_min_scalar_native() {
let config = ClipConfig {
min: Some(ClipInput::Runtime(onnx_ir::ir::RuntimeInputRef::new(
"min_val".to_string(),
1,
))),
max: None,
};
let node = ClipNodeBuilder::new("clip1")
.input_tensor("input", 2, DType::F32)
.input_scalar("min_val", DType::F32)
.output_tensor("output", 2, DType::F32)
.config(config)
.build();
let code = codegen_forward_default(&node);
assert_snapshot!(code, @r"
pub fn forward(&self, input: Tensor<B, 2>, min_val: f32) -> Tensor<B, 2> {
let output = {
let __clip_min = (min_val as f64);
input.clamp_min(__clip_min)
};
output
}
");
}
}