use super::prelude::*;
use burn_store::TensorSnapshot;
use onnx_ir::gru::{GruActivationFunction, GruDirection};
#[allow(clippy::single_range_in_vec_init)]
fn collect_gru_snapshots(
field_name: &str,
inputs: &[Argument],
config: &onnx_ir::gru::GruConfig,
) -> Vec<TensorSnapshot> {
use crate::burn::node_traits::{SerializationBackend, extract_node_data};
use burn::tensor::Tensor;
let hidden_size = config.hidden_size;
let input_size = config.input_size;
let data_w = extract_node_data(inputs, 1);
let data_r = extract_node_data(inputs, 2);
let data_b = extract_node_data(inputs, 3);
let Some(data_w) = data_w else {
return vec![];
};
let Some(data_r) = data_r else {
return vec![];
};
let dtype = data_w.dtype;
let device = Default::default();
let gate_indices = [0usize, 1, 2]; let gate_names = ["update_gate", "reset_gate", "new_gate"];
let direction_prefixes: Vec<&str> = match config.direction {
GruDirection::Forward | GruDirection::Reverse => vec![""],
GruDirection::Bidirectional => vec!["forward.", "reverse."],
};
let mut snapshots = Vec::new();
let w_tensor: Tensor<SerializationBackend, 3> = Tensor::from_data(data_w.clone(), &device);
let r_tensor: Tensor<SerializationBackend, 3> = Tensor::from_data(data_r.clone(), &device);
let b_tensor: Option<Tensor<SerializationBackend, 2>> =
data_b.clone().map(|b| Tensor::from_data(b, &device));
for (dir_idx, dir_prefix) in direction_prefixes.iter().enumerate() {
let w_dir = w_tensor
.clone()
.slice([dir_idx..dir_idx + 1, 0..3 * hidden_size, 0..input_size])
.squeeze::<2>();
let r_dir = r_tensor
.clone()
.slice([dir_idx..dir_idx + 1, 0..3 * hidden_size, 0..hidden_size])
.squeeze::<2>();
let b_dir = b_tensor.as_ref().map(|b| {
b.clone()
.slice([dir_idx..dir_idx + 1, 0..6 * hidden_size])
.squeeze::<1>() });
for (gate_idx, gate_name) in gate_names.iter().enumerate() {
let onnx_gate_idx = gate_indices[gate_idx];
let start = onnx_gate_idx * hidden_size;
let end = start + hidden_size;
let w_gate = w_dir.clone().slice([start..end, 0..input_size]).transpose();
let w_gate_data = w_gate.into_data();
let path = format!(
"{}.{}{}.input_transform.weight",
field_name, dir_prefix, gate_name
);
snapshots.push(create_snapshot_from_data(
w_gate_data,
&path,
"Linear",
dtype,
));
if let Some(ref b) = b_dir {
let wb_start = onnx_gate_idx * hidden_size;
let wb_end = wb_start + hidden_size;
let wb: Tensor<SerializationBackend, 1> = b.clone().slice([wb_start..wb_end]);
let bias_data = wb.into_data();
let path = format!(
"{}.{}{}.input_transform.bias",
field_name, dir_prefix, gate_name
);
snapshots.push(create_snapshot_from_data(bias_data, &path, "Linear", dtype));
}
let r_gate = r_dir
.clone()
.slice([start..end, 0..hidden_size])
.transpose();
let r_gate_data = r_gate.into_data();
let path = format!(
"{}.{}{}.hidden_transform.weight",
field_name, dir_prefix, gate_name
);
snapshots.push(create_snapshot_from_data(
r_gate_data,
&path,
"Linear",
dtype,
));
if let Some(b) = &b_dir {
let rb_start = 3 * hidden_size + onnx_gate_idx * hidden_size;
let rb_end = rb_start + hidden_size;
let rb: Tensor<SerializationBackend, 1> = b.clone().slice([rb_start..rb_end]);
let bias_data = rb.into_data();
let path = format!(
"{}.{}{}.hidden_transform.bias",
field_name, dir_prefix, gate_name
);
snapshots.push(create_snapshot_from_data(bias_data, &path, "Linear", dtype));
}
}
}
snapshots
}
fn create_snapshot_from_data(
data: burn::tensor::TensorData,
path: &str,
container_type: &str,
dtype: burn::tensor::DType,
) -> TensorSnapshot {
use burn::module::ParamId;
use burn_store::TensorSnapshotError;
use std::rc::Rc;
let data = data.convert_dtype(dtype);
let shape = data.shape.clone();
let path_stack: Vec<String> = path.split('.').map(String::from).collect();
let container_stack = vec![format!("Struct:{}", container_type)];
let data_fn = Rc::new(
move || -> Result<burn::tensor::TensorData, TensorSnapshotError> { Ok(data.clone()) },
);
TensorSnapshot::from_closure(
data_fn,
dtype,
shape,
path_stack,
container_stack,
ParamId::new(),
)
}
fn forward_unidirectional(
node: &onnx_ir::gru::GruNode,
scope: &mut ScopeAtPosition<'_>,
input: TokenStream,
field: Ident,
output_y: Option<Ident>,
output_y_h: Option<Ident>,
) -> TokenStream {
let has_initial_h = node.config.has_initial_h;
let is_reverse = matches!(node.config.direction, GruDirection::Reverse);
let batch_first = node.config.batch_first;
let initial_state_expr = if has_initial_h {
let h_input = scope.arg(&node.inputs[5]);
quote! { Some(#h_input.squeeze_dim(0)) }
} else {
quote! { None }
};
let input_transform = if batch_first {
quote! { #input }
} else {
quote! { #input.swap_dims(0, 1) }
};
let input_with_direction = if is_reverse {
quote! {
{
let batch_first_input = #input_transform;
batch_first_input.flip([1])
}
}
} else {
quote! { #input_transform }
};
let forward_call = quote! {
let gru_output = self.#field.forward(#input_with_direction, #initial_state_expr);
};
let output_with_direction = if is_reverse {
quote! { gru_output.flip([1]) }
} else {
quote! { gru_output }
};
let y_h_step = if is_reverse {
quote! { 0..1 }
} else {
quote! { (seq_len - 1)..seq_len }
};
let y_h_expr = quote! {
{
let [_batch, seq_len, _hidden] = batch_first_output.dims();
let step = batch_first_output.clone().slice([0.._batch, #y_h_step, 0.._hidden]);
step.squeeze_dim::<2>(1).unsqueeze_dims::<3>(&[0])
}
};
let y_output_expr = if batch_first {
quote! { batch_first_output.clone().unsqueeze_dims::<4>(&[2]) }
} else {
quote! { batch_first_output.clone().swap_dims(0, 1).unsqueeze_dims::<4>(&[1]) }
};
match (output_y, output_y_h) {
(Some(y), Some(y_h)) => {
quote! {
let (#y, #y_h) = {
#forward_call
let batch_first_output = #output_with_direction;
(
#y_output_expr,
#y_h_expr
)
};
}
}
(Some(y), None) => {
quote! {
let #y = {
#forward_call
let batch_first_output = #output_with_direction;
#y_output_expr
};
}
}
(None, Some(y_h)) => {
quote! {
let #y_h = {
#forward_call
let batch_first_output = #output_with_direction;
#y_h_expr
};
}
}
(None, None) => {
quote! {
{
#forward_call
}
}
}
}
}
fn forward_bidirectional(
node: &onnx_ir::gru::GruNode,
scope: &mut ScopeAtPosition<'_>,
input: TokenStream,
field: Ident,
output_y: Option<Ident>,
output_y_h: Option<Ident>,
) -> TokenStream {
let has_initial_h = node.config.has_initial_h;
let hidden_size = node.config.hidden_size;
let initial_state_expr = if has_initial_h {
let h_input = scope.arg(&node.inputs[5]);
quote! { Some(#h_input) }
} else {
quote! { None }
};
let y_output_expr = if node.config.batch_first {
quote! {
{
let [batch_size, seq_len, _] = output_seq.dims();
output_seq.reshape([batch_size, seq_len, 2, #hidden_size])
}
}
} else {
quote! {
{
let [seq_len, batch_size, _] = output_seq.dims();
let reshaped = output_seq.reshape([seq_len, batch_size, 2, #hidden_size]);
reshaped.swap_dims(1, 2)
}
}
};
match (output_y, output_y_h) {
(Some(y), Some(y_h)) => {
quote! {
let (#y, #y_h) = {
let (output_seq, final_state) = self.#field.forward(#input, #initial_state_expr);
(#y_output_expr, final_state)
};
}
}
(Some(y), None) => {
quote! {
let #y = {
let (output_seq, _final_state) = self.#field.forward(#input, #initial_state_expr);
#y_output_expr
};
}
}
(None, Some(y_h)) => {
quote! {
let #y_h = {
let (_output_seq, final_state) = self.#field.forward(#input, #initial_state_expr);
final_state
};
}
}
(None, None) => {
quote! {
{
let _ = self.#field.forward(#input, #initial_state_expr);
}
}
}
}
}
impl NodeCodegen for onnx_ir::gru::GruNode {
fn inputs(&self) -> &[Argument] {
&self.inputs
}
fn outputs(&self) -> &[Argument] {
&self.outputs
}
fn field(&self) -> Option<Field> {
if self.config.clip.is_some() {
panic!(
"GRU clip attribute is not supported. Burn's GRU module does not support cell state clipping."
);
}
if self.config.gate_activation != GruActivationFunction::Sigmoid
|| self.config.hidden_activation != GruActivationFunction::Tanh
{
panic!(
"Custom GRU activations are not supported. Burn's GRU uses fixed Sigmoid (gates) and Tanh (hidden). Got gate: {:?}, hidden: {:?}",
self.config.gate_activation, self.config.hidden_activation
);
}
let name = Ident::new(&self.name, Span::call_site());
let d_input = self.config.input_size.to_tokens();
let d_hidden = self.config.hidden_size.to_tokens();
let bias = self.config.has_bias;
let reset_after = self.config.linear_before_reset;
match self.config.direction {
GruDirection::Forward | GruDirection::Reverse => Some(Field::new(
self.name.clone(),
quote! { burn::nn::gru::Gru<B> },
quote! {
let #name = burn::nn::gru::GruConfig::new(#d_input, #d_hidden, #bias)
.with_reset_after(#reset_after)
.init(device);
},
)),
GruDirection::Bidirectional => {
let batch_first = self.config.batch_first;
Some(Field::new(
self.name.clone(),
quote! { burn::nn::gru::BiGru<B> },
quote! {
let #name = burn::nn::gru::BiGruConfig::new(#d_input, #d_hidden, #bias)
.with_reset_after(#reset_after)
.with_batch_first(#batch_first)
.init(device);
},
))
}
}
}
fn collect_snapshots(&self, field_name: &str) -> Vec<TensorSnapshot> {
collect_gru_snapshots(field_name, &self.inputs, &self.config)
}
fn forward(&self, scope: &mut ScopeAtPosition<'_>) -> TokenStream {
let input = scope.arg(self.inputs.first().unwrap());
let field = Ident::new(&self.name, Span::call_site());
let output_y = self
.outputs
.first()
.filter(|a| !a.name.is_empty())
.map(arg_to_ident);
let output_y_h = self
.outputs
.get(1)
.filter(|a| !a.name.is_empty())
.map(arg_to_ident);
if matches!(self.config.direction, GruDirection::Bidirectional) {
forward_bidirectional(self, scope, input, field, output_y, output_y_h)
} else {
forward_unidirectional(self, scope, input, field, output_y, output_y_h)
}
}
fn register_imports(&self, _imports: &mut BurnImports) {
}
}
#[cfg(test)]
mod tests {
use super::super::test_helpers::*;
use burn::tensor::DType;
use insta::assert_snapshot;
use onnx_ir::gru::{GruActivationFunction, GruConfig, GruDirection, GruNode};
use onnx_ir::ir::{ArgType, Argument, TensorType};
fn create_gru_node(
name: &str,
direction: GruDirection,
batch_first: bool,
has_initial_h: bool,
num_outputs: usize,
) -> GruNode {
let config = GruConfig::new(
4, 8, direction,
true, has_initial_h,
batch_first,
None, false, GruActivationFunction::Sigmoid, GruActivationFunction::Tanh, None, None, );
let input = Argument::new(
"input",
ArgType::Tensor(TensorType::new(DType::F32, 3, None)),
);
let w = Argument::new("W", ArgType::Tensor(TensorType::new(DType::F32, 3, None)));
let r = Argument::new("R", ArgType::Tensor(TensorType::new(DType::F32, 3, None)));
let b = Argument::new("B", ArgType::Tensor(TensorType::new(DType::F32, 2, None)));
let mut inputs = vec![input, w, r, b];
if has_initial_h {
inputs.push(Argument::new(
"sequence_lens",
ArgType::ScalarNative(DType::I64),
));
inputs.push(Argument::new(
"initial_h",
ArgType::Tensor(TensorType::new(DType::F32, 3, None)),
));
}
let mut outputs = vec![];
if num_outputs > 0 {
outputs.push(Argument::new(
"Y",
ArgType::Tensor(TensorType::new(DType::F32, 4, None)),
));
}
if num_outputs > 1 {
outputs.push(Argument::new(
"Y_h",
ArgType::Tensor(TensorType::new(DType::F32, 3, None)),
));
}
GruNode {
name: name.to_string(),
inputs,
outputs,
config,
}
}
#[test]
fn test_gru_forward_basic() {
let node = create_gru_node("gru1", GruDirection::Forward, false, false, 2);
let code = codegen_forward_default(&node);
assert_snapshot!(code, @r#"
pub fn forward(
&self,
input: Tensor<B, 3>,
W: Tensor<B, 3>,
R: Tensor<B, 3>,
B: Tensor<B, 2>,
) -> (Tensor<B, 4>, Tensor<B, 3>) {
let (Y, Y_h) = {
let gru_output = self.gru1.forward(input.swap_dims(0, 1), None);
let batch_first_output = gru_output;
(
batch_first_output.clone().swap_dims(0, 1).unsqueeze_dims::<4>(&[1]),
{
let [_batch, seq_len, _hidden] = batch_first_output.dims();
let step = batch_first_output
.clone()
.slice([0.._batch, (seq_len - 1)..seq_len, 0.._hidden]);
step.squeeze_dim::<2>(1).unsqueeze_dims::<3>(&[0])
},
)
};
(Y, Y_h)
}
"#);
}
#[test]
fn test_gru_forward_reverse() {
let node = create_gru_node("gru1", GruDirection::Reverse, false, false, 2);
let code = codegen_forward_default(&node);
assert_snapshot!(code, @r#"
pub fn forward(
&self,
input: Tensor<B, 3>,
W: Tensor<B, 3>,
R: Tensor<B, 3>,
B: Tensor<B, 2>,
) -> (Tensor<B, 4>, Tensor<B, 3>) {
let (Y, Y_h) = {
let gru_output = self
.gru1
.forward(
{
let batch_first_input = input.swap_dims(0, 1);
batch_first_input.flip([1])
},
None,
);
let batch_first_output = gru_output.flip([1]);
(
batch_first_output.clone().swap_dims(0, 1).unsqueeze_dims::<4>(&[1]),
{
let [_batch, seq_len, _hidden] = batch_first_output.dims();
let step = batch_first_output
.clone()
.slice([0.._batch, 0..1, 0.._hidden]);
step.squeeze_dim::<2>(1).unsqueeze_dims::<3>(&[0])
},
)
};
(Y, Y_h)
}
"#);
}
#[test]
fn test_gru_forward_y_only() {
let node = create_gru_node("gru1", GruDirection::Forward, false, false, 1);
let code = codegen_forward_default(&node);
assert_snapshot!(code, @r#"
pub fn forward(
&self,
input: Tensor<B, 3>,
W: Tensor<B, 3>,
R: Tensor<B, 3>,
B: Tensor<B, 2>,
) -> Tensor<B, 4> {
let Y = {
let gru_output = self.gru1.forward(input.swap_dims(0, 1), None);
let batch_first_output = gru_output;
batch_first_output.clone().swap_dims(0, 1).unsqueeze_dims::<4>(&[1])
};
Y
}
"#);
}
#[test]
fn test_gru_field_forward() {
let node = create_gru_node("gru1", GruDirection::Forward, false, false, 2);
let code = codegen_field_init(&node);
assert_snapshot!(code, @r"
let gru1 = burn::nn::gru::GruConfig::new(4, 8, true)
.with_reset_after(false)
.init(device);
");
}
#[test]
fn test_gru_field_reverse() {
let node = create_gru_node("gru1", GruDirection::Reverse, false, false, 2);
let code = codegen_field_init(&node);
assert_snapshot!(code, @r"
let gru1 = burn::nn::gru::GruConfig::new(4, 8, true)
.with_reset_after(false)
.init(device);
");
}
#[test]
fn test_gru_forward_batch_first() {
let node = create_gru_node("gru1", GruDirection::Forward, true, false, 2);
let code = codegen_forward_default(&node);
assert_snapshot!(code, @r#"
pub fn forward(
&self,
input: Tensor<B, 3>,
W: Tensor<B, 3>,
R: Tensor<B, 3>,
B: Tensor<B, 2>,
) -> (Tensor<B, 4>, Tensor<B, 3>) {
let (Y, Y_h) = {
let gru_output = self.gru1.forward(input, None);
let batch_first_output = gru_output;
(
batch_first_output.clone().unsqueeze_dims::<4>(&[2]),
{
let [_batch, seq_len, _hidden] = batch_first_output.dims();
let step = batch_first_output
.clone()
.slice([0.._batch, (seq_len - 1)..seq_len, 0.._hidden]);
step.squeeze_dim::<2>(1).unsqueeze_dims::<3>(&[0])
},
)
};
(Y, Y_h)
}
"#);
}
#[test]
fn test_gru_forward_with_initial_h() {
let node = create_gru_node("gru1", GruDirection::Forward, false, true, 2);
let code = codegen_forward_default(&node);
assert_snapshot!(code, @r#"
pub fn forward(
&self,
input: Tensor<B, 3>,
W: Tensor<B, 3>,
R: Tensor<B, 3>,
B: Tensor<B, 2>,
sequence_lens: i64,
initial_h: Tensor<B, 3>,
) -> (Tensor<B, 4>, Tensor<B, 3>) {
let (Y, Y_h) = {
let gru_output = self
.gru1
.forward(input.swap_dims(0, 1), Some(initial_h.squeeze_dim(0)));
let batch_first_output = gru_output;
(
batch_first_output.clone().swap_dims(0, 1).unsqueeze_dims::<4>(&[1]),
{
let [_batch, seq_len, _hidden] = batch_first_output.dims();
let step = batch_first_output
.clone()
.slice([0.._batch, (seq_len - 1)..seq_len, 0.._hidden]);
step.squeeze_dim::<2>(1).unsqueeze_dims::<3>(&[0])
},
)
};
(Y, Y_h)
}
"#);
}
#[test]
fn test_gru_forward_bidirectional() {
let node = create_gru_node("gru1", GruDirection::Bidirectional, false, false, 2);
let code = codegen_forward_default(&node);
assert_snapshot!(code, @r"
pub fn forward(
&self,
input: Tensor<B, 3>,
W: Tensor<B, 3>,
R: Tensor<B, 3>,
B: Tensor<B, 2>,
) -> (Tensor<B, 4>, Tensor<B, 3>) {
let (Y, Y_h) = {
let (output_seq, final_state) = self.gru1.forward(input, None);
(
{
let [seq_len, batch_size, _] = output_seq.dims();
let reshaped = output_seq.reshape([seq_len, batch_size, 2, 8usize]);
reshaped.swap_dims(1, 2)
},
final_state,
)
};
(Y, Y_h)
}
");
}
#[test]
fn test_gru_forward_bidirectional_batch_first() {
let node = create_gru_node("gru1", GruDirection::Bidirectional, true, false, 2);
let code = codegen_forward_default(&node);
assert_snapshot!(code, @r"
pub fn forward(
&self,
input: Tensor<B, 3>,
W: Tensor<B, 3>,
R: Tensor<B, 3>,
B: Tensor<B, 2>,
) -> (Tensor<B, 4>, Tensor<B, 3>) {
let (Y, Y_h) = {
let (output_seq, final_state) = self.gru1.forward(input, None);
(
{
let [batch_size, seq_len, _] = output_seq.dims();
output_seq.reshape([batch_size, seq_len, 2, 8usize])
},
final_state,
)
};
(Y, Y_h)
}
");
}
#[test]
fn test_gru_forward_bidirectional_y_only() {
let node = create_gru_node("gru1", GruDirection::Bidirectional, false, false, 1);
let code = codegen_forward_default(&node);
assert_snapshot!(code, @r"
pub fn forward(
&self,
input: Tensor<B, 3>,
W: Tensor<B, 3>,
R: Tensor<B, 3>,
B: Tensor<B, 2>,
) -> Tensor<B, 4> {
let Y = {
let (output_seq, _final_state) = self.gru1.forward(input, None);
{
let [seq_len, batch_size, _] = output_seq.dims();
let reshaped = output_seq.reshape([seq_len, batch_size, 2, 8usize]);
reshaped.swap_dims(1, 2)
}
};
Y
}
");
}
#[test]
fn test_gru_forward_bidirectional_with_initial_h() {
let node = create_gru_node("gru1", GruDirection::Bidirectional, false, true, 2);
let code = codegen_forward_default(&node);
assert_snapshot!(code, @r"
pub fn forward(
&self,
input: Tensor<B, 3>,
W: Tensor<B, 3>,
R: Tensor<B, 3>,
B: Tensor<B, 2>,
sequence_lens: i64,
initial_h: Tensor<B, 3>,
) -> (Tensor<B, 4>, Tensor<B, 3>) {
let (Y, Y_h) = {
let (output_seq, final_state) = self.gru1.forward(input, Some(initial_h));
(
{
let [seq_len, batch_size, _] = output_seq.dims();
let reshaped = output_seq.reshape([seq_len, batch_size, 2, 8usize]);
reshaped.swap_dims(1, 2)
},
final_state,
)
};
(Y, Y_h)
}
");
}
#[test]
fn test_gru_field_bidirectional() {
let node = create_gru_node("gru1", GruDirection::Bidirectional, false, false, 2);
let code = codegen_field_init(&node);
assert_snapshot!(code, @r"
let gru1 = burn::nn::gru::BiGruConfig::new(4, 8, true)
.with_reset_after(false)
.with_batch_first(false)
.init(device);
");
}
}