use crate::error::Result;
use crate::tensor::Tensor;
pub fn residual_connection<F>(input: &Tensor, sublayer: F) -> Result<Tensor>
where
F: FnOnce(&Tensor) -> Result<Tensor>,
{
let sublayer_output = sublayer(input)?;
input.add(&sublayer_output)
}
pub fn pre_norm_residual<F>(
input: &Tensor,
norm: &crate::normalize::LayerNorm,
sublayer: F,
) -> Result<Tensor>
where
F: FnOnce(&Tensor) -> Result<Tensor>,
{
let normed = norm.forward(input)?;
let sublayer_output = sublayer(&normed)?;
input.add(&sublayer_output)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_residual_identity() {
let x = Tensor::new(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
let result = residual_connection(&x, |_| Ok(Tensor::zeros(&[3]))).unwrap();
assert_eq!(result.data(), x.data());
}
#[test]
fn test_residual_adds() {
let x = Tensor::new(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
let result = residual_connection(&x, |inp| Ok(inp.scale(1.0))).unwrap();
assert_eq!(result.data(), &[2.0, 4.0, 6.0]);
}
}