use burn_core as burn;
use crate::{Linear, LinearConfig, LinearLayout};
use burn::module::{Initializer, Module};
use burn::tensor::{Tensor, backend::Backend};
#[derive(Module, Debug)]
pub struct GateController<B: Backend> {
pub input_transform: Linear<B>,
pub hidden_transform: Linear<B>,
}
impl<B: Backend> GateController<B> {
pub fn new(
d_input: usize,
d_output: usize,
bias: bool,
initializer: Initializer,
device: &B::Device,
) -> Self {
Self {
input_transform: LinearConfig {
d_input,
d_output,
bias,
initializer: initializer.clone(),
layout: LinearLayout::Row,
}
.init(device),
hidden_transform: LinearConfig {
d_input: d_output,
d_output,
bias,
initializer,
layout: LinearLayout::Row,
}
.init(device),
}
}
pub fn gate_product(&self, input: Tensor<B, 2>, hidden: Tensor<B, 2>) -> Tensor<B, 2> {
self.input_transform.forward(input) + self.hidden_transform.forward(hidden)
}
#[cfg(test)]
pub fn create_with_weights(
d_input: usize,
d_output: usize,
bias: bool,
initializer: Initializer,
input_record: crate::LinearRecord<B>,
hidden_record: crate::LinearRecord<B>,
) -> Self {
let l1 = LinearConfig {
d_input,
d_output,
bias,
initializer: initializer.clone(),
layout: LinearLayout::Row,
}
.init(&input_record.weight.device())
.load_record(input_record);
let l2 = LinearConfig {
d_input,
d_output,
bias,
initializer,
layout: LinearLayout::Row,
}
.init(&hidden_record.weight.device())
.load_record(hidden_record);
Self {
input_transform: l1,
hidden_transform: l2,
}
}
}